diff --git a/clients/python/src/model_registry/core.py b/clients/python/src/model_registry/core.py index 270d89b8b..9e57da09c 100644 --- a/clients/python/src/model_registry/core.py +++ b/clients/python/src/model_registry/core.py @@ -1,7 +1,6 @@ """Client for the model registry.""" -from __future__ import annotations -from collections.abc import Sequence +from __future__ import annotations from ml_metadata.proto import MetadataStoreClientConfig @@ -130,7 +129,7 @@ def get_registered_model_by_params( def get_registered_models( self, options: ListOptions | None = None - ) -> Sequence[RegisteredModel]: + ) -> list[RegisteredModel]: """Fetch registered models. Args: @@ -194,7 +193,7 @@ def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None: def get_model_versions( self, registered_model_id: str, options: ListOptions | None = None - ) -> Sequence[ModelVersion]: + ) -> list[ModelVersion]: """Fetch model versions by registered model ID. Args: @@ -344,7 +343,7 @@ def get_model_artifacts( self, model_version_id: str | None = None, options: ListOptions | None = None, - ) -> Sequence[ModelArtifact]: + ) -> list[ModelArtifact]: """Fetches model artifacts. Args: diff --git a/clients/python/src/model_registry/store/wrapper.py b/clients/python/src/model_registry/store/wrapper.py index 57c56dd78..750d0642d 100644 --- a/clients/python/src/model_registry/store/wrapper.py +++ b/clients/python/src/model_registry/store/wrapper.py @@ -124,7 +124,7 @@ def put_context(self, context: Context) -> int: def _filter_type( self, type_name: str, protos: Sequence[ProtoType] - ) -> Sequence[ProtoType]: + ) -> list[ProtoType]: return [proto for proto in protos if proto.type == type_name] def get_context( @@ -168,9 +168,7 @@ def get_context( return None - def get_contexts( - self, ctx_type_name: str, options: ListOptions - ) -> Sequence[Context]: + def get_contexts(self, ctx_type_name: str, options: ListOptions) -> list[Context]: """Get contexts from the store. Args: @@ -179,6 +177,11 @@ def get_contexts( Returns: Contexts. + + Raises: + TypeNotFoundException: If the type doesn't exist. + ServerException: If there was an error getting the type. + StoreException: Invalid arguments. """ # TODO: should we make options optional? # if options is not None: @@ -195,9 +198,11 @@ def get_contexts( # else: # contexts = self._mlmd_store.get_contexts_by_type(ctx_type_name) - if not contexts: + if not contexts and ctx_type_name not in [ + t.name for t in self._mlmd_store.get_context_types() + ]: msg = f"Context type {ctx_type_name} does not exist" - raise StoreException(msg) + raise TypeNotFoundException(msg) return contexts @@ -309,9 +314,7 @@ def get_attributed_artifact(self, art_type_name: str, ctx_id: int) -> Artifact: return None - def get_artifacts( - self, art_type_name: str, options: ListOptions - ) -> Sequence[Artifact]: + def get_artifacts(self, art_type_name: str, options: ListOptions) -> list[Artifact]: """Get artifacts from the store. Args: @@ -320,6 +323,11 @@ def get_artifacts( Returns: Artifacts. + + Raises: + TypeNotFoundException: If the type doesn't exist. + ServerException: If there was an error getting the type. + StoreException: Invalid arguments. """ try: artifacts = self._mlmd_store.get_artifacts(options) @@ -331,8 +339,10 @@ def get_artifacts( raise ServerException(msg) from e artifacts = self._filter_type(art_type_name, artifacts) - if not artifacts: + if not artifacts and art_type_name not in [ + t.name for t in self._mlmd_store.get_artifact_types() + ]: msg = f"Artifact type {art_type_name} does not exist" - raise StoreException(msg) + raise TypeNotFoundException(msg) return artifacts diff --git a/clients/python/tests/store/test_wrapper.py b/clients/python/tests/store/test_wrapper.py index c2d379de1..c6f4dbe2d 100644 --- a/clients/python/tests/store/test_wrapper.py +++ b/clients/python/tests/store/test_wrapper.py @@ -17,6 +17,7 @@ TypeNotFoundException, ) from model_registry.store import MLMDStore +from model_registry.types.options import MLMDListOptions @pytest.fixture() @@ -53,6 +54,28 @@ def test_get_undefined_context_type_id(plain_wrapper: MLMDStore): plain_wrapper.get_type_id(Context, "undefined") +@pytest.mark.usefixtures("artifact") +def test_get_no_artifacts(plain_wrapper: MLMDStore): + arts = plain_wrapper.get_artifacts("test_artifact", MLMDListOptions()) + assert arts == [] + + +def test_get_undefined_artifacts(plain_wrapper: MLMDStore): + with pytest.raises(TypeNotFoundException): + plain_wrapper.get_artifacts("undefined", MLMDListOptions()) + + +@pytest.mark.usefixtures("context") +def test_get_no_contexts(plain_wrapper: MLMDStore): + ctxs = plain_wrapper.get_contexts("test_context", MLMDListOptions()) + assert ctxs == [] + + +def test_get_undefined_contexts(plain_wrapper: MLMDStore): + with pytest.raises(TypeNotFoundException): + plain_wrapper.get_contexts("undefined", MLMDListOptions()) + + def test_put_invalid_artifact(plain_wrapper: MLMDStore, artifact: Artifact): artifact.properties["null"].int_value = 0