diff --git a/src/zenml/artifacts/external_artifact_config.py b/src/zenml/artifacts/external_artifact_config.py index 49290fdff12..4c32493472e 100644 --- a/src/zenml/artifacts/external_artifact_config.py +++ b/src/zenml/artifacts/external_artifact_config.py @@ -93,8 +93,6 @@ def _get_artifact_from_model( RuntimeError: If `model_artifact_name` is set, but `model_name` is empty and model configuration is missing in @step and @pipeline. """ - from zenml.model.model_config import ModelConfig - if self.model_name is None: if model_config is None: raise RuntimeError( @@ -104,13 +102,18 @@ def _get_artifact_from_model( ) self.model_name = model_config.name self.model_version = model_config.version - - _model_config = ModelConfig( - name=self.model_name, - version=self.model_version, - suppress_warnings=True, - ) - model_version = _model_config._get_model_version() + if ( + model_config is None + or self.model_name != model_config.name + or self.model_version != model_config.version + ): + from zenml.model.model_config import ModelConfig + + model_config = ModelConfig( + name=self.model_name, + version=self.model_version, + ) + model_version = model_config._get_model_version() for artifact_getter in [ model_version.get_artifact_object, diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index f1052403040..dc5aa2a8e78 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -159,7 +159,7 @@ def register_model( audience=audience, use_cases=use_cases, trade_offs=tradeoffs, - ethic=ethical, + ethics=ethical, limitations=limitations, tags=tag, user=Client().active_user.id, @@ -265,7 +265,7 @@ def update_model( audience=audience, use_cases=use_cases, trade_offs=tradeoffs, - ethic=ethical, + ethics=ethical, limitations=limitations, tags=tag, user=Client().active_user.id, @@ -451,7 +451,7 @@ def update_model_version( def _print_artifacts_links_generic( model_name_or_id: str, model_version_name_or_number_or_id: str, - only_artifacts: bool = False, + only_artifact_objects: bool = False, only_deployments: bool = False, only_model_objects: bool = False, **kwargs: Any, @@ -461,36 +461,42 @@ def _print_artifacts_links_generic( Args: model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The name, number or ID of the model version. - only_artifacts: If set, only print artifacts. + only_artifact_objects: If set, only print artifacts. only_deployments: If set, only print deployments. only_model_objects: If set, only print model objects. **kwargs: Keyword arguments to filter models. """ model_version = Client().get_model_version( model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id, + model_version_name_or_number_or_id=None + if model_version_name_or_number_or_id == "0" + else model_version_name_or_number_or_id, + ) + type_ = ( + "artifacts" + if only_artifact_objects + else "deployments" + if only_deployments + else "model objects" ) if ( - (only_artifacts and not model_version.artifact_object_ids) + (only_artifact_objects and not model_version.artifact_object_ids) or (only_deployments and not model_version.deployment_ids) or (only_model_objects and not model_version.model_object_ids) ): - _type = ( - "artifacts" - if only_artifacts - else "deployments" - if only_deployments - else "model objects" - ) - cli_utils.declare(f"No {_type} linked to the model version found.") + cli_utils.declare(f"No {type_} linked to the model version found.") return + cli_utils.title( + f"{type_} linked to the model version `{model_version.name}[{model_version.number}]`:" + ) + links = Client().list_model_version_artifact_links( ModelVersionArtifactFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, - only_artifacts=only_artifacts, + only_artifacts=only_artifact_objects, only_deployments=only_deployments, only_model_objects=only_model_objects, **kwargs, @@ -515,7 +521,7 @@ def _print_artifacts_links_generic( help="List artifacts linked to a model version.", ) @click.argument("model_name_or_id") -@click.argument("model_version_name_or_number_or_id") +@click.argument("model_version_name_or_number_or_id", default="0") @cli_utils.list_options(ModelVersionArtifactFilterModel) def list_model_version_artifacts( model_name_or_id: str, @@ -527,12 +533,13 @@ def list_model_version_artifacts( Args: model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The name, number or ID of the model version. + Or use 0 for the latest version. **kwargs: Keyword arguments to filter models. """ _print_artifacts_links_generic( model_name_or_id=model_name_or_id, model_version_name_or_number_or_id=model_version_name_or_number_or_id, - only_artifacts=True, + only_artifact_objects=True, **kwargs, ) @@ -542,7 +549,7 @@ def list_model_version_artifacts( help="List model objects linked to a model version.", ) @click.argument("model_name_or_id") -@click.argument("model_version_name_or_number_or_id") +@click.argument("model_version_name_or_number_or_id", default="0") @cli_utils.list_options(ModelVersionArtifactFilterModel) def list_model_version_model_objects( model_name_or_id: str, @@ -554,6 +561,7 @@ def list_model_version_model_objects( Args: model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The name, number or ID of the model version. + Or use 0 for the latest version. **kwargs: Keyword arguments to filter models. """ _print_artifacts_links_generic( @@ -569,7 +577,7 @@ def list_model_version_model_objects( help="List deployments linked to a model version.", ) @click.argument("model_name_or_id") -@click.argument("model_version_name_or_number_or_id") +@click.argument("model_version_name_or_number_or_id", default="0") @cli_utils.list_options(ModelVersionArtifactFilterModel) def list_model_version_deployments( model_name_or_id: str, @@ -581,6 +589,7 @@ def list_model_version_deployments( Args: model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The name, number or ID of the model version. + Or use 0 for the latest version. **kwargs: Keyword arguments to filter models. """ _print_artifacts_links_generic( @@ -596,7 +605,7 @@ def list_model_version_deployments( help="List pipeline runs of a model version.", ) @click.argument("model_name_or_id") -@click.argument("model_version_name_or_number_or_id") +@click.argument("model_version_name_or_number_or_id", default="0") @cli_utils.list_options(ModelVersionPipelineRunFilterModel) def list_model_version_pipeline_runs( model_name_or_id: str, @@ -608,16 +617,22 @@ def list_model_version_pipeline_runs( Args: model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The name, number or ID of the model version. + Or use 0 for the latest version. **kwargs: Keyword arguments to filter models. """ model_version = Client().get_model_version( model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id, + model_version_name_or_number_or_id=None + if model_version_name_or_number_or_id == "0" + else model_version_name_or_number_or_id, ) if not model_version.pipeline_run_ids: cli_utils.declare("No pipeline runs attached to model version found.") return + cli_utils.title( + f"Pipeline runs linked to the model version `{model_version.name}[{model_version.number}]`:" + ) links = Client().list_model_version_pipeline_run_links( ModelVersionPipelineRunFilterModel( diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index 8e2f4f55640..5797ee563ab 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -19,11 +19,10 @@ from zenml.config.constants import DOCKER_SETTINGS_KEY from zenml.config.source import Source, convert_source_validator from zenml.config.strict_base_model import StrictBaseModel -from zenml.models.model_base_model import ModelConfigModel +from zenml.model.model_config import ModelConfig if TYPE_CHECKING: from zenml.config import DockerSettings - from zenml.model.model_config import ModelConfig from zenml.config.base_settings import BaseSettings, SettingsOrDict @@ -41,28 +40,12 @@ class PipelineConfigurationUpdate(StrictBaseModel): extra: Dict[str, Any] = {} failure_hook_source: Optional[Source] = None success_hook_source: Optional[Source] = None - model_config_model: Optional[ModelConfigModel] = None + model_config: Optional[ModelConfig] = None _convert_source = convert_source_validator( "failure_hook_source", "success_hook_source" ) - @property - def model_config(self) -> Optional["ModelConfig"]: - """Gets a ModelConfig object out of the model config model. - - This is a technical circular import resolver. - - Returns: - The model config object, if configured. - """ - if self.model_config_model is None: - return None - - from zenml.model.model_config import ModelConfig - - return ModelConfig.parse_obj(self.model_config_model.dict()) - class PipelineConfiguration(PipelineConfigurationUpdate): """Pipeline configuration class.""" diff --git a/src/zenml/config/step_configurations.py b/src/zenml/config/step_configurations.py index 3d3ecdd3ea1..cbb95d5c5ce 100644 --- a/src/zenml/config/step_configurations.py +++ b/src/zenml/config/step_configurations.py @@ -34,12 +34,11 @@ from zenml.config.source import Source, convert_source_validator from zenml.config.strict_base_model import StrictBaseModel from zenml.logger import get_logger -from zenml.models.model_base_model import ModelConfigModel +from zenml.model.model_config import ModelConfig from zenml.utils import deprecation_utils if TYPE_CHECKING: from zenml.config import DockerSettings, ResourceSettings - from zenml.model.model_config import ModelConfig logger = get_logger(__name__) @@ -135,7 +134,7 @@ class StepConfigurationUpdate(StrictBaseModel): extra: Dict[str, Any] = {} failure_hook_source: Optional[Source] = None success_hook_source: Optional[Source] = None - model_config_model: Optional[ModelConfigModel] = None + model_config: Optional[ModelConfig] = None outputs: Mapping[str, PartialArtifactConfiguration] = {} @@ -146,22 +145,6 @@ class StepConfigurationUpdate(StrictBaseModel): "name" ) - @property - def model_config(self) -> Optional["ModelConfig"]: - """Gets a ModelConfig object out of the model config model. - - This is a technical circular import resolver. - - Returns: - The model config object, if configured. - """ - if self.model_config_model is None: - return None - - from zenml.model.model_config import ModelConfig - - return ModelConfig.parse_obj(self.model_config_model.dict()) - class PartialStepConfiguration(StepConfigurationUpdate): """Class representing a partial step configuration.""" diff --git a/src/zenml/model/artifact_config.py b/src/zenml/model/artifact_config.py index dd9f2ae1656..ac3499e7655 100644 --- a/src/zenml/model/artifact_config.py +++ b/src/zenml/model/artifact_config.py @@ -21,14 +21,9 @@ from zenml.enums import ModelStages from zenml.exceptions import StepContextError from zenml.logger import get_logger -from zenml.models.model_models import ( - ModelVersionArtifactFilterModel, - ModelVersionArtifactRequestModel, -) if TYPE_CHECKING: from zenml.model.model_config import ModelConfig - from zenml.models import ModelResponseModel, ModelVersionResponseModel logger = get_logger(__name__) @@ -87,7 +82,6 @@ def _model_config(self) -> "ModelConfig": name=self.model_name, version=self.model_version, create_new_model_version=False, - suppress_warnings=True, ) return on_the_fly_config @@ -100,27 +94,10 @@ def _model_config(self) -> "ModelConfig": # Return the model from the context return model_config - @property - def _model(self) -> "ModelResponseModel": - """Get the `ModelResponseModel`. - - Returns: - ModelResponseModel: The fetched or created model. - """ - return self._model_config.get_or_create_model() - - @property - def _model_version(self) -> "ModelVersionResponseModel": - """Get the `ModelVersionResponseModel`. - - Returns: - ModelVersionResponseModel: The model version. - """ - return self._model_config.get_or_create_model_version() - def _link_to_model_version( self, artifact_uuid: UUID, + model_config: "ModelConfig", is_model_object: bool = False, is_deployment: bool = False, ) -> None: @@ -130,14 +107,21 @@ def _link_to_model_version( Args: artifact_uuid: The UUID of the artifact to link. + model_config: The model configuration from caller. is_model_object: Whether the artifact is a model object. Defaults to False. is_deployment: Whether the artifact is a deployment. Defaults to False. """ from zenml.client import Client + from zenml.models.model_models import ( + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ) # Create a ZenML client client = Client() + model_version = model_config._get_model_version() + artifact_name = self.artifact_name if artifact_name is None: artifact = client.zen_store.get_artifact(artifact_id=artifact_uuid) @@ -149,8 +133,8 @@ def _link_to_model_version( workspace=client.active_workspace.id, name=artifact_name, artifact=artifact_uuid, - model=self._model.id, - model_version=self._model_version.id, + model=model_version.model.id, + model_version=model_version.id, is_model_object=is_model_object, is_deployment=is_deployment, overwrite=self.overwrite, @@ -164,11 +148,13 @@ def _link_to_model_version( user_id=client.active_user.id, workspace_id=client.active_workspace.id, name=artifact_name, - model_id=self._model.id, - model_version_id=self._model_version.id, + model_id=model_version.model.id, + model_version_id=model_version.id, only_artifacts=not (is_model_object or is_deployment), only_deployments=is_deployment, only_model_objects=is_model_object, + pipeline_name=self._pipeline_name, + step_name=self._step_name, ) ) if len(existing_links): @@ -178,8 +164,8 @@ def _link_to_model_version( f"Existing artifact link(s) `{artifact_name}` found and will be deleted." ) client.zen_store.delete_model_version_artifact_link( - model_name_or_id=self._model.id, - model_version_name_or_id=self._model_version.id, + model_name_or_id=model_version.model.id, + model_version_name_or_id=model_version.id, model_version_artifact_link_name_or_id=artifact_name, ) else: @@ -189,16 +175,17 @@ def _link_to_model_version( client.zen_store.create_model_version_artifact_link(request) def link_to_model( - self, - artifact_uuid: UUID, + self, artifact_uuid: UUID, model_config: "ModelConfig" ) -> None: """Link artifact to the model version. Args: - artifact_uuid (UUID): The UUID of the artifact to link. + artifact_uuid: The UUID of the artifact to link. + model_config: The model configuration from caller. """ self._link_to_model_version( artifact_uuid, + model_config=model_config, is_model_object=self.IS_MODEL_ARTIFACT, is_deployment=self.IS_DEPLOYMENT_ARTIFACT, ) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 4ce9264a07b..4fc00954e35 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -13,13 +13,21 @@ # permissions and limitations under the License. """ModelConfig user facing interface to pass into pipeline or step.""" -from typing import TYPE_CHECKING, Optional +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Union, +) -from pydantic import PrivateAttr +from pydantic import BaseModel, root_validator +from zenml.constants import RUNNING_MODEL_VERSION +from zenml.enums import ExecutionStatus, ModelStages from zenml.exceptions import EntityExistsError from zenml.logger import get_logger -from zenml.models.model_base_model import ModelConfigModel if TYPE_CHECKING: from zenml.models.model_models import ( @@ -30,12 +38,20 @@ logger = get_logger(__name__) -class ModelConfig(ModelConfigModel): +class ModelConfig(BaseModel): """ModelConfig class to pass into pipeline or step to set it into a model context. name: The name of the model. + license: The license under which the model is created. + description: The description of the model. + audience: The target audience of the model. + use_cases: The use cases of the model. + limitations: The known limitations of the model. + trade_offs: The tradeoffs of the model. + ethics: The ethical implications of the model. + tags: Tags associated with the model. version: The model version name, number or stage is optional and points model context - to a specific version/stage, if skipped and `create_new_model_version` is False - + to a specific version/stage. If skipped and `create_new_model_version` is False - latest model version will be used. version_description: The description of the model version. create_new_model_version: Whether to create a new model version during execution @@ -44,10 +60,112 @@ class ModelConfig(ModelConfigModel): delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. """ - _model: Optional["ModelResponseModel"] = PrivateAttr(default=None) - _model_version: Optional["ModelVersionResponseModel"] = PrivateAttr( - default=None - ) + name: str + license: Optional[str] + description: Optional[str] + audience: Optional[str] + use_cases: Optional[str] + limitations: Optional[str] + trade_offs: Optional[str] + ethics: Optional[str] + tags: Optional[List[str]] + version: Optional[Union[ModelStages, int, str]] + version_description: Optional[str] + create_new_model_version: bool = False + save_models_to_registry: bool = True + delete_new_version_on_failure: bool = True + + suppress_class_validation_warnings: bool = False + + class Config: + """Config class.""" + + smart_union = True + + @root_validator(pre=True) + def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate all in one. + + Args: + values: Dict of values. + + Returns: + Dict of validated values. + + Raises: + ValueError: If validation failed on one of the checks. + """ + create_new_model_version = values.get( + "create_new_model_version", False + ) + suppress_class_validation_warnings = values.get( + "suppress_class_validation_warnings", False + ) + version = values.get("version", None) + + if create_new_model_version: + misuse_message = ( + "`version` set to {set} cannot be used with `create_new_model_version`." + "You can leave it default or set to a non-stage and non-numeric string.\n" + "Examples:\n" + " - `version` set to 1 or '1' is interpreted as a version number\n" + " - `version` set to 'production' is interpreted as a stage\n" + " - `version` set to 'my_first_version_in_2023' is a valid version to be created\n" + " - `version` set to 'My Second Version!' is a valid version to be created\n" + ) + if isinstance(version, ModelStages) or version in [ + stage.value for stage in ModelStages + ]: + raise ValueError( + misuse_message.format(set="a `ModelStages` instance") + ) + if str(version).isnumeric(): + raise ValueError(misuse_message.format(set="a numeric value")) + if version is None: + if not suppress_class_validation_warnings: + logger.info( + "Creation of new model version was requested, but no version name was explicitly provided. " + f"Setting `version` to `{RUNNING_MODEL_VERSION}`." + ) + values["version"] = RUNNING_MODEL_VERSION + if ( + version in [stage.value for stage in ModelStages] + and not suppress_class_validation_warnings + ): + logger.info( + f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage." + ) + if str(version).isnumeric() and not suppress_class_validation_warnings: + logger.info( + f"`version` `{version}` is numeric and will be fetched using version number." + ) + values["suppress_class_validation_warnings"] = True + return values + + def _validate_config_in_runtime(self) -> None: + """Validate that config doesn't conflict with runtime environment. + + Raises: + RuntimeError: If recovery not requested, but model version already exists. + RuntimeError: If there is unfinished pipeline run for requested new model version. + """ + try: + model_version = self._get_model_version() + if self.create_new_model_version: + for run_name, run in model_version.pipeline_runs.items(): + if run.status == ExecutionStatus.RUNNING: + raise RuntimeError( + f"New model version was requested, but pipeline run `{run_name}` " + f"is still running with version `{model_version.name}`." + ) + + if self.delete_new_version_on_failure: + raise RuntimeError( + f"Cannot create version `{self.version}` " + f"for model `{self.name}` since it already exists" + ) + except KeyError: + self.get_or_create_model_version() def get_or_create_model(self) -> "ModelResponseModel": """This method should get or create a model from Model Control Plane. @@ -57,15 +175,12 @@ def get_or_create_model(self) -> "ModelResponseModel": Returns: The model based on configuration. """ - if self._model is not None: - return self._model - from zenml.client import Client from zenml.models.model_models import ModelRequestModel zenml_client = Client() try: - self._model = zenml_client.get_model(model_name_or_id=self.name) + model = zenml_client.get_model(model_name_or_id=self.name) except KeyError: model_request = ModelRequestModel( name=self.name, @@ -75,22 +190,20 @@ def get_or_create_model(self) -> "ModelResponseModel": use_cases=self.use_cases, limitations=self.limitations, trade_offs=self.trade_offs, - ethic=self.ethic, + ethics=self.ethics, tags=self.tags, user=zenml_client.active_user.id, workspace=zenml_client.active_workspace.id, ) model_request = ModelRequestModel.parse_obj(model_request) try: - self._model = zenml_client.create_model(model=model_request) + model = zenml_client.create_model(model=model_request) logger.info(f"New model `{self.name}` was created implicitly.") except EntityExistsError: # this is backup logic, if model was created somehow in between get and create calls - self._model = zenml_client.get_model( - model_name_or_id=self.name - ) + model = zenml_client.get_model(model_name_or_id=self.name) - return self._model + return model def _create_model_version( self, model: "ModelResponseModel" @@ -103,9 +216,6 @@ def _create_model_version( Returns: The model version based on configuration. """ - if self._model_version is not None: - return self._model_version - from zenml.client import Client from zenml.models.model_models import ModelVersionRequestModel @@ -123,14 +233,14 @@ def _create_model_version( model_name_or_id=self.name, model_version_name_or_number_or_id=self.version, ) - self._model_version = mv + model_version = mv except KeyError: - self._model_version = zenml_client.create_model_version( + model_version = zenml_client.create_model_version( model_version=mv_request ) logger.info(f"New model version `{self.version}` was created.") - return self._model_version + return model_version def _get_model_version(self) -> "ModelVersionResponseModel": """This method gets a model version from Model Control Plane. @@ -138,25 +248,22 @@ def _get_model_version(self) -> "ModelVersionResponseModel": Returns: The model version based on configuration. """ - if self._model_version is not None: - return self._model_version - from zenml.client import Client zenml_client = Client() if self.version is None: # raise if not found - self._model_version = zenml_client.get_model_version( + model_version = zenml_client.get_model_version( model_name_or_id=self.name ) else: # by version name or stage or number # raise if not found - self._model_version = zenml_client.get_model_version( + model_version = zenml_client.get_model_version( model_name_or_id=self.name, model_version_name_or_number_or_id=self.version, ) - return self._model_version + return model_version def get_or_create_model_version(self) -> "ModelVersionResponseModel": """This method should get or create a model and a model version from Model Control Plane. @@ -184,16 +291,18 @@ def get_or_create_model_version(self) -> "ModelVersionResponseModel": mv = self._get_model_version() return mv - def _merge_with_config(self, model_config: ModelConfigModel) -> None: + def _merge(self, model_config: "ModelConfig") -> None: self.license = self.license or model_config.license self.description = self.description or model_config.description self.audience = self.audience or model_config.audience self.use_cases = self.use_cases or model_config.use_cases self.limitations = self.limitations or model_config.limitations self.trade_offs = self.trade_offs or model_config.trade_offs - self.ethic = self.ethic or model_config.ethic + self.ethics = self.ethics or model_config.ethics if model_config.tags is not None: - self.tags = (self.tags or []) + model_config.tags + self.tags = list( + {t for t in self.tags or []}.union(set(model_config.tags)) + ) self.delete_new_version_on_failure &= ( model_config.delete_new_version_on_failure diff --git a/src/zenml/models/model_base_model.py b/src/zenml/models/model_base_model.py index e568dd31827..411a3b32aa8 100644 --- a/src/zenml/models/model_base_model.py +++ b/src/zenml/models/model_base_model.py @@ -13,14 +13,10 @@ # permissions and limitations under the License. """Model base model to support Model Control Plane feature.""" -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field -from zenml.constants import ( - RUNNING_MODEL_VERSION, -) -from zenml.enums import ModelStages from zenml.logger import get_logger from zenml.models.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH @@ -58,106 +54,10 @@ class ModelBaseModel(BaseModel): title="The trade offs of the model", max_length=TEXT_FIELD_MAX_LENGTH, ) - ethic: Optional[str] = Field( + ethics: Optional[str] = Field( title="The ethical implications of the model", max_length=TEXT_FIELD_MAX_LENGTH, ) tags: Optional[List[str]] = Field( title="Tags associated with the model", ) - - -class ModelConfigModel(ModelBaseModel): - """ModelConfig class to pass into pipeline or step to set it into a model context. - - name: The name of the model. - version: The model version name, number or stage is optional and points model context - to a specific version/stage, if skipped and `create_new_model_version` is False - - latest model version will be used. - version_description: The description of the model version. - create_new_model_version: Whether to create a new model version during execution - save_models_to_registry: Whether to save all ModelArtifacts to Model Registry, - if available in active stack. - delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. - suppress_warnings: Whether to suppress warnings during validation. - """ - - version: Optional[Union[ModelStages, int, str]] - version_description: Optional[str] - create_new_model_version: bool = False - save_models_to_registry: bool = True - delete_new_version_on_failure: bool = True - suppress_warnings: bool = False - - class Config: - """Config class.""" - - smart_union = True - - @root_validator - def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Validate all in one. - - Args: - values: Dict of values. - - Returns: - Dict of validated values. - - Raises: - ValueError: If validation failed on one of the checks. - """ - create_new_model_version = values.get( - "create_new_model_version", False - ) - delete_new_version_on_failure = values.get( - "delete_new_version_on_failure", True - ) - suppress_warnings = values.get("suppress_warnings", False) - if not delete_new_version_on_failure and not create_new_model_version: - if not suppress_warnings: - logger.warning( - "Using `delete_new_version_on_failure=False` and `create_new_model_version=False` has no effect." - "Setting `delete_new_version_on_failure` to `True`." - ) - values["delete_new_version_on_failure"] = True - - version = values.get("version", None) - - if create_new_model_version: - misuse_message = ( - "`version` set to {set} cannot be used with `create_new_model_version`." - "You can leave it default or set to a non-stage and non-numeric string.\n" - "Examples:\n" - " - `version` set to 1 or '1' is interpreted as a version number\n" - " - `version` set to 'production' is interpreted as a stage\n" - " - `version` set to 'my_first_version_in_2023' is a valid version to be created\n" - " - `version` set to 'My Second Version!' is a valid version to be created\n" - ) - if isinstance(version, ModelStages) or version in [ - stage.value for stage in ModelStages - ]: - raise ValueError( - misuse_message.format(set="a `ModelStages` instance") - ) - if str(version).isnumeric(): - raise ValueError(misuse_message.format(set="a numeric value")) - if version is None: - if not suppress_warnings: - logger.info( - "Creation of new model version was requested, but no version name was explicitly provided. " - f"Setting `version` to `{RUNNING_MODEL_VERSION}`." - ) - values["version"] = RUNNING_MODEL_VERSION - if ( - version in [stage.value for stage in ModelStages] - and not suppress_warnings - ): - logger.info( - f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage." - ) - if str(version).isnumeric() and not suppress_warnings: - logger.info( - f"`version` `{version}` is numeric and will be fetched using version number." - ) - return values diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 475391094a4..4f83574ec7a 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -618,5 +618,5 @@ class ModelUpdateModel(BaseModel): use_cases: Optional[str] limitations: Optional[str] trade_offs: Optional[str] - ethic: Optional[str] + ethics: Optional[str] tags: Optional[List[str]] diff --git a/src/zenml/new/pipelines/model_utils.py b/src/zenml/new/pipelines/model_utils.py index a9328f525f6..2ff509adaaf 100644 --- a/src/zenml/new/pipelines/model_utils.py +++ b/src/zenml/new/pipelines/model_utils.py @@ -18,7 +18,6 @@ from pydantic import BaseModel, PrivateAttr from zenml.model.model_config import ModelConfig -from zenml.models.model_base_model import ModelConfigModel class NewModelVersionRequest(BaseModel): @@ -57,7 +56,7 @@ def model_config(self) -> ModelConfig: def update_request( self, - model_config: ModelConfigModel, + model_config: ModelConfig, requester: "NewModelVersionRequest.Requester", ) -> None: """Update from Model Config Model object in place. @@ -71,7 +70,7 @@ def update_request( """ self.requesters.append(requester) if self._model_config is None: - self._model_config = ModelConfig.parse_obj(model_config) + self._model_config = model_config if self._model_config.version != model_config.version: raise ValueError( @@ -79,4 +78,4 @@ def update_request( "Since a new model version is requested for this model, all `version` names must match or left default." ) - self._model_config._merge_with_config(model_config) + self._model_config._merge(model_config) diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 68a9c4d050b..85378b35730 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -53,7 +53,7 @@ from zenml.config.pipeline_spec import PipelineSpec from zenml.config.schedule import Schedule from zenml.config.step_configurations import StepConfigurationUpdate -from zenml.enums import ExecutionStatus, StackComponentType +from zenml.enums import StackComponentType from zenml.hooks.hook_validators import resolve_and_validate_hook from zenml.logger import get_logger from zenml.models import ( @@ -66,7 +66,6 @@ PipelineRunResponseModel, ScheduleRequestModel, ) -from zenml.models.model_base_model import ModelConfigModel from zenml.models.pipeline_build_models import ( PipelineBuildBaseModel, ) @@ -354,8 +353,6 @@ def configure( # string of on_success hook function to be used for this pipeline success_hook_source = resolve_and_validate_hook(on_success) - if model_config: - model_config.suppress_warnings = True values = dict_utils.remove_none_values( { "enable_cache": enable_cache, @@ -366,18 +363,12 @@ def configure( "extra": extra, "failure_hook_source": failure_hook_source, "success_hook_source": success_hook_source, - "model_config_model": ModelConfigModel.parse_obj( - model_config.dict() - ) - if model_config is not None - else None, + "model_config": model_config, } ) if not self.__suppress_warnings_flag__: to_be_reapplied = [] for param_, value_ in values.items(): - if param_ == "model_config_model": - param_ = "model_config" if ( param_ in PipelineRunConfiguration.__fields__ and param_ in self._from_config_file @@ -825,46 +816,52 @@ def get_new_version_requests( new_versions_requested: Dict[ str, NewModelVersionRequest ] = defaultdict(NewModelVersionRequest) + other_model_configs: List["ModelConfig"] = [] all_steps_have_own_config = True for step in deployment.step_configurations.values(): - step_model_config = step.config.model_config_model + step_model_config = step.config.model_config all_steps_have_own_config = ( all_steps_have_own_config - and step.config.model_config_model is not None + and step.config.model_config is not None ) - if ( - step_model_config - and step_model_config.create_new_model_version - ): - new_versions_requested[step_model_config.name].update_request( - step_model_config, - NewModelVersionRequest.Requester( - source="step", name=step.config.name - ), - ) + if step_model_config: + if step_model_config.create_new_model_version: + new_versions_requested[ + step_model_config.name + ].update_request( + step_model_config, + NewModelVersionRequest.Requester( + source="step", name=step.config.name + ), + ) + else: + other_model_configs.append(step_model_config) if not all_steps_have_own_config: pipeline_model_config = ( - deployment.pipeline_configuration.model_config_model + deployment.pipeline_configuration.model_config ) - if ( - pipeline_model_config - and pipeline_model_config.create_new_model_version - ): - new_versions_requested[ - pipeline_model_config.name - ].update_request( - pipeline_model_config, - NewModelVersionRequest.Requester( - source="pipeline", name=self.name - ), - ) - elif deployment.pipeline_configuration.model_config_model is not None: + if pipeline_model_config: + if pipeline_model_config.create_new_model_version: + new_versions_requested[ + pipeline_model_config.name + ].update_request( + pipeline_model_config, + NewModelVersionRequest.Requester( + source="pipeline", name=self.name + ), + ) + else: + other_model_configs.append(pipeline_model_config) + elif deployment.pipeline_configuration.model_config is not None: logger.warning( f"ModelConfig of pipeline `{self.name}` is overridden in all steps. " ) self._validate_new_version_requests(new_versions_requested) + for other_model_config in other_model_configs: + other_model_config._validate_config_in_runtime() + return new_versions_requested def _validate_new_version_requests( @@ -875,11 +872,6 @@ def _validate_new_version_requests( Args: new_versions_requested: A dict of new model version request objects. - - Raises: - RuntimeError: If there is unfinished pipeline run for requested model version. - RuntimeError: If recovery not requested, but model version already exists. - """ for model_name, data in new_versions_requested.items(): if len(data.requesters) > 1: @@ -888,25 +880,7 @@ def _validate_new_version_requests( f"{data.requesters}\n We recommend that `create_new_model_version` is configured " "only in one place of the pipeline." ) - try: - model_version = data.model_config._get_model_version() - - for run_name, run in model_version.pipeline_runs.items(): - if run.status == ExecutionStatus.RUNNING: - raise RuntimeError( - f"New model version was requested, but pipeline run `{run_name}` " - f"is still running with version `{model_version.name}`." - ) - if ( - data.model_config.version - and data.model_config.delete_new_version_on_failure - ): - raise RuntimeError( - f"Cannot create version `{data.model_config.version}` " - f"for model `{data.model_config.name}` since it already exists" - ) - except KeyError: - pass + data.model_config._validate_config_in_runtime() def update_new_versions_requests( self, @@ -928,7 +902,7 @@ def update_new_versions_requests( for step_name in deployment.step_configurations: step_model_config = deployment.step_configurations[ step_name - ].config.model_config_model + ].config.model_config if ( step_model_config is not None and step_model_config.name in new_version_requests @@ -937,9 +911,7 @@ def update_new_versions_requests( step_model_config.name ].model_config.version step_model_config.create_new_model_version = True - pipeline_model_config = ( - deployment.pipeline_configuration.model_config_model - ) + pipeline_model_config = deployment.pipeline_configuration.model_config if ( pipeline_model_config is not None and pipeline_model_config.name in new_version_requests @@ -1156,7 +1128,7 @@ def _compile( integration_registry.activate_integrations() - self._from_config_file = self._parse_config_file( + self._parse_config_file( config_path=config_path, matcher=list(PipelineRunConfiguration.__fields__.keys()), ) @@ -1393,15 +1365,12 @@ def __exit__(self, *args: Any) -> None: def _parse_config_file( self, config_path: Optional[str], matcher: List[str] - ) -> Dict[str, Any]: - """Parses the given configuration file. + ) -> None: + """Parses the given configuration file and sets `self._from_config_file`. Args: config_path: Path to a yaml configuration file. matcher: List of keys to match in the configuration file. - - Returns: - The parsed configuration file as a dictionary. """ _from_config_file: Dict[str, Any] = {} if config_path: @@ -1416,12 +1385,17 @@ def _parse_config_file( ) if "model_config" in _from_config_file: - from zenml.model.model_config import ModelConfig + if "model_config" in self._from_config_file: + _from_config_file["model_config"] = self._from_config_file[ + "model_config" + ] + else: + from zenml.model.model_config import ModelConfig - _from_config_file["model_config"] = ModelConfig.parse_obj( - _from_config_file["model_config"] - ) - return _from_config_file + _from_config_file["model_config"] = ModelConfig.parse_obj( + _from_config_file["model_config"] + ) + self._from_config_file = _from_config_file def with_options( self, @@ -1460,7 +1434,7 @@ def with_options( """ pipeline_copy = self.copy() - pipeline_copy._from_config_file = self._parse_config_file( + pipeline_copy._parse_config_file( config_path=config_path, matcher=inspect.getfullargspec(self.configure)[0], ) diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index b3072aef022..e5eae10f17e 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -381,11 +381,10 @@ def _prepare( output_name: artifact.id for output_name, artifact in cached_outputs.items() } - if model_config: - self._link_cached_artifacts_to_model_version( - model_config=model_config, - step_run=step_run, - ) + self._link_cached_artifacts_to_model_version( + model_config_from_context=model_config, + step_run=step_run, + ) step_run.status = ExecutionStatus.CACHED step_run.end_time = step_run.start_time @@ -393,20 +392,19 @@ def _prepare( def _link_cached_artifacts_to_model_version( self, - model_config: "ModelConfig", + model_config_from_context: Optional["ModelConfig"], step_run: StepRunRequestModel, ) -> None: """Links the output artifacts of the cached step to the model version in Control Plane. Args: - model_config: The model config of the current step. + model_config_from_context: The model config of the current step. step_run: The step to run. """ from zenml.model.artifact_config import ArtifactConfig from zenml.steps.base_step import BaseStep from zenml.steps.utils import parse_return_type_annotations - model_version = model_config.get_or_create_model_version() step_instance = BaseStep.load_from_source(self._step.spec.source) output_annotations = parse_return_type_annotations( step_instance.entrypoint @@ -414,23 +412,37 @@ def _link_cached_artifacts_to_model_version( for output_name_, output_ in step_run.outputs.items(): if output_name_ in output_annotations: annotation = output_annotations.get(output_name_, None) - artifact_config = ( - annotation.artifact_config - if annotation and annotation.artifact_config is not None - else ArtifactConfig() - ) - artifact_config_ = artifact_config.copy() - artifact_config_.model_name = ( - artifact_config.model_name or model_version.model.name - ) - artifact_config_.model_version = ( - artifact_config_.model_version or model_version.name - ) - artifact_config_._pipeline_name = ( - self._deployment.pipeline_configuration.name - ) - artifact_config_._step_name = self._step_name - artifact_config_.link_to_model(output_) + if annotation and annotation.artifact_config is not None: + artifact_config_ = annotation.artifact_config.copy() + else: + artifact_config_ = ArtifactConfig( + artifact_name=output_name_ + ) + logger.info( + f"Linking artifact `{artifact_config_.artifact_name}` to " + f"model `{artifact_config_.model_name}` version " + f"`{artifact_config_.model_version}` implicitly." + ) + if artifact_config_.model_name is None: + model_config = model_config_from_context + else: + from zenml.model.model_config import ModelConfig + + model_config = ModelConfig( + name=artifact_config_.model_name, + version=artifact_config_.model_version, + ) + if model_config: + model_config.get_or_create_model_version() + + artifact_config_._pipeline_name = ( + self._deployment.pipeline_configuration.name + ) + artifact_config_._step_name = self._step_name + artifact_config_.link_to_model( + artifact_uuid=output_, + model_config=model_config, + ) def _run_step( self, diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 8752c0f82bb..d9385ab3a6b 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -623,38 +623,57 @@ def _link_artifacts_to_model( """ from zenml.model.artifact_config import ArtifactConfig + context = get_step_context() try: - mc = get_step_context().model_config + model_config_from_context = context.model_config except StepContextError: - mc = None - logger.warning( - "No model context found, unable to auto-link artifacts." - ) + model_config_from_context = None for artifact_name in artifact_ids: artifact_uuid = artifact_ids[artifact_name] - artifact_config = ( - get_step_context()._get_output(artifact_name).artifact_config - ) - if artifact_config is None and mc is not None: - artifact_config = ArtifactConfig( - model_name=mc.name, - model_version=mc.version, - artifact_name=artifact_name, - ) - logger.info( - f"Linking artifact `{artifact_name}` to model `{mc.name}` version `{mc.version}` implicitly." - ) + artifact_config_ = context._get_output( + artifact_name + ).artifact_config + if artifact_config_ is None: + if model_config_from_context is not None: + artifact_config_ = ArtifactConfig( + artifact_name=artifact_name, + ) + logger.info( + f"Linking artifact `{artifact_name}` to model `{model_config_from_context.name}` version `{model_config_from_context.version}` implicitly." + ) + else: + artifact_config_ = artifact_config_.copy() + + if artifact_config_ is not None: + model_config = None + if model_config_from_context is None: + if artifact_config_.model_name is None: + logger.warning( + "No model context found, unable to auto-link artifacts." + ) + return - if artifact_config is not None: - artifact_config.artifact_name = ( - artifact_config.artifact_name or artifact_name - ) - artifact_config._pipeline_name = ( - get_step_context().pipeline.name - ) - artifact_config._step_name = get_step_context().step_run.name - artifact_config.link_to_model(artifact_uuid=artifact_uuid) + if artifact_config_.model_name is not None: + from zenml.model.model_config import ModelConfig + + model_config = ModelConfig( + name=artifact_config_.model_name, + version=artifact_config_.model_version, + ) + else: + model_config = model_config_from_context + + if model_config: + artifact_config_.artifact_name = ( + artifact_config_.artifact_name or artifact_name + ) + artifact_config_._pipeline_name = context.pipeline.name + artifact_config_._step_name = context.step_run.name + artifact_config_.link_to_model( + artifact_uuid=artifact_uuid, + model_config=model_config, + ) def _get_model_versions_from_artifacts( self, diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index fa453251014..e9c75f167c4 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -699,7 +699,6 @@ def configure( """ from zenml.config.step_configurations import StepConfigurationUpdate from zenml.hooks.hook_validators import resolve_and_validate_hook - from zenml.models.model_base_model import ModelConfigModel if name: logger.warning("Configuring the name of a step is deprecated.") @@ -748,8 +747,6 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: if isinstance(parameters, BaseParameters): parameters = parameters.dict() - if model_config: - model_config.suppress_warnings = True values = dict_utils.remove_none_values( { "enable_cache": enable_cache, @@ -764,11 +761,7 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: "extra": extra, "failure_hook_source": failure_hook_source, "success_hook_source": success_hook_source, - "model_config_model": ModelConfigModel.parse_obj( - model_config.dict() - ) - if model_config is not None - else None, + "model_config": model_config, } ) config = StepConfigurationUpdate(**values) diff --git a/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py b/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py new file mode 100644 index 00000000000..13e3b296a2c --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py @@ -0,0 +1,103 @@ +"""rename model_config_model to model_config in pipeline and step configs [4f66af55fbb9]. + +Revision ID: 4f66af55fbb9 +Revises: 0.45.2 +Create Date: 2023-10-17 13:57:35.810054 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.sql import text + +# revision identifiers, used by Alembic. +revision = "4f66af55fbb9" +down_revision = "729263e47b55" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + table_name="model", + column_name="ethic", + new_column_name="ethics", + existing_type=sa.TEXT(), + ) + + connection = op.get_bind() + + update_config_fields = text( + """ + UPDATE pipeline_deployment + SET pipeline_configuration = REPLACE( + pipeline_configuration, + '"model_config_model"', + '"model_config"' + ), + step_configurations = REPLACE( + step_configurations, + '"model_config_model"', + '"model_config"' + ) + """ + ) + connection.execute(update_config_fields) + + update_config_fields = text( + """ + UPDATE step_run + SET step_configuration = REPLACE( + step_configuration, + '"model_config_model"', + '"model_config"' + ) + """ + ) + connection.execute(update_config_fields) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + table_name="model", + column_name="ethics", + new_column_name="ethic", + existing_type=sa.TEXT(), + ) + + connection = op.get_bind() + + update_config_fields = text( + """ + UPDATE pipeline_deployment + SET pipeline_configuration = REPLACE( + pipeline_configuration, + '"model_config"', + '"model_config_model"' + ), + step_configurations = REPLACE( + step_configurations, + '"model_config"', + '"model_config_model"' + ) + """ + ) + connection.execute(update_config_fields) + + update_config_fields = text( + """ + UPDATE step_run + SET step_configuration = REPLACE( + step_configuration, + '"model_config"', + '"model_config_model"' + ) + """ + ) + connection.execute(update_config_fields) + + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 774bb1bc684..3078069c768 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -72,7 +72,7 @@ class ModelSchema(NamedSchema, table=True): use_cases: str = Field(sa_column=Column(TEXT, nullable=True)) limitations: str = Field(sa_column=Column(TEXT, nullable=True)) trade_offs: str = Field(sa_column=Column(TEXT, nullable=True)) - ethic: str = Field(sa_column=Column(TEXT, nullable=True)) + ethics: str = Field(sa_column=Column(TEXT, nullable=True)) tags: str = Field(sa_column=Column(TEXT, nullable=True)) model_versions: List["ModelVersionSchema"] = Relationship( back_populates="model", @@ -107,7 +107,7 @@ def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": use_cases=model_request.use_cases, limitations=model_request.limitations, trade_offs=model_request.trade_offs, - ethic=model_request.ethic, + ethics=model_request.ethics, tags=json.dumps(model_request.tags) if model_request.tags else None, @@ -132,7 +132,7 @@ def to_model(self) -> ModelResponseModel: use_cases=self.use_cases, limitations=self.limitations, trade_offs=self.trade_offs, - ethic=self.ethic, + ethics=self.ethics, tags=json.loads(self.tags) if self.tags else None, ) diff --git a/tests/integration/functional/model/test_artifact_config.py b/tests/integration/functional/model/test_artifact_config.py index fadb67b8dd5..2fcc349571d 100644 --- a/tests/integration/functional/model/test_artifact_config.py +++ b/tests/integration/functional/model/test_artifact_config.py @@ -228,13 +228,13 @@ def multi_named_output_step_from_self() -> ( return 1, 2, 3 -@pipeline(enable_cache=False) -def multi_named_pipeline_from_self(): +@pipeline +def multi_named_pipeline_from_self(enable_cache: bool): """Multi output linking from Annotated.""" - multi_named_output_step_from_self() + multi_named_output_step_from_self.with_options(enable_cache=enable_cache)() -def test_link_multiple_named_outputs_with_self_context(): +def test_link_multiple_named_outputs_with_self_context_and_caching(): """Test multi output linking with context defined in Annotated.""" with model_killer(): client = Client() @@ -266,32 +266,42 @@ def test_link_multiple_named_outputs_with_self_context(): ) ) - multi_named_pipeline_from_self() + for run_count in range(1, 3): + multi_named_pipeline_from_self(run_count == 2) - al1 = client.list_model_version_artifact_links( - ModelVersionArtifactFilterModel( - user_id=user, - workspace_id=ws, - model_id=mv1.model.id, - model_version_id=mv1.id, + al1 = client.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + user_id=user, + workspace_id=ws, + model_id=mv1.model.id, + model_version_id=mv1.id, + ) ) - ) - al2 = client.list_model_version_artifact_links( - ModelVersionArtifactFilterModel( - user_id=user, - workspace_id=ws, - model_id=mv2.model.id, - model_version_id=mv2.id, + al2 = client.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + user_id=user, + workspace_id=ws, + model_id=mv2.model.id, + model_version_id=mv2.id, + ) ) - ) - assert al1.size == 2 - assert al2.size == 1 - - assert {al.name for al in al1} == { - "1", - "2", - } - assert al2[0].name == "3" + assert al1.size == 2 + assert al2.size == 1 + + assert {al.name for al in al1} == { + "1", + "2", + } + assert al2[0].name == "3" + + # clean-up links to test caching linkage + for mv, al in zip([mv1, mv2], [al1, al2]): + for al_ in al: + client.zen_store.delete_model_version_artifact_link( + model_name_or_id=mv.model.id, + model_version_name_or_id=mv.id, + model_version_artifact_link_name_or_id=al_.id, + ) @step(model_config=ModelConfig(name="step", version="step")) diff --git a/tests/integration/functional/model/test_model_config.py b/tests/integration/functional/model/test_model_config.py index 20c613701fe..3a38052af91 100644 --- a/tests/integration/functional/model/test_model_config.py +++ b/tests/integration/functional/model/test_model_config.py @@ -171,26 +171,9 @@ def test_init_create_new_version_with_version_fails(self): assert mc.create_new_model_version assert mc.version == RUNNING_MODEL_VERSION - def test_init_recovery_without_create_new_version_warns(self): - """Test that use of `recovery` warn on `create_new_model_version` set to False.""" - with mock.patch( - "zenml.models.model_base_model.logger.warning" - ) as logger: - ModelConfig(name=MODEL_NAME, delete_new_version_on_failure=False) - logger.assert_called_once() - with mock.patch( - "zenml.models.model_base_model.logger.warning" - ) as logger: - ModelConfig( - name=MODEL_NAME, - delete_new_version_on_failure=False, - create_new_model_version=True, - ) - logger.assert_not_called() - def test_init_stage_logic(self): """Test that if version is set to string contained in ModelStages user is informed about it.""" - with mock.patch("zenml.models.model_base_model.logger.info") as logger: + with mock.patch("zenml.model.model_config.logger.info") as logger: mc = ModelConfig( name=MODEL_NAME, version=ModelStages.PRODUCTION.value, diff --git a/tests/integration/functional/pipelines/test_pipeline_config.py b/tests/integration/functional/pipelines/test_pipeline_config.py index 6173da1a060..5eccf78bfc3 100644 --- a/tests/integration/functional/pipelines/test_pipeline_config.py +++ b/tests/integration/functional/pipelines/test_pipeline_config.py @@ -37,7 +37,7 @@ def assert_model_config_step(): assert model_config.use_cases == "use_cases" assert model_config.limitations == "limitations" assert model_config.trade_offs == "trade_offs" - assert model_config.ethic == "ethic" + assert model_config.ethics == "ethics" assert model_config.tags == ["tag"] assert model_config.version_description == "version_description" assert model_config.save_models_to_registry @@ -62,7 +62,7 @@ def test_pipeline_with_model_config_from_yaml(clean_workspace, tmp_path): use_cases="use_cases", limitations="limitations", trade_offs="trade_offs", - ethic="ethic", + ethics="ethics", tags=["tag"], version_description="version_description", save_models_to_registry=True, @@ -146,7 +146,7 @@ def assert_model_config_pipeline(): use_cases="use_cases", limitations="limitations", trade_offs="trade_offs", - ethic="ethic", + ethics="ethics", tags=["tag"], version_description="version_description", save_models_to_registry=True, @@ -165,7 +165,7 @@ def assert_model_config_pipeline(): assert p.configuration.model_config.use_cases == "use_cases" assert p.configuration.model_config.limitations == "limitations" assert p.configuration.model_config.trade_offs == "trade_offs" - assert p.configuration.model_config.ethic == "ethic" + assert p.configuration.model_config.ethics == "ethics" assert p.configuration.model_config.tags == ["tag"] assert ( p.configuration.model_config.version_description diff --git a/tests/integration/functional/steps/test_external_artifact.py b/tests/integration/functional/steps/test_external_artifact.py index 98164c80f29..2f6b92f77d0 100644 --- a/tests/integration/functional/steps/test_external_artifact.py +++ b/tests/integration/functional/steps/test_external_artifact.py @@ -62,6 +62,25 @@ def consumer_pipeline( ) +@pipeline(name="bar", enable_cache=False) +def consumer_pipeline_with_external_artifact_from_another_model( + model_artifact_version: int, + model_artifact_pipeline_name: str = None, + model_artifact_step_name: str = None, +): + consumer( + ExternalArtifact( + model_name="foo", + model_version=1, + model_artifact_name="predictions", + model_artifact_version=model_artifact_version, + model_artifact_pipeline_name=model_artifact_pipeline_name, + model_artifact_step_name=model_artifact_step_name, + ), + model_artifact_version, + ) + + @pipeline( name="bar", enable_cache=False, @@ -72,7 +91,15 @@ def two_step_producer_pipeline(): producer(1) -def test_exchange_of_model_artifacts_between_pipelines(): +@pytest.mark.parametrize( + "consumer_pipeline", + [ + consumer_pipeline, + consumer_pipeline_with_external_artifact_from_another_model, + ], + ids=["model context given", "no model context"], +) +def test_exchange_of_model_artifacts_between_pipelines(consumer_pipeline): """Test that ExternalArtifact helps to exchange data from Model between pipelines.""" with model_killer(): producer_pipeline.with_options( diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index 8f8bf5a57ab..2d21acf2d41 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -945,7 +945,7 @@ def update_method( use_cases="all", limitations="none", trade_offs="secret", - ethic="all good", + ethics="all good", tags=["cool", "stuff"], ), update_model=ModelUpdateModel( diff --git a/tests/unit/model/test_model_config_init.py b/tests/unit/model/test_model_config_init.py new file mode 100644 index 00000000000..adb68efa578 --- /dev/null +++ b/tests/unit/model/test_model_config_init.py @@ -0,0 +1,64 @@ +from unittest.mock import patch + +import pytest + +from zenml.enums import ModelStages +from zenml.model import ModelConfig + + +@pytest.mark.parametrize( + "version_name,create_new_model_version,delete_new_version_on_failure,logger", + [ + [None, True, False, "info"], + ["staging", False, False, "info"], + ["1", False, False, "info"], + [1, False, False, "info"], + ], + ids=[ + "Default running version", + "Pick model by text stage", + "Pick model by text version number", + "Pick model by integer version number", + ], +) +def test_init_warns( + version_name, + create_new_model_version, + delete_new_version_on_failure, + logger, +): + with patch(f"zenml.model.model_config.logger.{logger}") as logger: + ModelConfig( + name="foo", + version=version_name, + create_new_model_version=create_new_model_version, + delete_new_version_on_failure=delete_new_version_on_failure, + ) + logger.assert_called_once() + + +@pytest.mark.parametrize( + "version_name,create_new_model_version", + [ + [1, True], + ["1", True], + [ModelStages.PRODUCTION, True], + ["production", True], + ], + ids=[ + "Version number as integer and new version request", + "Version number as string and new version request", + "Version stage as instance and new version request", + "Version stage as string and new version request", + ], +) +def test_init_raises( + version_name, + create_new_model_version, +): + with pytest.raises(ValueError): + ModelConfig( + name="foo", + version=version_name, + create_new_model_version=create_new_model_version, + ) diff --git a/tests/unit/models/test_model_models.py b/tests/unit/models/test_model_models.py index 4e2f5a89c36..9bb562fe360 100644 --- a/tests/unit/models/test_model_models.py +++ b/tests/unit/models/test_model_models.py @@ -19,8 +19,6 @@ import pytest from tests.unit.steps.test_external_artifact import MockZenmlClient -from zenml.enums import ModelStages -from zenml.models.model_base_model import ModelConfigModel from zenml.models.model_models import ( ModelResponseModel, ModelVersionResponseModel, @@ -181,63 +179,3 @@ def test_getters( step_name=query_step, version=query_version, ) - - -@pytest.mark.parametrize( - "version_name,create_new_model_version,delete_new_version_on_failure,logger", - [ - [None, False, False, "warning"], - [None, True, False, "info"], - ["staging", False, False, "info"], - ["1", False, False, "info"], - [1, False, False, "info"], - ], - ids=[ - "No new version, but recovery", - "Default running version", - "Pick model by text stage", - "Pick model by text version number", - "Pick model by integer version number", - ], -) -def test_init_warns( - version_name, - create_new_model_version, - delete_new_version_on_failure, - logger, -): - with patch(f"zenml.models.model_base_model.logger.{logger}") as logger: - ModelConfigModel( - name="foo", - version=version_name, - create_new_model_version=create_new_model_version, - delete_new_version_on_failure=delete_new_version_on_failure, - ) - logger.assert_called_once() - - -@pytest.mark.parametrize( - "version_name,create_new_model_version", - [ - [1, True], - ["1", True], - [ModelStages.PRODUCTION, True], - ["production", True], - ], - ids=[ - "Version number as integer and new version request", - "Version number as string and new version request", - "Version stage as instance and new version request", - "Version stage as string and new version request", - ], -) -def test_init_raises( - version_name, - create_new_model_version, -): - with pytest.raises(ValueError): - ModelConfigModel( - name="foo", - version=version_name, - create_new_model_version=create_new_model_version, - )