diff --git a/src/sasctl/_services/model_repository.py b/src/sasctl/_services/model_repository.py index dfbbb95d..74fb6446 100644 --- a/src/sasctl/_services/model_repository.py +++ b/src/sasctl/_services/model_repository.py @@ -8,10 +8,17 @@ import datetime from warnings import warn +import requests +from requests.exceptions import HTTPError +import urllib -from ..core import HTTPError, current_session, delete, get, sasctl_command +# import traceback +# import sys + +from ..core import current_session, delete, get, sasctl_command, RestObj from .service import Service + FUNCTIONS = { "Analytical", "Classification", @@ -615,11 +622,222 @@ def list_model_versions(cls, model): list """ - model = cls.get_model(model) - if cls.get_model_link(model, "modelVersions") is None: - raise ValueError("Unable to retrieve versions for model '%s'" % model) - return cls.request_link(model, "modelVersions") + if current_session().version_info() < 4: + model = cls.get_model(model) + if cls.get_model_link(model, "modelVersions") is None: + raise ValueError("Unable to retrieve versions for model '%s'" % model) + + return cls.request_link(model, "modelVersions") + else: + link = cls.get_model_link(model, "modelHistory") + if link is None: + raise ValueError( + "Cannot find link for version history for model '%s'" % model + ) + + modelHistory = cls.request_link( + link, + "modelHistory", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + + if isinstance(modelHistory, RestObj): + return [modelHistory] + return modelHistory + + @classmethod + def get_model_version(cls, model, version_id): + """Get a specific version of a model. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + + Returns + ------- + RestObj + + """ + + model_history = cls.list_model_versions(model) + + for item in model_history: + if item["id"] == version_id: + return cls.request_link( + item, + "self", + headers={"Accept": "application/vnd.sas.models.model.version+json"}, + ) + + raise ValueError("The version id specified could not be found.") + + @classmethod + def get_model_with_versions(cls, model): + """Get the current model with its version history. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + + Returns + ------- + list + + """ + + if cls.is_uuid(model): + model_id = model + elif isinstance(model, dict) and "id" in model: + model_id = model["id"] + else: + model = cls.get_model(model) + if not model: + raise HTTPError( + "This model may not exist in a project or the model may not exist at all." + ) + model_id = model["id"] + + versions_uri = f"/models/{model_id}/versions" + try: + version_history = cls.request( + "GET", + versions_uri, + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + except urllib.error.HTTPError as e: + raise HTTPError( + f"Request failed: Model id may be referencing a non-existing model." + ) from None + + if isinstance(version_history, RestObj): + return [version_history] + + return version_history + + @classmethod + def get_model_or_version(cls, model, version_id): + """Get a specific version of a model but if model id and version id are the same, the current model is returned. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + + Returns + ------- + RestObj + + """ + + version_history = cls.get_model_with_versions(model) + + for item in version_history: + if item["id"] == version_id: + return cls.request_link( + item, + "self", + headers={ + "Accept": "application/vnd.sas.models.model.version+json, application/vnd.sas.models.model+json" + }, + ) + + raise ValueError("The version id specified could not be found.") + + @classmethod + def get_model_version_contents(cls, model, version_id): + """Get the contents of a model version. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + + Returns + ------- + list + + """ + model_version = cls.get_model_version(model, version_id) + version_contents = cls.request_link( + model_version, + "contents", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + + if isinstance(version_contents, RestObj): + return [version_contents] + + return version_contents + + @classmethod + def get_model_version_content_metadata(cls, model, version_id, content_id): + """Get the content metadata header information for a model version. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + content_id: str + The id of the content file. + + Returns + ------- + RestObj + + """ + model_version_contents = cls.get_model_version_contents(model, version_id) + + for item in model_version_contents: + if item["id"] == content_id: + return cls.request_link( + item, + "self", + headers={"Accept": "application/vnd.sas.models.model.content+json"}, + ) + + raise ValueError("The content id specified could not be found.") + + @classmethod + def get_model_version_content(cls, model, version_id, content_id): + """Get the specific content inside the content file for a model version. + + Parameters + ---------- + model : str or dict + The name, id, or dictionary representation of a model. + version_id: str + The id of a model version. + content_id: str + The id of the specific content file. + + Returns + ------- + list + + """ + + metadata = cls.get_model_version_content_metadata(model, version_id, content_id) + version_content_file = cls.request_link( + metadata, "content", headers={"Accept": "text/plain"} + ) + + if version_content_file is None: + raise HTTPError("Something went wrong while accessing the metadata file.") + + if isinstance(version_content_file, RestObj): + return [version_content_file] + return version_content_file @classmethod def copy_analytic_store(cls, model): diff --git a/tests/unit/test_model_repository.py b/tests/unit/test_model_repository.py index 9232896b..bf4f9284 100644 --- a/tests/unit/test_model_repository.py +++ b/tests/unit/test_model_repository.py @@ -13,6 +13,10 @@ from sasctl import current_session from sasctl.services import model_repository as mr +from sasctl.core import RestObj, VersionInfo, request +from requests import HTTPError +import urllib.error + def test_create_model(): MODEL_NAME = "Test Model" @@ -230,3 +234,343 @@ def test_add_model_content(): assert post.call_args[1]["files"] == { "files": ("test.pkl", binary_data, "application/image") } + + +def test_create_model_version(): + model_mock = {"id": 12345} + new_model_mock = {"id": 34567} + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model", + side_effect=[ + model_mock, + model_mock, + new_model_mock, + model_mock, + new_model_mock, + ], + ) as get_model: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_link" + ) as get_model_link: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + get_model_link_mock = { + "method": "GET", + "rel": "modelHistory", + "href": "/modelRepository/models/12345/history", + "uri": "/modelRepository/models/12345/history", + "type": "application/vnd.sas.collection", + "responseItemType": "application/vnd.sas.models.model.version", + } + + get_model_link.return_value = None + with pytest.raises(ValueError): + mr.create_model_version(model=model_mock, minor=False) + + get_model_link.return_value = get_model_link_mock + response = mr.create_model_version(model=model_mock, minor=False) + + request_link.assert_called_with( + model_mock, "addModelVersion", json={"option": "major"} + ) + assert response == new_model_mock + + response = mr.create_model_version(model=model_mock, minor=True) + request_link.assert_called_with( + model_mock, "addModelVersion", json={"option": "minor"} + ) + assert response == new_model_mock + + +def test_list_model_versions(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_link" + ) as get_model_link: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + with mock.patch("sasctl.core.Session.version_info") as version: + version.return_value = VersionInfo(4) + get_model_link.return_value = None + with pytest.raises(ValueError): + mr.list_model_versions( + model="12345", + ) + + get_model_link_mock = { + "method": "GET", + "rel": "modelHistory", + "href": "/modelRepository/models/12345/history", + "uri": "/modelRepository/models/12345/history", + "type": "application/vnd.sas.collection", + "responseItemType": "application/vnd.sas.models.model.version", + } + + get_model_link.return_value = get_model_link_mock + + response = mr.list_model_versions(model="12345") + assert response + + request_link.return_value = RestObj({"id": "12345"}) + response = mr.list_model_versions(model="12345") + assert isinstance(response, list) + + request_link.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.list_model_versions(model="12345") + assert isinstance(response, list) + + +def test_get_model_version(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.list_model_versions" + ) as list_model_versions: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + list_model_versions_mock = [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ] + + list_model_versions.return_value = list_model_versions_mock + + with pytest.raises(ValueError): + mr.get_model_version(model="000", version_id="000") + + response = mr.get_model_version(model="000", version_id="123") + request_link.assert_called_once_with( + list_model_versions_mock[0], + "self", + headers={"Accept": "application/vnd.sas.models.model.version+json"}, + ) + + +def test_get_model_with_versions(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.is_uuid" + ) as is_uuid: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model" + ) as get_model: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request" + ) as request: + + is_uuid.return_value = True + response = mr.get_model_with_versions(model="12345") + assert response + + is_uuid.return_value = False + get_model.return_value = None + response = mr.get_model_with_versions(model={"id": "12345"}) + assert response + + is_uuid.return_value = False + get_model.return_value = None + with pytest.raises(HTTPError): + mr.get_model_with_versions(model=RestObj) + + is_uuid.return_value = False + get_model.return_value = RestObj({"id": "123456"}) + request.side_effect = urllib.error.HTTPError( + url="http://demo.sas.com", + code=404, + msg="Not Found", + hdrs=None, + fp=None, + ) + with pytest.raises(HTTPError): + mr.get_model_with_versions(model=RestObj) + + request.side_effect = None + request.return_value = RestObj({"id": "12345"}) + response = mr.get_model_with_versions(model=RestObj) + assert isinstance(response, list) + + request.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.get_model_with_versions(model=RestObj) + assert isinstance(response, list) + + request.assert_any_call( + "GET", + "/models/123456/versions", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + + request.assert_any_call( + "GET", + "/models/12345/versions", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + + +def test_get_model_or_version(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_with_versions" + ) as get_model_with_versions: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_with_versions_mock = [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ] + + get_model_with_versions.return_value = [] + with pytest.raises(ValueError): + mr.get_model_or_version(model="000", version_id="000") + + get_model_with_versions.return_value = get_model_with_versions_mock + with pytest.raises(ValueError): + mr.get_model_or_version(model="000", version_id="000") + + response = mr.get_model_or_version(model="000", version_id="123") + request_link.assert_called_once_with( + get_model_with_versions_mock[0], + "self", + headers={ + "Accept": "application/vnd.sas.models.model.version+json, application/vnd.sas.models.model+json" + }, + ) + + +def test_get_model_version_contents(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_version" + ) as get_model_version: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_version.return_value = {"id": "000"} + request_link.return_value = RestObj({"id": "12345"}) + response = mr.get_model_version_contents(model="12345", version_id="3456") + assert isinstance(response, list) + + request_link.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.get_model_version_contents(model="12345", version_id="3456") + assert isinstance(response, list) + + request_link.assert_any_call( + {"id": "000"}, + "contents", + headers={"Accept": "application/vnd.sas.collection+json"}, + ) + + +def test_get_model_version_content_metadata(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_version_contents" + ) as get_model_version_contents: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_with_metadata_mock = [ + { + "id": "123", + "links": [ + { + "method": "GET", + "rel": "self", + "href": "/modelRepository/models/abc/history/123", + "uri": "/modelRepository/models/abc/history/123", + "type": "demo", + } + ], + }, + {"id": "345", "links": []}, + ] + + get_model_version_contents.return_value = [] + with pytest.raises(ValueError): + mr.get_model_version_content_metadata( + model="000", version_id="123", content_id="000" + ) + + get_model_version_contents.return_value = get_model_with_metadata_mock + with pytest.raises(ValueError): + mr.get_model_version_content_metadata( + model="abc", version_id="123", content_id="000" + ) + + response = mr.get_model_version_content_metadata( + model="abc", version_id="123", content_id="345" + ) + assert response + request_link.assert_called_once_with( + get_model_with_metadata_mock[1], + "self", + headers={"Accept": "application/vnd.sas.models.model.content+json"}, + ) + + +def test_get_model_version_content(): + with mock.patch( + "sasctl._services.model_repository.ModelRepository.get_model_version_content_metadata" + ) as get_model_version_content_metadata: + with mock.patch( + "sasctl._services.model_repository.ModelRepository.request_link" + ) as request_link: + + get_model_version_content_metadata.return_value = {"id": 000} + request_link.return_value = None + with pytest.raises(HTTPError): + mr.get_model_version_content( + model="abc", version_id="123", content_id="345" + ) + + request_link.return_value = RestObj({"id": "12345"}) + response = mr.get_model_version_content( + model="abc", version_id="123", content_id="345" + ) + assert isinstance(response, list) + + request_link.return_value = [ + RestObj({"id": "12345"}), + RestObj({"id": "3456"}), + ] + response = mr.get_model_version_content( + model="abc", version_id="123", content_id="345" + ) + assert isinstance(response, list) + + request_link.assert_any_call( + {"id": 000}, + "content", + headers={"Accept": "text/plain"}, + )