From 7e4171aadc14cc292dae65c53139608a8472b5f2 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 12:16:52 +0200 Subject: [PATCH 01/41] flatten `ModelConfig` code --- .../artifacts/external_artifact_config.py | 19 +- src/zenml/config/pipeline_configurations.py | 21 +- src/zenml/config/step_configurations.py | 21 +- src/zenml/model/artifact_config.py | 9 +- src/zenml/model/model_config.py | 249 +++++++++++++++--- src/zenml/models/model_base_model.py | 104 +------- src/zenml/new/pipelines/model_utils.py | 7 +- src/zenml/new/pipelines/pipeline.py | 52 +--- src/zenml/steps/base_step.py | 9 +- src/zenml/utils/pydantic_utils.py | 4 + tests/unit/model/test_model_config.py | 66 +++++ tests/unit/models/test_model_models.py | 62 ----- 12 files changed, 314 insertions(+), 309 deletions(-) create mode 100644 tests/unit/model/test_model_config.py diff --git a/src/zenml/artifacts/external_artifact_config.py b/src/zenml/artifacts/external_artifact_config.py index 49290fdff12..90d279a8d73 100644 --- a/src/zenml/artifacts/external_artifact_config.py +++ b/src/zenml/artifacts/external_artifact_config.py @@ -93,8 +93,6 @@ def _get_artifact_from_model( RuntimeError: If `model_artifact_name` is set, but `model_name` is empty and model configuration is missing in @step and @pipeline. """ - from zenml.model.model_config import ModelConfig - if self.model_name is None: if model_config is None: raise RuntimeError( @@ -104,13 +102,16 @@ def _get_artifact_from_model( ) self.model_name = model_config.name self.model_version = model_config.version - - _model_config = ModelConfig( - name=self.model_name, - version=self.model_version, - suppress_warnings=True, - ) - model_version = _model_config._get_model_version() + if ( + model_config is None + or self.model_name != model_config.name + or self.model_version != model_config.version + ): + model_config = ModelConfig( + name=self.model_name, + version=self.model_version, + ) + model_version = model_config._get_model_version() for artifact_getter in [ model_version.get_artifact_object, diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index 8e2f4f55640..5797ee563ab 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -19,11 +19,10 @@ from zenml.config.constants import DOCKER_SETTINGS_KEY from zenml.config.source import Source, convert_source_validator from zenml.config.strict_base_model import StrictBaseModel -from zenml.models.model_base_model import ModelConfigModel +from zenml.model.model_config import ModelConfig if TYPE_CHECKING: from zenml.config import DockerSettings - from zenml.model.model_config import ModelConfig from zenml.config.base_settings import BaseSettings, SettingsOrDict @@ -41,28 +40,12 @@ class PipelineConfigurationUpdate(StrictBaseModel): extra: Dict[str, Any] = {} failure_hook_source: Optional[Source] = None success_hook_source: Optional[Source] = None - model_config_model: Optional[ModelConfigModel] = None + model_config: Optional[ModelConfig] = None _convert_source = convert_source_validator( "failure_hook_source", "success_hook_source" ) - @property - def model_config(self) -> Optional["ModelConfig"]: - """Gets a ModelConfig object out of the model config model. - - This is a technical circular import resolver. - - Returns: - The model config object, if configured. - """ - if self.model_config_model is None: - return None - - from zenml.model.model_config import ModelConfig - - return ModelConfig.parse_obj(self.model_config_model.dict()) - class PipelineConfiguration(PipelineConfigurationUpdate): """Pipeline configuration class.""" diff --git a/src/zenml/config/step_configurations.py b/src/zenml/config/step_configurations.py index 3d3ecdd3ea1..cbb95d5c5ce 100644 --- a/src/zenml/config/step_configurations.py +++ b/src/zenml/config/step_configurations.py @@ -34,12 +34,11 @@ from zenml.config.source import Source, convert_source_validator from zenml.config.strict_base_model import StrictBaseModel from zenml.logger import get_logger -from zenml.models.model_base_model import ModelConfigModel +from zenml.model.model_config import ModelConfig from zenml.utils import deprecation_utils if TYPE_CHECKING: from zenml.config import DockerSettings, ResourceSettings - from zenml.model.model_config import ModelConfig logger = get_logger(__name__) @@ -135,7 +134,7 @@ class StepConfigurationUpdate(StrictBaseModel): extra: Dict[str, Any] = {} failure_hook_source: Optional[Source] = None success_hook_source: Optional[Source] = None - model_config_model: Optional[ModelConfigModel] = None + model_config: Optional[ModelConfig] = None outputs: Mapping[str, PartialArtifactConfiguration] = {} @@ -146,22 +145,6 @@ class StepConfigurationUpdate(StrictBaseModel): "name" ) - @property - def model_config(self) -> Optional["ModelConfig"]: - """Gets a ModelConfig object out of the model config model. - - This is a technical circular import resolver. - - Returns: - The model config object, if configured. - """ - if self.model_config_model is None: - return None - - from zenml.model.model_config import ModelConfig - - return ModelConfig.parse_obj(self.model_config_model.dict()) - class PartialStepConfiguration(StepConfigurationUpdate): """Class representing a partial step configuration.""" diff --git a/src/zenml/model/artifact_config.py b/src/zenml/model/artifact_config.py index dd9f2ae1656..1b723178aed 100644 --- a/src/zenml/model/artifact_config.py +++ b/src/zenml/model/artifact_config.py @@ -21,10 +21,6 @@ from zenml.enums import ModelStages from zenml.exceptions import StepContextError from zenml.logger import get_logger -from zenml.models.model_models import ( - ModelVersionArtifactFilterModel, - ModelVersionArtifactRequestModel, -) if TYPE_CHECKING: from zenml.model.model_config import ModelConfig @@ -87,7 +83,6 @@ def _model_config(self) -> "ModelConfig": name=self.model_name, version=self.model_version, create_new_model_version=False, - suppress_warnings=True, ) return on_the_fly_config @@ -134,6 +129,10 @@ def _link_to_model_version( is_deployment: Whether the artifact is a deployment. Defaults to False. """ from zenml.client import Client + from zenml.models.model_models import ( + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ) # Create a ZenML client client = Client() diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 4ce9264a07b..b6e310d5467 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -13,13 +13,21 @@ # permissions and limitations under the License. """ModelConfig user facing interface to pass into pipeline or step.""" -from typing import TYPE_CHECKING, Optional +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Union, +) -from pydantic import PrivateAttr +from pydantic import BaseModel, root_validator +from zenml.constants import RUNNING_MODEL_VERSION +from zenml.enums import ExecutionStatus, ModelStages from zenml.exceptions import EntityExistsError from zenml.logger import get_logger -from zenml.models.model_base_model import ModelConfigModel if TYPE_CHECKING: from zenml.models.model_models import ( @@ -30,24 +38,181 @@ logger = get_logger(__name__) -class ModelConfig(ModelConfigModel): - """ModelConfig class to pass into pipeline or step to set it into a model context. +class ModelConfig(BaseModel): + """ModelConfig class to pass into pipeline or step to set it into a model context.""" - name: The name of the model. - version: The model version name, number or stage is optional and points model context - to a specific version/stage, if skipped and `create_new_model_version` is False - - latest model version will be used. - version_description: The description of the model version. - create_new_model_version: Whether to create a new model version during execution - save_models_to_registry: Whether to save all ModelArtifacts to Model Registry, - if available in active stack. - delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. - """ + name: str + license: Optional[str] + description: Optional[str] + audience: Optional[str] + use_cases: Optional[str] + limitations: Optional[str] + trade_offs: Optional[str] + ethic: Optional[str] + tags: Optional[List[str]] + version: Optional[Union[ModelStages, int, str]] + version_description: Optional[str] + create_new_model_version: bool = False + save_models_to_registry: bool = True + delete_new_version_on_failure: bool = True - _model: Optional["ModelResponseModel"] = PrivateAttr(default=None) - _model_version: Optional["ModelVersionResponseModel"] = PrivateAttr( - default=None - ) + model: Optional[Any] = None + model_version: Optional[Any] = None + + def __init__( + self, + name: str, + license: Optional[str] = None, + description: Optional[str] = None, + audience: Optional[str] = None, + use_cases: Optional[str] = None, + limitations: Optional[str] = None, + trade_offs: Optional[str] = None, + ethic: Optional[str] = None, + tags: Optional[List[str]] = None, + version: Optional[Union[ModelStages, int, str]] = None, + version_description: Optional[str] = None, + create_new_model_version: bool = False, + save_models_to_registry: bool = True, + delete_new_version_on_failure: bool = True, + **kwargs: Dict[str, Any], + ): + """ModelConfig class to pass into pipeline or step to set it into a model context. + + Args: + name: The name of the model. + license: The license model created under. + description: The description of the model. + audience: The target audience of the model. + use_cases: The use cases of the model. + limitations: The know limitations of the model. + trade_offs: The trade offs of the model. + ethic: The ethical implications of the model. + tags: Tags associated with the model. + version: The model version name, number or stage is optional and points model context + to a specific version/stage, if skipped and `create_new_model_version` is False - + latest model version will be used. + version_description: The description of the model version. + create_new_model_version: Whether to create a new model version during execution + save_models_to_registry: Whether to save all ModelArtifacts to Model Registry, + if available in active stack. + delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. + kwargs: Other arguments. + """ + logger.error(f"INSTANTIATED MODEL CONFIG {name}/{version}!") + super().__init__( + name=name, + license=license, + description=description, + audience=audience, + use_cases=use_cases, + limitations=limitations, + trade_offs=trade_offs, + ethic=ethic, + tags=tags, + version=version, + version_description=version_description, + create_new_model_version=create_new_model_version, + save_models_to_registry=save_models_to_registry, + delete_new_version_on_failure=delete_new_version_on_failure, + ) + self.model = kwargs.get("model", None) + self.model_version = kwargs.get("model_version", None) + + class Config: + """Config class.""" + + smart_union = True + + @root_validator + def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate all in one. + + Args: + values: Dict of values. + + Returns: + Dict of validated values. + + Raises: + ValueError: If validation failed on one of the checks. + """ + create_new_model_version = values.get( + "create_new_model_version", False + ) + delete_new_version_on_failure = values.get( + "delete_new_version_on_failure", True + ) + if not delete_new_version_on_failure and not create_new_model_version: + logger.warning( + "Using `delete_new_version_on_failure=False` and `create_new_model_version=False` has no effect." + "Setting `delete_new_version_on_failure` to `True`." + ) + values["delete_new_version_on_failure"] = True + + version = values.get("version", None) + + if create_new_model_version: + misuse_message = ( + "`version` set to {set} cannot be used with `create_new_model_version`." + "You can leave it default or set to a non-stage and non-numeric string.\n" + "Examples:\n" + " - `version` set to 1 or '1' is interpreted as a version number\n" + " - `version` set to 'production' is interpreted as a stage\n" + " - `version` set to 'my_first_version_in_2023' is a valid version to be created\n" + " - `version` set to 'My Second Version!' is a valid version to be created\n" + ) + if isinstance(version, ModelStages) or version in [ + stage.value for stage in ModelStages + ]: + raise ValueError( + misuse_message.format(set="a `ModelStages` instance") + ) + if str(version).isnumeric(): + raise ValueError(misuse_message.format(set="a numeric value")) + if version is None: + logger.info( + "Creation of new model version was requested, but no version name was explicitly provided. " + f"Setting `version` to `{RUNNING_MODEL_VERSION}`." + ) + values["version"] = RUNNING_MODEL_VERSION + if version in [stage.value for stage in ModelStages]: + logger.info( + f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage." + ) + if str(version).isnumeric(): + logger.info( + f"`version` `{version}` is numeric and will be fetched using version number." + ) + return values + + def _validate_config_in_runtime(self) -> None: + """Validate that config doesn't conflict with runtime environment. + + Raises: + RuntimeError: If recovery not requested, but model version already exists. + RuntimeError: If there is unfinished pipeline run for requested new model version. + """ + try: + model_version = self._get_model_version() + for run_name, run in model_version.pipeline_runs.items(): + if run.status == ExecutionStatus.RUNNING: + raise RuntimeError( + f"New model version was requested, but pipeline run `{run_name}` " + f"is still running with version `{model_version.name}`." + ) + + if ( + self.create_new_model_version + and not self.delete_new_version_on_failure + ): + raise RuntimeError( + f"Cannot create version `{self.version}` " + f"for model `{self.name}` since it already exists" + ) + except KeyError: + pass + self.get_or_create_model_version() def get_or_create_model(self) -> "ModelResponseModel": """This method should get or create a model from Model Control Plane. @@ -57,15 +222,19 @@ def get_or_create_model(self) -> "ModelResponseModel": Returns: The model based on configuration. """ - if self._model is not None: - return self._model + from zenml.models.model_models import ModelResponseModel + + if self.model is not None: + return ModelResponseModel(**dict(self.model)) from zenml.client import Client from zenml.models.model_models import ModelRequestModel + logger.error(f"RETRIEVED MODEL {self.name}!") + zenml_client = Client() try: - self._model = zenml_client.get_model(model_name_or_id=self.name) + self.model = zenml_client.get_model(model_name_or_id=self.name) except KeyError: model_request = ModelRequestModel( name=self.name, @@ -82,15 +251,13 @@ def get_or_create_model(self) -> "ModelResponseModel": ) model_request = ModelRequestModel.parse_obj(model_request) try: - self._model = zenml_client.create_model(model=model_request) + self.model = zenml_client.create_model(model=model_request) logger.info(f"New model `{self.name}` was created implicitly.") except EntityExistsError: # this is backup logic, if model was created somehow in between get and create calls - self._model = zenml_client.get_model( - model_name_or_id=self.name - ) + self.model = zenml_client.get_model(model_name_or_id=self.name) - return self._model + return self.model def _create_model_version( self, model: "ModelResponseModel" @@ -103,12 +270,15 @@ def _create_model_version( Returns: The model version based on configuration. """ - if self._model_version is not None: - return self._model_version + if self.model_version is not None: + from zenml.models.model_models import ModelVersionResponseModel + + return ModelVersionResponseModel(**dict(self.model_version)) from zenml.client import Client from zenml.models.model_models import ModelVersionRequestModel + logger.error(f"CREATED VERSION {self.version}!") zenml_client = Client() model_version_request = ModelVersionRequestModel( user=zenml_client.active_user.id, @@ -123,14 +293,14 @@ def _create_model_version( model_name_or_id=self.name, model_version_name_or_number_or_id=self.version, ) - self._model_version = mv + self.model_version = mv except KeyError: - self._model_version = zenml_client.create_model_version( + self.model_version = zenml_client.create_model_version( model_version=mv_request ) logger.info(f"New model version `{self.version}` was created.") - return self._model_version + return self.model_version def _get_model_version(self) -> "ModelVersionResponseModel": """This method gets a model version from Model Control Plane. @@ -138,25 +308,28 @@ def _get_model_version(self) -> "ModelVersionResponseModel": Returns: The model version based on configuration. """ - if self._model_version is not None: - return self._model_version + if self.model_version is not None: + from zenml.models.model_models import ModelVersionResponseModel + + return ModelVersionResponseModel(**dict(self.model_version)) from zenml.client import Client + logger.error(f"RETRIEVED VERSION {self.version}!") zenml_client = Client() if self.version is None: # raise if not found - self._model_version = zenml_client.get_model_version( + self.model_version = zenml_client.get_model_version( model_name_or_id=self.name ) else: # by version name or stage or number # raise if not found - self._model_version = zenml_client.get_model_version( + self.model_version = zenml_client.get_model_version( model_name_or_id=self.name, model_version_name_or_number_or_id=self.version, ) - return self._model_version + return self.model_version def get_or_create_model_version(self) -> "ModelVersionResponseModel": """This method should get or create a model and a model version from Model Control Plane. @@ -184,7 +357,7 @@ def get_or_create_model_version(self) -> "ModelVersionResponseModel": mv = self._get_model_version() return mv - def _merge_with_config(self, model_config: ModelConfigModel) -> None: + def _merge(self, model_config: "ModelConfig") -> None: self.license = self.license or model_config.license self.description = self.description or model_config.description self.audience = self.audience or model_config.audience diff --git a/src/zenml/models/model_base_model.py b/src/zenml/models/model_base_model.py index e568dd31827..69bb1105bc1 100644 --- a/src/zenml/models/model_base_model.py +++ b/src/zenml/models/model_base_model.py @@ -13,14 +13,10 @@ # permissions and limitations under the License. """Model base model to support Model Control Plane feature.""" -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field -from zenml.constants import ( - RUNNING_MODEL_VERSION, -) -from zenml.enums import ModelStages from zenml.logger import get_logger from zenml.models.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH @@ -65,99 +61,3 @@ class ModelBaseModel(BaseModel): tags: Optional[List[str]] = Field( title="Tags associated with the model", ) - - -class ModelConfigModel(ModelBaseModel): - """ModelConfig class to pass into pipeline or step to set it into a model context. - - name: The name of the model. - version: The model version name, number or stage is optional and points model context - to a specific version/stage, if skipped and `create_new_model_version` is False - - latest model version will be used. - version_description: The description of the model version. - create_new_model_version: Whether to create a new model version during execution - save_models_to_registry: Whether to save all ModelArtifacts to Model Registry, - if available in active stack. - delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. - suppress_warnings: Whether to suppress warnings during validation. - """ - - version: Optional[Union[ModelStages, int, str]] - version_description: Optional[str] - create_new_model_version: bool = False - save_models_to_registry: bool = True - delete_new_version_on_failure: bool = True - suppress_warnings: bool = False - - class Config: - """Config class.""" - - smart_union = True - - @root_validator - def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Validate all in one. - - Args: - values: Dict of values. - - Returns: - Dict of validated values. - - Raises: - ValueError: If validation failed on one of the checks. - """ - create_new_model_version = values.get( - "create_new_model_version", False - ) - delete_new_version_on_failure = values.get( - "delete_new_version_on_failure", True - ) - suppress_warnings = values.get("suppress_warnings", False) - if not delete_new_version_on_failure and not create_new_model_version: - if not suppress_warnings: - logger.warning( - "Using `delete_new_version_on_failure=False` and `create_new_model_version=False` has no effect." - "Setting `delete_new_version_on_failure` to `True`." - ) - values["delete_new_version_on_failure"] = True - - version = values.get("version", None) - - if create_new_model_version: - misuse_message = ( - "`version` set to {set} cannot be used with `create_new_model_version`." - "You can leave it default or set to a non-stage and non-numeric string.\n" - "Examples:\n" - " - `version` set to 1 or '1' is interpreted as a version number\n" - " - `version` set to 'production' is interpreted as a stage\n" - " - `version` set to 'my_first_version_in_2023' is a valid version to be created\n" - " - `version` set to 'My Second Version!' is a valid version to be created\n" - ) - if isinstance(version, ModelStages) or version in [ - stage.value for stage in ModelStages - ]: - raise ValueError( - misuse_message.format(set="a `ModelStages` instance") - ) - if str(version).isnumeric(): - raise ValueError(misuse_message.format(set="a numeric value")) - if version is None: - if not suppress_warnings: - logger.info( - "Creation of new model version was requested, but no version name was explicitly provided. " - f"Setting `version` to `{RUNNING_MODEL_VERSION}`." - ) - values["version"] = RUNNING_MODEL_VERSION - if ( - version in [stage.value for stage in ModelStages] - and not suppress_warnings - ): - logger.info( - f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage." - ) - if str(version).isnumeric() and not suppress_warnings: - logger.info( - f"`version` `{version}` is numeric and will be fetched using version number." - ) - return values diff --git a/src/zenml/new/pipelines/model_utils.py b/src/zenml/new/pipelines/model_utils.py index a9328f525f6..2ff509adaaf 100644 --- a/src/zenml/new/pipelines/model_utils.py +++ b/src/zenml/new/pipelines/model_utils.py @@ -18,7 +18,6 @@ from pydantic import BaseModel, PrivateAttr from zenml.model.model_config import ModelConfig -from zenml.models.model_base_model import ModelConfigModel class NewModelVersionRequest(BaseModel): @@ -57,7 +56,7 @@ def model_config(self) -> ModelConfig: def update_request( self, - model_config: ModelConfigModel, + model_config: ModelConfig, requester: "NewModelVersionRequest.Requester", ) -> None: """Update from Model Config Model object in place. @@ -71,7 +70,7 @@ def update_request( """ self.requesters.append(requester) if self._model_config is None: - self._model_config = ModelConfig.parse_obj(model_config) + self._model_config = model_config if self._model_config.version != model_config.version: raise ValueError( @@ -79,4 +78,4 @@ def update_request( "Since a new model version is requested for this model, all `version` names must match or left default." ) - self._model_config._merge_with_config(model_config) + self._model_config._merge(model_config) diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 68a9c4d050b..ea2a77831d4 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -53,7 +53,7 @@ from zenml.config.pipeline_spec import PipelineSpec from zenml.config.schedule import Schedule from zenml.config.step_configurations import StepConfigurationUpdate -from zenml.enums import ExecutionStatus, StackComponentType +from zenml.enums import StackComponentType from zenml.hooks.hook_validators import resolve_and_validate_hook from zenml.logger import get_logger from zenml.models import ( @@ -66,7 +66,6 @@ PipelineRunResponseModel, ScheduleRequestModel, ) -from zenml.models.model_base_model import ModelConfigModel from zenml.models.pipeline_build_models import ( PipelineBuildBaseModel, ) @@ -354,8 +353,6 @@ def configure( # string of on_success hook function to be used for this pipeline success_hook_source = resolve_and_validate_hook(on_success) - if model_config: - model_config.suppress_warnings = True values = dict_utils.remove_none_values( { "enable_cache": enable_cache, @@ -366,18 +363,12 @@ def configure( "extra": extra, "failure_hook_source": failure_hook_source, "success_hook_source": success_hook_source, - "model_config_model": ModelConfigModel.parse_obj( - model_config.dict() - ) - if model_config is not None - else None, + "model_config": model_config, } ) if not self.__suppress_warnings_flag__: to_be_reapplied = [] for param_, value_ in values.items(): - if param_ == "model_config_model": - param_ = "model_config" if ( param_ in PipelineRunConfiguration.__fields__ and param_ in self._from_config_file @@ -827,10 +818,10 @@ def get_new_version_requests( ] = defaultdict(NewModelVersionRequest) all_steps_have_own_config = True for step in deployment.step_configurations.values(): - step_model_config = step.config.model_config_model + step_model_config = step.config.model_config all_steps_have_own_config = ( all_steps_have_own_config - and step.config.model_config_model is not None + and step.config.model_config is not None ) if ( step_model_config @@ -844,7 +835,7 @@ def get_new_version_requests( ) if not all_steps_have_own_config: pipeline_model_config = ( - deployment.pipeline_configuration.model_config_model + deployment.pipeline_configuration.model_config ) if ( pipeline_model_config @@ -858,7 +849,7 @@ def get_new_version_requests( source="pipeline", name=self.name ), ) - elif deployment.pipeline_configuration.model_config_model is not None: + elif deployment.pipeline_configuration.model_config is not None: logger.warning( f"ModelConfig of pipeline `{self.name}` is overridden in all steps. " ) @@ -875,11 +866,6 @@ def _validate_new_version_requests( Args: new_versions_requested: A dict of new model version request objects. - - Raises: - RuntimeError: If there is unfinished pipeline run for requested model version. - RuntimeError: If recovery not requested, but model version already exists. - """ for model_name, data in new_versions_requested.items(): if len(data.requesters) > 1: @@ -888,25 +874,7 @@ def _validate_new_version_requests( f"{data.requesters}\n We recommend that `create_new_model_version` is configured " "only in one place of the pipeline." ) - try: - model_version = data.model_config._get_model_version() - - for run_name, run in model_version.pipeline_runs.items(): - if run.status == ExecutionStatus.RUNNING: - raise RuntimeError( - f"New model version was requested, but pipeline run `{run_name}` " - f"is still running with version `{model_version.name}`." - ) - if ( - data.model_config.version - and data.model_config.delete_new_version_on_failure - ): - raise RuntimeError( - f"Cannot create version `{data.model_config.version}` " - f"for model `{data.model_config.name}` since it already exists" - ) - except KeyError: - pass + data.model_config._validate_config_in_runtime() def update_new_versions_requests( self, @@ -928,7 +896,7 @@ def update_new_versions_requests( for step_name in deployment.step_configurations: step_model_config = deployment.step_configurations[ step_name - ].config.model_config_model + ].config.model_config if ( step_model_config is not None and step_model_config.name in new_version_requests @@ -937,9 +905,7 @@ def update_new_versions_requests( step_model_config.name ].model_config.version step_model_config.create_new_model_version = True - pipeline_model_config = ( - deployment.pipeline_configuration.model_config_model - ) + pipeline_model_config = deployment.pipeline_configuration.model_config if ( pipeline_model_config is not None and pipeline_model_config.name in new_version_requests diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index fa453251014..e9c75f167c4 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -699,7 +699,6 @@ def configure( """ from zenml.config.step_configurations import StepConfigurationUpdate from zenml.hooks.hook_validators import resolve_and_validate_hook - from zenml.models.model_base_model import ModelConfigModel if name: logger.warning("Configuring the name of a step is deprecated.") @@ -748,8 +747,6 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: if isinstance(parameters, BaseParameters): parameters = parameters.dict() - if model_config: - model_config.suppress_warnings = True values = dict_utils.remove_none_values( { "enable_cache": enable_cache, @@ -764,11 +761,7 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: "extra": extra, "failure_hook_source": failure_hook_source, "success_hook_source": success_hook_source, - "model_config_model": ModelConfigModel.parse_obj( - model_config.dict() - ) - if model_config is not None - else None, + "model_config": model_config, } ) config = StepConfigurationUpdate(**values) diff --git a/src/zenml/utils/pydantic_utils.py b/src/zenml/utils/pydantic_utils.py index e02c8127cdc..9e352b02a0d 100644 --- a/src/zenml/utils/pydantic_utils.py +++ b/src/zenml/utils/pydantic_utils.py @@ -54,8 +54,12 @@ def update_model( update_dict = update else: update_dict = update.dict(exclude_unset=True) + if "model_config" in update_dict: + update_dict["model_config"] = getattr(update, "model_config") original_dict = original.dict(exclude_unset=True) + if "model_config" in original_dict: + original_dict["model_config"] = getattr(original, "model_config") if recursive: values = dict_utils.recursive_update(original_dict, update_dict) else: diff --git a/tests/unit/model/test_model_config.py b/tests/unit/model/test_model_config.py new file mode 100644 index 00000000000..ba4b6943c75 --- /dev/null +++ b/tests/unit/model/test_model_config.py @@ -0,0 +1,66 @@ +from unittest.mock import patch + +import pytest + +from zenml.enums import ModelStages +from zenml.model import ModelConfig + + +@pytest.mark.parametrize( + "version_name,create_new_model_version,delete_new_version_on_failure,logger", + [ + [None, False, False, "warning"], + [None, True, False, "info"], + ["staging", False, False, "info"], + ["1", False, False, "info"], + [1, False, False, "info"], + ], + ids=[ + "No new version, but recovery", + "Default running version", + "Pick model by text stage", + "Pick model by text version number", + "Pick model by integer version number", + ], +) +def test_init_warns( + version_name, + create_new_model_version, + delete_new_version_on_failure, + logger, +): + with patch(f"zenml.model.model_config.logger.{logger}") as logger: + ModelConfig( + name="foo", + version=version_name, + create_new_model_version=create_new_model_version, + delete_new_version_on_failure=delete_new_version_on_failure, + ) + logger.assert_called_once() + + +@pytest.mark.parametrize( + "version_name,create_new_model_version", + [ + [1, True], + ["1", True], + [ModelStages.PRODUCTION, True], + ["production", True], + ], + ids=[ + "Version number as integer and new version request", + "Version number as string and new version request", + "Version stage as instance and new version request", + "Version stage as string and new version request", + ], +) +def test_init_raises( + version_name, + create_new_model_version, +): + with pytest.raises(ValueError): + ModelConfig( + name="foo", + version=version_name, + create_new_model_version=create_new_model_version, + ) diff --git a/tests/unit/models/test_model_models.py b/tests/unit/models/test_model_models.py index 4e2f5a89c36..9bb562fe360 100644 --- a/tests/unit/models/test_model_models.py +++ b/tests/unit/models/test_model_models.py @@ -19,8 +19,6 @@ import pytest from tests.unit.steps.test_external_artifact import MockZenmlClient -from zenml.enums import ModelStages -from zenml.models.model_base_model import ModelConfigModel from zenml.models.model_models import ( ModelResponseModel, ModelVersionResponseModel, @@ -181,63 +179,3 @@ def test_getters( step_name=query_step, version=query_version, ) - - -@pytest.mark.parametrize( - "version_name,create_new_model_version,delete_new_version_on_failure,logger", - [ - [None, False, False, "warning"], - [None, True, False, "info"], - ["staging", False, False, "info"], - ["1", False, False, "info"], - [1, False, False, "info"], - ], - ids=[ - "No new version, but recovery", - "Default running version", - "Pick model by text stage", - "Pick model by text version number", - "Pick model by integer version number", - ], -) -def test_init_warns( - version_name, - create_new_model_version, - delete_new_version_on_failure, - logger, -): - with patch(f"zenml.models.model_base_model.logger.{logger}") as logger: - ModelConfigModel( - name="foo", - version=version_name, - create_new_model_version=create_new_model_version, - delete_new_version_on_failure=delete_new_version_on_failure, - ) - logger.assert_called_once() - - -@pytest.mark.parametrize( - "version_name,create_new_model_version", - [ - [1, True], - ["1", True], - [ModelStages.PRODUCTION, True], - ["production", True], - ], - ids=[ - "Version number as integer and new version request", - "Version number as string and new version request", - "Version stage as instance and new version request", - "Version stage as string and new version request", - ], -) -def test_init_raises( - version_name, - create_new_model_version, -): - with pytest.raises(ValueError): - ModelConfigModel( - name="foo", - version=version_name, - create_new_model_version=create_new_model_version, - ) From 82d69d34e3ded039dc52fe33405225268d949402 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 12:52:51 +0200 Subject: [PATCH 02/41] revert config names for backward compatibility --- src/zenml/config/pipeline_configurations.py | 2 +- src/zenml/config/step_configurations.py | 2 +- src/zenml/new/pipelines/pipeline.py | 16 ++++++---- src/zenml/new/steps/step_context.py | 8 ++--- src/zenml/orchestrators/step_launcher.py | 4 +-- src/zenml/steps/base_step.py | 2 +- .../pipelines/test_pipeline_config.py | 32 +++++++++---------- 7 files changed, 34 insertions(+), 32 deletions(-) diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index 5797ee563ab..4c2e5df295f 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -40,7 +40,7 @@ class PipelineConfigurationUpdate(StrictBaseModel): extra: Dict[str, Any] = {} failure_hook_source: Optional[Source] = None success_hook_source: Optional[Source] = None - model_config: Optional[ModelConfig] = None + model_config_model: Optional[ModelConfig] = None _convert_source = convert_source_validator( "failure_hook_source", "success_hook_source" diff --git a/src/zenml/config/step_configurations.py b/src/zenml/config/step_configurations.py index cbb95d5c5ce..d6cdb5eeb2d 100644 --- a/src/zenml/config/step_configurations.py +++ b/src/zenml/config/step_configurations.py @@ -134,7 +134,7 @@ class StepConfigurationUpdate(StrictBaseModel): extra: Dict[str, Any] = {} failure_hook_source: Optional[Source] = None success_hook_source: Optional[Source] = None - model_config: Optional[ModelConfig] = None + model_config_model: Optional[ModelConfig] = None outputs: Mapping[str, PartialArtifactConfiguration] = {} diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index ea2a77831d4..10cb28d58ac 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -363,7 +363,7 @@ def configure( "extra": extra, "failure_hook_source": failure_hook_source, "success_hook_source": success_hook_source, - "model_config": model_config, + "model_config_model": model_config, } ) if not self.__suppress_warnings_flag__: @@ -818,10 +818,10 @@ def get_new_version_requests( ] = defaultdict(NewModelVersionRequest) all_steps_have_own_config = True for step in deployment.step_configurations.values(): - step_model_config = step.config.model_config + step_model_config = step.config.model_config_model all_steps_have_own_config = ( all_steps_have_own_config - and step.config.model_config is not None + and step.config.model_config_model is not None ) if ( step_model_config @@ -835,7 +835,7 @@ def get_new_version_requests( ) if not all_steps_have_own_config: pipeline_model_config = ( - deployment.pipeline_configuration.model_config + deployment.pipeline_configuration.model_config_model ) if ( pipeline_model_config @@ -849,7 +849,7 @@ def get_new_version_requests( source="pipeline", name=self.name ), ) - elif deployment.pipeline_configuration.model_config is not None: + elif deployment.pipeline_configuration.model_config_model is not None: logger.warning( f"ModelConfig of pipeline `{self.name}` is overridden in all steps. " ) @@ -896,7 +896,7 @@ def update_new_versions_requests( for step_name in deployment.step_configurations: step_model_config = deployment.step_configurations[ step_name - ].config.model_config + ].config.model_config_model if ( step_model_config is not None and step_model_config.name in new_version_requests @@ -905,7 +905,9 @@ def update_new_versions_requests( step_model_config.name ].model_config.version step_model_config.create_new_model_version = True - pipeline_model_config = deployment.pipeline_configuration.model_config + pipeline_model_config = ( + deployment.pipeline_configuration.model_config_model + ) if ( pipeline_model_config is not None and pipeline_model_config.name in new_version_requests diff --git a/src/zenml/new/steps/step_context.py b/src/zenml/new/steps/step_context.py index 56b6ebd85cd..b2e6573b1eb 100644 --- a/src/zenml/new/steps/step_context.py +++ b/src/zenml/new/steps/step_context.py @@ -218,10 +218,10 @@ def model_config(self) -> "ModelConfig": Raises: StepContextError: If the `ModelConfig` object is not set in `@step` or `@pipeline`. """ - if self.step_run.config.model_config is not None: - return self.step_run.config.model_config - if self.pipeline_run.config.model_config is not None: - return self.pipeline_run.config.model_config + if self.step_run.config.model_config_model is not None: + return self.step_run.config.model_config_model + if self.pipeline_run.config.model_config_model is not None: + return self.pipeline_run.config.model_config_model raise StepContextError( f"Unable to get ModelConfig in step '{self.step_name}' of pipeline " f"run '{self.pipeline_run.id}': It was not set in `@step` or `@pipeline`." diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index b3072aef022..efc44a74c83 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -331,8 +331,8 @@ def _prepare( model_config = ( self._deployment.step_configurations[ step_run.name - ].config.model_config - or self._deployment.pipeline_configuration.model_config + ].config.model_config_model + or self._deployment.pipeline_configuration.model_config_model ) input_artifacts, parent_step_ids = input_utils.resolve_step_inputs( step=self._step, diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index e9c75f167c4..b005161f7c1 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -761,7 +761,7 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: "extra": extra, "failure_hook_source": failure_hook_source, "success_hook_source": success_hook_source, - "model_config": model_config, + "model_config_model": model_config, } ) config = StepConfigurationUpdate(**values) diff --git a/tests/integration/functional/pipelines/test_pipeline_config.py b/tests/integration/functional/pipelines/test_pipeline_config.py index 6173da1a060..135f2c5c0c4 100644 --- a/tests/integration/functional/pipelines/test_pipeline_config.py +++ b/tests/integration/functional/pipelines/test_pipeline_config.py @@ -132,7 +132,7 @@ def assert_model_config_pipeline(): assert_model_config_step() p = assert_model_config_pipeline.with_options(config_path=str(config_path)) - assert p.configuration.model_config.name == "bar" + assert p.configuration.model_config_model.name == "bar" with patch("zenml.new.pipelines.pipeline.logger.warning") as warning: p.configure( @@ -154,24 +154,24 @@ def assert_model_config_pipeline(): ) warning.assert_called_once() - assert p.configuration.model_config is not None - assert p.configuration.model_config.name == "foo" - assert p.configuration.model_config.version == RUNNING_MODEL_VERSION - assert p.configuration.model_config.create_new_model_version - assert not p.configuration.model_config.delete_new_version_on_failure - assert p.configuration.model_config.description == "description" - assert p.configuration.model_config.license == "MIT" - assert p.configuration.model_config.audience == "audience" - assert p.configuration.model_config.use_cases == "use_cases" - assert p.configuration.model_config.limitations == "limitations" - assert p.configuration.model_config.trade_offs == "trade_offs" - assert p.configuration.model_config.ethic == "ethic" - assert p.configuration.model_config.tags == ["tag"] + assert p.configuration.model_config_model is not None + assert p.configuration.model_config_model.name == "foo" + assert p.configuration.model_config_model.version == RUNNING_MODEL_VERSION + assert p.configuration.model_config_model.create_new_model_version + assert not p.configuration.model_config_model.delete_new_version_on_failure + assert p.configuration.model_config_model.description == "description" + assert p.configuration.model_config_model.license == "MIT" + assert p.configuration.model_config_model.audience == "audience" + assert p.configuration.model_config_model.use_cases == "use_cases" + assert p.configuration.model_config_model.limitations == "limitations" + assert p.configuration.model_config_model.trade_offs == "trade_offs" + assert p.configuration.model_config_model.ethic == "ethic" + assert p.configuration.model_config_model.tags == ["tag"] assert ( - p.configuration.model_config.version_description + p.configuration.model_config_model.version_description == "version_description" ) - assert p.configuration.model_config.save_models_to_registry + assert p.configuration.model_config_model.save_models_to_registry with pytest.raises(AssertionError): p() From f9ba41cb59ecc829450be57c65d027a9f6bf3e10 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 12:56:15 +0200 Subject: [PATCH 03/41] revert utils change --- src/zenml/model/model_config.py | 1 + src/zenml/utils/pydantic_utils.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index b6e310d5467..d96e519439e 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -194,6 +194,7 @@ def _validate_config_in_runtime(self) -> None: RuntimeError: If there is unfinished pipeline run for requested new model version. """ try: + logger.error(f"VALIDATED MODEL {self.name}!") model_version = self._get_model_version() for run_name, run in model_version.pipeline_runs.items(): if run.status == ExecutionStatus.RUNNING: diff --git a/src/zenml/utils/pydantic_utils.py b/src/zenml/utils/pydantic_utils.py index 9e352b02a0d..e02c8127cdc 100644 --- a/src/zenml/utils/pydantic_utils.py +++ b/src/zenml/utils/pydantic_utils.py @@ -54,12 +54,8 @@ def update_model( update_dict = update else: update_dict = update.dict(exclude_unset=True) - if "model_config" in update_dict: - update_dict["model_config"] = getattr(update, "model_config") original_dict = original.dict(exclude_unset=True) - if "model_config" in original_dict: - original_dict["model_config"] = getattr(original, "model_config") if recursive: values = dict_utils.recursive_update(original_dict, update_dict) else: From 8e6f82af3ecb0f0871f08638320b7d90795e8027 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 13:46:16 +0200 Subject: [PATCH 04/41] fix ExternalArtifact --- .../artifacts/external_artifact_config.py | 2 ++ .../steps/test_external_artifact.py | 29 ++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/zenml/artifacts/external_artifact_config.py b/src/zenml/artifacts/external_artifact_config.py index 90d279a8d73..4c32493472e 100644 --- a/src/zenml/artifacts/external_artifact_config.py +++ b/src/zenml/artifacts/external_artifact_config.py @@ -107,6 +107,8 @@ def _get_artifact_from_model( or self.model_name != model_config.name or self.model_version != model_config.version ): + from zenml.model.model_config import ModelConfig + model_config = ModelConfig( name=self.model_name, version=self.model_version, diff --git a/tests/integration/functional/steps/test_external_artifact.py b/tests/integration/functional/steps/test_external_artifact.py index 98164c80f29..2f6b92f77d0 100644 --- a/tests/integration/functional/steps/test_external_artifact.py +++ b/tests/integration/functional/steps/test_external_artifact.py @@ -62,6 +62,25 @@ def consumer_pipeline( ) +@pipeline(name="bar", enable_cache=False) +def consumer_pipeline_with_external_artifact_from_another_model( + model_artifact_version: int, + model_artifact_pipeline_name: str = None, + model_artifact_step_name: str = None, +): + consumer( + ExternalArtifact( + model_name="foo", + model_version=1, + model_artifact_name="predictions", + model_artifact_version=model_artifact_version, + model_artifact_pipeline_name=model_artifact_pipeline_name, + model_artifact_step_name=model_artifact_step_name, + ), + model_artifact_version, + ) + + @pipeline( name="bar", enable_cache=False, @@ -72,7 +91,15 @@ def two_step_producer_pipeline(): producer(1) -def test_exchange_of_model_artifacts_between_pipelines(): +@pytest.mark.parametrize( + "consumer_pipeline", + [ + consumer_pipeline, + consumer_pipeline_with_external_artifact_from_another_model, + ], + ids=["model context given", "no model context"], +) +def test_exchange_of_model_artifacts_between_pipelines(consumer_pipeline): """Test that ExternalArtifact helps to exchange data from Model between pipelines.""" with model_killer(): producer_pipeline.with_options( From 6247d03a7147a3acfaec1d0fb163638bad468d96 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 13:49:44 +0200 Subject: [PATCH 05/41] backward compatibility --- src/zenml/new/pipelines/pipeline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 10cb28d58ac..938411d4a48 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -369,6 +369,8 @@ def configure( if not self.__suppress_warnings_flag__: to_be_reapplied = [] for param_, value_ in values.items(): + if param_ == "model_config_model": + param_ = "model_config" if ( param_ in PipelineRunConfiguration.__fields__ and param_ in self._from_config_file From 90efe1bba8de9292eeef0edaebe5de4fc740a803 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 14:06:27 +0200 Subject: [PATCH 06/41] `model_config_model` -> `model_config` --- src/zenml/config/pipeline_configurations.py | 2 +- src/zenml/config/step_configurations.py | 2 +- src/zenml/new/pipelines/pipeline.py | 14 ++-- src/zenml/new/steps/step_context.py | 8 +-- src/zenml/orchestrators/step_launcher.py | 4 +- ...bb9_rename_model_config_model_to_model_.py | 64 +++++++++++++++++++ .../pipelines/test_pipeline_config.py | 32 +++++----- 7 files changed, 94 insertions(+), 32 deletions(-) create mode 100644 src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index 4c2e5df295f..5797ee563ab 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -40,7 +40,7 @@ class PipelineConfigurationUpdate(StrictBaseModel): extra: Dict[str, Any] = {} failure_hook_source: Optional[Source] = None success_hook_source: Optional[Source] = None - model_config_model: Optional[ModelConfig] = None + model_config: Optional[ModelConfig] = None _convert_source = convert_source_validator( "failure_hook_source", "success_hook_source" diff --git a/src/zenml/config/step_configurations.py b/src/zenml/config/step_configurations.py index d6cdb5eeb2d..cbb95d5c5ce 100644 --- a/src/zenml/config/step_configurations.py +++ b/src/zenml/config/step_configurations.py @@ -134,7 +134,7 @@ class StepConfigurationUpdate(StrictBaseModel): extra: Dict[str, Any] = {} failure_hook_source: Optional[Source] = None success_hook_source: Optional[Source] = None - model_config_model: Optional[ModelConfig] = None + model_config: Optional[ModelConfig] = None outputs: Mapping[str, PartialArtifactConfiguration] = {} diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 938411d4a48..aa2df595159 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -820,10 +820,10 @@ def get_new_version_requests( ] = defaultdict(NewModelVersionRequest) all_steps_have_own_config = True for step in deployment.step_configurations.values(): - step_model_config = step.config.model_config_model + step_model_config = step.config.model_config all_steps_have_own_config = ( all_steps_have_own_config - and step.config.model_config_model is not None + and step.config.model_config is not None ) if ( step_model_config @@ -837,7 +837,7 @@ def get_new_version_requests( ) if not all_steps_have_own_config: pipeline_model_config = ( - deployment.pipeline_configuration.model_config_model + deployment.pipeline_configuration.model_config ) if ( pipeline_model_config @@ -851,7 +851,7 @@ def get_new_version_requests( source="pipeline", name=self.name ), ) - elif deployment.pipeline_configuration.model_config_model is not None: + elif deployment.pipeline_configuration.model_config is not None: logger.warning( f"ModelConfig of pipeline `{self.name}` is overridden in all steps. " ) @@ -898,7 +898,7 @@ def update_new_versions_requests( for step_name in deployment.step_configurations: step_model_config = deployment.step_configurations[ step_name - ].config.model_config_model + ].config.model_config if ( step_model_config is not None and step_model_config.name in new_version_requests @@ -907,9 +907,7 @@ def update_new_versions_requests( step_model_config.name ].model_config.version step_model_config.create_new_model_version = True - pipeline_model_config = ( - deployment.pipeline_configuration.model_config_model - ) + pipeline_model_config = deployment.pipeline_configuration.model_config if ( pipeline_model_config is not None and pipeline_model_config.name in new_version_requests diff --git a/src/zenml/new/steps/step_context.py b/src/zenml/new/steps/step_context.py index b2e6573b1eb..56b6ebd85cd 100644 --- a/src/zenml/new/steps/step_context.py +++ b/src/zenml/new/steps/step_context.py @@ -218,10 +218,10 @@ def model_config(self) -> "ModelConfig": Raises: StepContextError: If the `ModelConfig` object is not set in `@step` or `@pipeline`. """ - if self.step_run.config.model_config_model is not None: - return self.step_run.config.model_config_model - if self.pipeline_run.config.model_config_model is not None: - return self.pipeline_run.config.model_config_model + if self.step_run.config.model_config is not None: + return self.step_run.config.model_config + if self.pipeline_run.config.model_config is not None: + return self.pipeline_run.config.model_config raise StepContextError( f"Unable to get ModelConfig in step '{self.step_name}' of pipeline " f"run '{self.pipeline_run.id}': It was not set in `@step` or `@pipeline`." diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index efc44a74c83..b3072aef022 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -331,8 +331,8 @@ def _prepare( model_config = ( self._deployment.step_configurations[ step_run.name - ].config.model_config_model - or self._deployment.pipeline_configuration.model_config_model + ].config.model_config + or self._deployment.pipeline_configuration.model_config ) input_artifacts, parent_step_ids = input_utils.resolve_step_inputs( step=self._step, diff --git a/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py b/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py new file mode 100644 index 00000000000..6300867134c --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py @@ -0,0 +1,64 @@ +"""rename model_config_model to model_config in pipeline and step configs [4f66af55fbb9]. + +Revision ID: 4f66af55fbb9 +Revises: 0.45.2 +Create Date: 2023-10-17 13:57:35.810054 + +""" +from alembic import op +from sqlalchemy.sql import text + +# revision identifiers, used by Alembic. +revision = "4f66af55fbb9" +down_revision = "0.45.2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + connection = op.get_bind() + + update_config_fields = text( + """ + UPDATE pipeline_deployment + SET pipeline_configuration = REPLACE( + pipeline_configuration, + '"model_config_model"', + '"model_config"' + ), + step_configurations = REPLACE( + step_configurations, + '"model_config_model"', + '"model_config"' + ) + """ + ) + connection.execute(update_config_fields) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + connection = op.get_bind() + + update_config_fields = text( + """ + UPDATE pipeline_deployment + SET pipeline_configuration = REPLACE( + pipeline_configuration, + '"model_config"', + '"model_config_model"' + ), + step_configurations = REPLACE( + step_configurations, + '"model_config"', + '"model_config_model"' + ) + """ + ) + connection.execute(update_config_fields) + + # ### end Alembic commands ### diff --git a/tests/integration/functional/pipelines/test_pipeline_config.py b/tests/integration/functional/pipelines/test_pipeline_config.py index 135f2c5c0c4..6173da1a060 100644 --- a/tests/integration/functional/pipelines/test_pipeline_config.py +++ b/tests/integration/functional/pipelines/test_pipeline_config.py @@ -132,7 +132,7 @@ def assert_model_config_pipeline(): assert_model_config_step() p = assert_model_config_pipeline.with_options(config_path=str(config_path)) - assert p.configuration.model_config_model.name == "bar" + assert p.configuration.model_config.name == "bar" with patch("zenml.new.pipelines.pipeline.logger.warning") as warning: p.configure( @@ -154,24 +154,24 @@ def assert_model_config_pipeline(): ) warning.assert_called_once() - assert p.configuration.model_config_model is not None - assert p.configuration.model_config_model.name == "foo" - assert p.configuration.model_config_model.version == RUNNING_MODEL_VERSION - assert p.configuration.model_config_model.create_new_model_version - assert not p.configuration.model_config_model.delete_new_version_on_failure - assert p.configuration.model_config_model.description == "description" - assert p.configuration.model_config_model.license == "MIT" - assert p.configuration.model_config_model.audience == "audience" - assert p.configuration.model_config_model.use_cases == "use_cases" - assert p.configuration.model_config_model.limitations == "limitations" - assert p.configuration.model_config_model.trade_offs == "trade_offs" - assert p.configuration.model_config_model.ethic == "ethic" - assert p.configuration.model_config_model.tags == ["tag"] + assert p.configuration.model_config is not None + assert p.configuration.model_config.name == "foo" + assert p.configuration.model_config.version == RUNNING_MODEL_VERSION + assert p.configuration.model_config.create_new_model_version + assert not p.configuration.model_config.delete_new_version_on_failure + assert p.configuration.model_config.description == "description" + assert p.configuration.model_config.license == "MIT" + assert p.configuration.model_config.audience == "audience" + assert p.configuration.model_config.use_cases == "use_cases" + assert p.configuration.model_config.limitations == "limitations" + assert p.configuration.model_config.trade_offs == "trade_offs" + assert p.configuration.model_config.ethic == "ethic" + assert p.configuration.model_config.tags == ["tag"] assert ( - p.configuration.model_config_model.version_description + p.configuration.model_config.version_description == "version_description" ) - assert p.configuration.model_config_model.save_models_to_registry + assert p.configuration.model_config.save_models_to_registry with pytest.raises(AssertionError): p() From 7056f96b250755e72323729a366b114e4ca4232c Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 14:41:22 +0200 Subject: [PATCH 07/41] cleanup after rename --- src/zenml/new/pipelines/pipeline.py | 4 +--- src/zenml/steps/base_step.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index aa2df595159..ea2a77831d4 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -363,14 +363,12 @@ def configure( "extra": extra, "failure_hook_source": failure_hook_source, "success_hook_source": success_hook_source, - "model_config_model": model_config, + "model_config": model_config, } ) if not self.__suppress_warnings_flag__: to_be_reapplied = [] for param_, value_ in values.items(): - if param_ == "model_config_model": - param_ = "model_config" if ( param_ in PipelineRunConfiguration.__fields__ and param_ in self._from_config_file diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index b005161f7c1..e9c75f167c4 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -761,7 +761,7 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: "extra": extra, "failure_hook_source": failure_hook_source, "success_hook_source": success_hook_source, - "model_config_model": model_config, + "model_config": model_config, } ) config = StepConfigurationUpdate(**values) From 33a0bc064bd98eb4970bd9a4580fefe6937034de Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:10:38 +0200 Subject: [PATCH 08/41] reuse model config in artifact config --- src/zenml/model/artifact_config.py | 42 ++++++++---------------- src/zenml/orchestrators/step_launcher.py | 4 ++- src/zenml/orchestrators/step_runner.py | 4 ++- 3 files changed, 20 insertions(+), 30 deletions(-) diff --git a/src/zenml/model/artifact_config.py b/src/zenml/model/artifact_config.py index 1b723178aed..6c5d12aefc1 100644 --- a/src/zenml/model/artifact_config.py +++ b/src/zenml/model/artifact_config.py @@ -24,7 +24,6 @@ if TYPE_CHECKING: from zenml.model.model_config import ModelConfig - from zenml.models import ModelResponseModel, ModelVersionResponseModel logger = get_logger(__name__) @@ -95,27 +94,10 @@ def _model_config(self) -> "ModelConfig": # Return the model from the context return model_config - @property - def _model(self) -> "ModelResponseModel": - """Get the `ModelResponseModel`. - - Returns: - ModelResponseModel: The fetched or created model. - """ - return self._model_config.get_or_create_model() - - @property - def _model_version(self) -> "ModelVersionResponseModel": - """Get the `ModelVersionResponseModel`. - - Returns: - ModelVersionResponseModel: The model version. - """ - return self._model_config.get_or_create_model_version() - def _link_to_model_version( self, artifact_uuid: UUID, + model_config: "ModelConfig", is_model_object: bool = False, is_deployment: bool = False, ) -> None: @@ -125,6 +107,7 @@ def _link_to_model_version( Args: artifact_uuid: The UUID of the artifact to link. + model_config: The model configuration from caller. is_model_object: Whether the artifact is a model object. Defaults to False. is_deployment: Whether the artifact is a deployment. Defaults to False. """ @@ -137,6 +120,8 @@ def _link_to_model_version( # Create a ZenML client client = Client() + model_version = model_config._get_model_version() + artifact_name = self.artifact_name if artifact_name is None: artifact = client.zen_store.get_artifact(artifact_id=artifact_uuid) @@ -148,8 +133,8 @@ def _link_to_model_version( workspace=client.active_workspace.id, name=artifact_name, artifact=artifact_uuid, - model=self._model.id, - model_version=self._model_version.id, + model=model_version.model.id, + model_version=model_version.id, is_model_object=is_model_object, is_deployment=is_deployment, overwrite=self.overwrite, @@ -163,8 +148,8 @@ def _link_to_model_version( user_id=client.active_user.id, workspace_id=client.active_workspace.id, name=artifact_name, - model_id=self._model.id, - model_version_id=self._model_version.id, + model_id=model_version.model.id, + model_version_id=model_version.id, only_artifacts=not (is_model_object or is_deployment), only_deployments=is_deployment, only_model_objects=is_model_object, @@ -177,8 +162,8 @@ def _link_to_model_version( f"Existing artifact link(s) `{artifact_name}` found and will be deleted." ) client.zen_store.delete_model_version_artifact_link( - model_name_or_id=self._model.id, - model_version_name_or_id=self._model_version.id, + model_name_or_id=model_version.model.id, + model_version_name_or_id=model_version.id, model_version_artifact_link_name_or_id=artifact_name, ) else: @@ -188,16 +173,17 @@ def _link_to_model_version( client.zen_store.create_model_version_artifact_link(request) def link_to_model( - self, - artifact_uuid: UUID, + self, artifact_uuid: UUID, model_config: "ModelConfig" ) -> None: """Link artifact to the model version. Args: - artifact_uuid (UUID): The UUID of the artifact to link. + artifact_uuid: The UUID of the artifact to link. + model_config: The model configuration from caller. """ self._link_to_model_version( artifact_uuid, + model_config=model_config, is_model_object=self.IS_MODEL_ARTIFACT, is_deployment=self.IS_DEPLOYMENT_ARTIFACT, ) diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index b3072aef022..558fc444b68 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -430,7 +430,9 @@ def _link_cached_artifacts_to_model_version( self._deployment.pipeline_configuration.name ) artifact_config_._step_name = self._step_name - artifact_config_.link_to_model(output_) + artifact_config_.link_to_model( + artifact_uuid=output_, model_config=model_config + ) def _run_step( self, diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 8752c0f82bb..5b936c6bb58 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -654,7 +654,9 @@ def _link_artifacts_to_model( get_step_context().pipeline.name ) artifact_config._step_name = get_step_context().step_run.name - artifact_config.link_to_model(artifact_uuid=artifact_uuid) + artifact_config.link_to_model( + artifact_uuid=artifact_uuid, model_config=mc + ) def _get_model_versions_from_artifacts( self, From ab9d117c0117cc9970642b9be23c6e6b79e678c2 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:26:49 +0200 Subject: [PATCH 09/41] warm up all model configs on pipeline level --- src/zenml/model/model_config.py | 24 +++++++------- src/zenml/new/pipelines/pipeline.py | 50 ++++++++++++++++------------- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index d96e519439e..d60b4fd7b07 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -196,21 +196,19 @@ def _validate_config_in_runtime(self) -> None: try: logger.error(f"VALIDATED MODEL {self.name}!") model_version = self._get_model_version() - for run_name, run in model_version.pipeline_runs.items(): - if run.status == ExecutionStatus.RUNNING: + if self.create_new_model_version: + for run_name, run in model_version.pipeline_runs.items(): + if run.status == ExecutionStatus.RUNNING: + raise RuntimeError( + f"New model version was requested, but pipeline run `{run_name}` " + f"is still running with version `{model_version.name}`." + ) + + if not self.delete_new_version_on_failure: raise RuntimeError( - f"New model version was requested, but pipeline run `{run_name}` " - f"is still running with version `{model_version.name}`." + f"Cannot create version `{self.version}` " + f"for model `{self.name}` since it already exists" ) - - if ( - self.create_new_model_version - and not self.delete_new_version_on_failure - ): - raise RuntimeError( - f"Cannot create version `{self.version}` " - f"for model `{self.name}` since it already exists" - ) except KeyError: pass self.get_or_create_model_version() diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index ea2a77831d4..2fda43ce6f1 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -816,6 +816,7 @@ def get_new_version_requests( new_versions_requested: Dict[ str, NewModelVersionRequest ] = defaultdict(NewModelVersionRequest) + other_model_configs: List["ModelConfig"] = [] all_steps_have_own_config = True for step in deployment.step_configurations.values(): step_model_config = step.config.model_config @@ -823,32 +824,34 @@ def get_new_version_requests( all_steps_have_own_config and step.config.model_config is not None ) - if ( - step_model_config - and step_model_config.create_new_model_version - ): - new_versions_requested[step_model_config.name].update_request( - step_model_config, - NewModelVersionRequest.Requester( - source="step", name=step.config.name - ), - ) + if step_model_config: + if step_model_config.create_new_model_version: + new_versions_requested[ + step_model_config.name + ].update_request( + step_model_config, + NewModelVersionRequest.Requester( + source="step", name=step.config.name + ), + ) + else: + other_model_configs.append(step_model_config) if not all_steps_have_own_config: pipeline_model_config = ( deployment.pipeline_configuration.model_config ) - if ( - pipeline_model_config - and pipeline_model_config.create_new_model_version - ): - new_versions_requested[ - pipeline_model_config.name - ].update_request( - pipeline_model_config, - NewModelVersionRequest.Requester( - source="pipeline", name=self.name - ), - ) + if pipeline_model_config: + if pipeline_model_config.create_new_model_version: + new_versions_requested[ + pipeline_model_config.name + ].update_request( + pipeline_model_config, + NewModelVersionRequest.Requester( + source="pipeline", name=self.name + ), + ) + else: + other_model_configs.append(pipeline_model_config) elif deployment.pipeline_configuration.model_config is not None: logger.warning( f"ModelConfig of pipeline `{self.name}` is overridden in all steps. " @@ -856,6 +859,9 @@ def get_new_version_requests( self._validate_new_version_requests(new_versions_requested) + for other_model_config in other_model_configs: + other_model_config._validate_config_in_runtime() + return new_versions_requested def _validate_new_version_requests( From 84c756ff651f5d8e01b578a50ac1999f2527ea3f Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:32:07 +0200 Subject: [PATCH 10/41] avoid multi warning in ModelConfig --- src/zenml/model/model_config.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index d60b4fd7b07..1c7847702ba 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -58,6 +58,7 @@ class ModelConfig(BaseModel): model: Optional[Any] = None model_version: Optional[Any] = None + user_not_yet_warned: bool = True def __init__( self, @@ -118,6 +119,7 @@ def __init__( ) self.model = kwargs.get("model", None) self.model_version = kwargs.get("model_version", None) + self.user_not_yet_warned = kwargs.get("user_not_yet_warned", True) class Config: """Config class.""" @@ -143,11 +145,13 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: delete_new_version_on_failure = values.get( "delete_new_version_on_failure", True ) + user_not_yet_warned = values.get("user_not_yet_warned", True) if not delete_new_version_on_failure and not create_new_model_version: - logger.warning( - "Using `delete_new_version_on_failure=False` and `create_new_model_version=False` has no effect." - "Setting `delete_new_version_on_failure` to `True`." - ) + if user_not_yet_warned: + logger.warning( + "Using `delete_new_version_on_failure=False` and `create_new_model_version=False` has no effect." + "Setting `delete_new_version_on_failure` to `True`." + ) values["delete_new_version_on_failure"] = True version = values.get("version", None) @@ -171,19 +175,24 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: if str(version).isnumeric(): raise ValueError(misuse_message.format(set="a numeric value")) if version is None: - logger.info( - "Creation of new model version was requested, but no version name was explicitly provided. " - f"Setting `version` to `{RUNNING_MODEL_VERSION}`." - ) + if user_not_yet_warned: + logger.info( + "Creation of new model version was requested, but no version name was explicitly provided. " + f"Setting `version` to `{RUNNING_MODEL_VERSION}`." + ) values["version"] = RUNNING_MODEL_VERSION - if version in [stage.value for stage in ModelStages]: + if ( + version in [stage.value for stage in ModelStages] + and user_not_yet_warned + ): logger.info( f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage." ) - if str(version).isnumeric(): + if str(version).isnumeric() and user_not_yet_warned: logger.info( f"`version` `{version}` is numeric and will be fetched using version number." ) + values["user_not_yet_warned"] = False return values def _validate_config_in_runtime(self) -> None: From ffcdc342bb9d554e42fe4da60ba0ad15c3258050 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:41:05 +0200 Subject: [PATCH 11/41] properly pass values to validator --- src/zenml/model/model_config.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 1c7847702ba..d2ea4f857a5 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -116,10 +116,10 @@ def __init__( create_new_model_version=create_new_model_version, save_models_to_registry=save_models_to_registry, delete_new_version_on_failure=delete_new_version_on_failure, + model=kwargs.get("model", None), + model_version=kwargs.get("model_version", None), + user_not_yet_warned=kwargs.get("user_not_yet_warned", True), ) - self.model = kwargs.get("model", None) - self.model_version = kwargs.get("model_version", None) - self.user_not_yet_warned = kwargs.get("user_not_yet_warned", True) class Config: """Config class.""" From 7f0040498b58761436e3ad1d01af055e14fd6a6d Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:41:24 +0200 Subject: [PATCH 12/41] avoid instantiate in parse config file --- src/zenml/new/pipelines/pipeline.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 2fda43ce6f1..48adb75564d 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -1128,7 +1128,7 @@ def _compile( integration_registry.activate_integrations() - self._from_config_file = self._parse_config_file( + self._parse_config_file( config_path=config_path, matcher=list(PipelineRunConfiguration.__fields__.keys()), ) @@ -1365,15 +1365,12 @@ def __exit__(self, *args: Any) -> None: def _parse_config_file( self, config_path: Optional[str], matcher: List[str] - ) -> Dict[str, Any]: - """Parses the given configuration file. + ) -> None: + """Parses the given configuration file and sets `self._from_config_file`. Args: config_path: Path to a yaml configuration file. matcher: List of keys to match in the configuration file. - - Returns: - The parsed configuration file as a dictionary. """ _from_config_file: Dict[str, Any] = {} if config_path: @@ -1390,10 +1387,11 @@ def _parse_config_file( if "model_config" in _from_config_file: from zenml.model.model_config import ModelConfig - _from_config_file["model_config"] = ModelConfig.parse_obj( - _from_config_file["model_config"] + _from_config_file["model_config"] = self._from_config_file.get( + "model_config", + ModelConfig.parse_obj(_from_config_file["model_config"]), ) - return _from_config_file + self._from_config_file = _from_config_file def with_options( self, @@ -1432,7 +1430,7 @@ def with_options( """ pipeline_copy = self.copy() - pipeline_copy._from_config_file = self._parse_config_file( + pipeline_copy._parse_config_file( config_path=config_path, matcher=inspect.getfullargspec(self.configure)[0], ) From cc8cd0166fe01925dd00c97d4f0b0dbd25ad08c9 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:54:35 +0200 Subject: [PATCH 13/41] avoid instantiate in parse config file --- src/zenml/new/pipelines/pipeline.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 48adb75564d..85378b35730 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -1385,12 +1385,16 @@ def _parse_config_file( ) if "model_config" in _from_config_file: - from zenml.model.model_config import ModelConfig + if "model_config" in self._from_config_file: + _from_config_file["model_config"] = self._from_config_file[ + "model_config" + ] + else: + from zenml.model.model_config import ModelConfig - _from_config_file["model_config"] = self._from_config_file.get( - "model_config", - ModelConfig.parse_obj(_from_config_file["model_config"]), - ) + _from_config_file["model_config"] = ModelConfig.parse_obj( + _from_config_file["model_config"] + ) self._from_config_file = _from_config_file def with_options( From 9e7fe6c6a162225f3ec026ca75d5f83e8ac4298b Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 15:59:07 +0200 Subject: [PATCH 14/41] remove debug statements --- src/zenml/model/model_config.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index d2ea4f857a5..4b25e625681 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -100,7 +100,6 @@ def __init__( delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. kwargs: Other arguments. """ - logger.error(f"INSTANTIATED MODEL CONFIG {name}/{version}!") super().__init__( name=name, license=license, @@ -203,7 +202,6 @@ def _validate_config_in_runtime(self) -> None: RuntimeError: If there is unfinished pipeline run for requested new model version. """ try: - logger.error(f"VALIDATED MODEL {self.name}!") model_version = self._get_model_version() if self.create_new_model_version: for run_name, run in model_version.pipeline_runs.items(): @@ -238,8 +236,6 @@ def get_or_create_model(self) -> "ModelResponseModel": from zenml.client import Client from zenml.models.model_models import ModelRequestModel - logger.error(f"RETRIEVED MODEL {self.name}!") - zenml_client = Client() try: self.model = zenml_client.get_model(model_name_or_id=self.name) @@ -286,7 +282,6 @@ def _create_model_version( from zenml.client import Client from zenml.models.model_models import ModelVersionRequestModel - logger.error(f"CREATED VERSION {self.version}!") zenml_client = Client() model_version_request = ModelVersionRequestModel( user=zenml_client.active_user.id, @@ -323,7 +318,6 @@ def _get_model_version(self) -> "ModelVersionResponseModel": from zenml.client import Client - logger.error(f"RETRIEVED VERSION {self.version}!") zenml_client = Client() if self.version is None: # raise if not found From f5bbdc0b4790d2bcd3823640e138c96db5e5139c Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 16:03:45 +0200 Subject: [PATCH 15/41] update misleading logging --- src/zenml/model/artifact_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/zenml/model/artifact_config.py b/src/zenml/model/artifact_config.py index 6c5d12aefc1..ac3499e7655 100644 --- a/src/zenml/model/artifact_config.py +++ b/src/zenml/model/artifact_config.py @@ -153,6 +153,8 @@ def _link_to_model_version( only_artifacts=not (is_model_object or is_deployment), only_deployments=is_deployment, only_model_objects=is_model_object, + pipeline_name=self._pipeline_name, + step_name=self._step_name, ) ) if len(existing_links): From 019059ffaeccf371d6f67ee03a0ab1144a7783ea Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 16:13:51 +0200 Subject: [PATCH 16/41] improve logging of cached linking --- src/zenml/orchestrators/step_launcher.py | 26 +++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 558fc444b68..9d8b1713e52 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -414,18 +414,20 @@ def _link_cached_artifacts_to_model_version( for output_name_, output_ in step_run.outputs.items(): if output_name_ in output_annotations: annotation = output_annotations.get(output_name_, None) - artifact_config = ( - annotation.artifact_config - if annotation and annotation.artifact_config is not None - else ArtifactConfig() - ) - artifact_config_ = artifact_config.copy() - artifact_config_.model_name = ( - artifact_config.model_name or model_version.model.name - ) - artifact_config_.model_version = ( - artifact_config_.model_version or model_version.name - ) + if annotation and annotation.artifact_config is not None: + artifact_config_ = annotation.artifact_config.copy() + else: + artifact_config_ = ArtifactConfig( + artifact_name=output_name_, + model_name=model_version.model.name, + model_version=model_version.name, + ) + logger.info( + f"Linking artifact `{artifact_config_.artifact_name}` to " + f"model `{artifact_config_.model_name}` version " + f"`{artifact_config_.model_version}` implicitly." + ) + artifact_config_._pipeline_name = ( self._deployment.pipeline_configuration.name ) From b177f4c8082588ea82d961ef8dbfcf54c556d41a Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:23:05 +0200 Subject: [PATCH 17/41] allow latest version in cli links listing --- src/zenml/cli/model.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index f1052403040..018ed48e50d 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -468,7 +468,16 @@ def _print_artifacts_links_generic( """ model_version = Client().get_model_version( model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id, + model_version_name_or_number_or_id=None + if model_version_name_or_number_or_id == "0" + else model_version_name_or_number_or_id, + ) + type_ = ( + "artifacts" + if only_artifacts + else "deployments" + if only_deployments + else "model objects" ) if ( @@ -476,16 +485,13 @@ def _print_artifacts_links_generic( or (only_deployments and not model_version.deployment_ids) or (only_model_objects and not model_version.model_object_ids) ): - _type = ( - "artifacts" - if only_artifacts - else "deployments" - if only_deployments - else "model objects" - ) - cli_utils.declare(f"No {_type} linked to the model version found.") + cli_utils.declare(f"No {type_} linked to the model version found.") return + cli_utils.title( + f"{type_} linked to the model version `{model_version.name}[{model_version.number}]`:" + ) + links = Client().list_model_version_artifact_links( ModelVersionArtifactFilterModel( model_id=model_version.model.id, @@ -515,7 +521,7 @@ def _print_artifacts_links_generic( help="List artifacts linked to a model version.", ) @click.argument("model_name_or_id") -@click.argument("model_version_name_or_number_or_id") +@click.argument("model_version_name_or_number_or_id", default="0") @cli_utils.list_options(ModelVersionArtifactFilterModel) def list_model_version_artifacts( model_name_or_id: str, @@ -527,6 +533,7 @@ def list_model_version_artifacts( Args: model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The name, number or ID of the model version. + Or use 0 for the latest version. **kwargs: Keyword arguments to filter models. """ _print_artifacts_links_generic( @@ -542,7 +549,7 @@ def list_model_version_artifacts( help="List model objects linked to a model version.", ) @click.argument("model_name_or_id") -@click.argument("model_version_name_or_number_or_id") +@click.argument("model_version_name_or_number_or_id", default="0") @cli_utils.list_options(ModelVersionArtifactFilterModel) def list_model_version_model_objects( model_name_or_id: str, @@ -554,6 +561,7 @@ def list_model_version_model_objects( Args: model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The name, number or ID of the model version. + Or use 0 for the latest version. **kwargs: Keyword arguments to filter models. """ _print_artifacts_links_generic( @@ -569,7 +577,7 @@ def list_model_version_model_objects( help="List deployments linked to a model version.", ) @click.argument("model_name_or_id") -@click.argument("model_version_name_or_number_or_id") +@click.argument("model_version_name_or_number_or_id", default="0") @cli_utils.list_options(ModelVersionArtifactFilterModel) def list_model_version_deployments( model_name_or_id: str, @@ -581,6 +589,7 @@ def list_model_version_deployments( Args: model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The name, number or ID of the model version. + Or use 0 for the latest version. **kwargs: Keyword arguments to filter models. """ _print_artifacts_links_generic( @@ -596,7 +605,7 @@ def list_model_version_deployments( help="List pipeline runs of a model version.", ) @click.argument("model_name_or_id") -@click.argument("model_version_name_or_number_or_id") +@click.argument("model_version_name_or_number_or_id", default="0") @cli_utils.list_options(ModelVersionPipelineRunFilterModel) def list_model_version_pipeline_runs( model_name_or_id: str, @@ -608,6 +617,7 @@ def list_model_version_pipeline_runs( Args: model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The name, number or ID of the model version. + Or use 0 for the latest version. **kwargs: Keyword arguments to filter models. """ model_version = Client().get_model_version( @@ -618,6 +628,9 @@ def list_model_version_pipeline_runs( if not model_version.pipeline_run_ids: cli_utils.declare("No pipeline runs attached to model version found.") return + cli_utils.title( + f"Pipeline runs linked to the model version `{model_version.name}[{model_version.number}]`:" + ) links = Client().list_model_version_pipeline_run_links( ModelVersionPipelineRunFilterModel( From 468c99989ca6db6ebeb024ae9a56feee2c40421a Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:25:02 +0200 Subject: [PATCH 18/41] Update src/zenml/model/model_config.py Co-authored-by: Alex Strick van Linschoten --- src/zenml/model/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 4b25e625681..ebd4fb90cb1 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -86,7 +86,7 @@ def __init__( description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. - limitations: The know limitations of the model. + limitations: The known limitations of the model. trade_offs: The trade offs of the model. ethic: The ethical implications of the model. tags: Tags associated with the model. From 14b3ec4df1cdbdd14b5210c0dfaebc75f1df3852 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:25:08 +0200 Subject: [PATCH 19/41] Update src/zenml/model/model_config.py Co-authored-by: Alex Strick van Linschoten --- src/zenml/model/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index ebd4fb90cb1..912eda13732 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -82,7 +82,7 @@ def __init__( Args: name: The name of the model. - license: The license model created under. + license: The license under which the model is created. description: The description of the model. audience: The target audience of the model. use_cases: The use cases of the model. From b469058dc50f2f46c165b1dfc2b07b78fd98fd84 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:25:13 +0200 Subject: [PATCH 20/41] Update src/zenml/model/model_config.py Co-authored-by: Alex Strick van Linschoten --- src/zenml/model/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 912eda13732..ebb532dcb4d 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -87,7 +87,7 @@ def __init__( audience: The target audience of the model. use_cases: The use cases of the model. limitations: The known limitations of the model. - trade_offs: The trade offs of the model. + trade_offs: The tradeoffs of the model. ethic: The ethical implications of the model. tags: Tags associated with the model. version: The model version name, number or stage is optional and points model context From ce5ee8908b7fb1d966f08f87e3381a2b274db91a Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:25:21 +0200 Subject: [PATCH 21/41] Update src/zenml/model/model_config.py Co-authored-by: Alex Strick van Linschoten --- src/zenml/model/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index ebb532dcb4d..a29ea03f635 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -91,7 +91,7 @@ def __init__( ethic: The ethical implications of the model. tags: Tags associated with the model. version: The model version name, number or stage is optional and points model context - to a specific version/stage, if skipped and `create_new_model_version` is False - + to a specific version/stage. If skipped and `create_new_model_version` is False - latest model version will be used. version_description: The description of the model version. create_new_model_version: Whether to create a new model version during execution From 397766ec8e9f83048ed5d2f84ed7f97a6f5274f1 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 18:20:30 +0200 Subject: [PATCH 22/41] unique test file name --- .../model/{test_model_config.py => test_model_config_init.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/unit/model/{test_model_config.py => test_model_config_init.py} (100%) diff --git a/tests/unit/model/test_model_config.py b/tests/unit/model/test_model_config_init.py similarity index 100% rename from tests/unit/model/test_model_config.py rename to tests/unit/model/test_model_config_init.py From eee10523374d755052182ab11fb4bfbeb52307ea Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 17 Oct 2023 18:20:50 +0200 Subject: [PATCH 23/41] ethic->ethics --- src/zenml/cli/model.py | 4 ++-- src/zenml/model/model_config.py | 12 ++++++------ src/zenml/models/model_base_model.py | 2 +- src/zenml/models/model_models.py | 2 +- src/zenml/orchestrators/step_runner.py | 16 +++++++++++++--- ...bb9_rename_model_config_model_to_model_.py | 19 +++++++++++++++++-- src/zenml/zen_stores/schemas/model_schemas.py | 6 +++--- .../pipelines/test_pipeline_config.py | 8 ++++---- .../functional/zen_stores/utils.py | 2 +- 9 files changed, 48 insertions(+), 23 deletions(-) diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index 018ed48e50d..862acb42bd6 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -159,7 +159,7 @@ def register_model( audience=audience, use_cases=use_cases, trade_offs=tradeoffs, - ethic=ethical, + ethics=ethical, limitations=limitations, tags=tag, user=Client().active_user.id, @@ -265,7 +265,7 @@ def update_model( audience=audience, use_cases=use_cases, trade_offs=tradeoffs, - ethic=ethical, + ethics=ethical, limitations=limitations, tags=tag, user=Client().active_user.id, diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index a29ea03f635..2dc19a733fa 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -48,7 +48,7 @@ class ModelConfig(BaseModel): use_cases: Optional[str] limitations: Optional[str] trade_offs: Optional[str] - ethic: Optional[str] + ethics: Optional[str] tags: Optional[List[str]] version: Optional[Union[ModelStages, int, str]] version_description: Optional[str] @@ -69,7 +69,7 @@ def __init__( use_cases: Optional[str] = None, limitations: Optional[str] = None, trade_offs: Optional[str] = None, - ethic: Optional[str] = None, + ethics: Optional[str] = None, tags: Optional[List[str]] = None, version: Optional[Union[ModelStages, int, str]] = None, version_description: Optional[str] = None, @@ -88,7 +88,7 @@ def __init__( use_cases: The use cases of the model. limitations: The known limitations of the model. trade_offs: The tradeoffs of the model. - ethic: The ethical implications of the model. + ethics: The ethical implications of the model. tags: Tags associated with the model. version: The model version name, number or stage is optional and points model context to a specific version/stage. If skipped and `create_new_model_version` is False - @@ -108,7 +108,7 @@ def __init__( use_cases=use_cases, limitations=limitations, trade_offs=trade_offs, - ethic=ethic, + ethics=ethics, tags=tags, version=version, version_description=version_description, @@ -248,7 +248,7 @@ def get_or_create_model(self) -> "ModelResponseModel": use_cases=self.use_cases, limitations=self.limitations, trade_offs=self.trade_offs, - ethic=self.ethic, + ethics=self.ethics, tags=self.tags, user=zenml_client.active_user.id, workspace=zenml_client.active_workspace.id, @@ -366,7 +366,7 @@ def _merge(self, model_config: "ModelConfig") -> None: self.use_cases = self.use_cases or model_config.use_cases self.limitations = self.limitations or model_config.limitations self.trade_offs = self.trade_offs or model_config.trade_offs - self.ethic = self.ethic or model_config.ethic + self.ethics = self.ethics or model_config.ethics if model_config.tags is not None: self.tags = (self.tags or []) + model_config.tags diff --git a/src/zenml/models/model_base_model.py b/src/zenml/models/model_base_model.py index 69bb1105bc1..411a3b32aa8 100644 --- a/src/zenml/models/model_base_model.py +++ b/src/zenml/models/model_base_model.py @@ -54,7 +54,7 @@ class ModelBaseModel(BaseModel): title="The trade offs of the model", max_length=TEXT_FIELD_MAX_LENGTH, ) - ethic: Optional[str] = Field( + ethics: Optional[str] = Field( title="The ethical implications of the model", max_length=TEXT_FIELD_MAX_LENGTH, ) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 475391094a4..4f83574ec7a 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -618,5 +618,5 @@ class ModelUpdateModel(BaseModel): use_cases: Optional[str] limitations: Optional[str] trade_offs: Optional[str] - ethic: Optional[str] + ethics: Optional[str] tags: Optional[List[str]] diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 5b936c6bb58..00cfae4b007 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -627,9 +627,6 @@ def _link_artifacts_to_model( mc = get_step_context().model_config except StepContextError: mc = None - logger.warning( - "No model context found, unable to auto-link artifacts." - ) for artifact_name in artifact_ids: artifact_uuid = artifact_ids[artifact_name] @@ -647,6 +644,19 @@ def _link_artifacts_to_model( ) if artifact_config is not None: + if mc is None: + if artifact_config.model_name is None: + logger.warning( + "No model context found, unable to auto-link artifacts." + ) + return + else: + from zenml.model.model_config import ModelConfig + + mc = ModelConfig( + name=artifact_config.model_name, + version=artifact_config.model_version, + ) artifact_config.artifact_name = ( artifact_config.artifact_name or artifact_name ) diff --git a/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py b/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py index 6300867134c..9bbaaac9648 100644 --- a/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py +++ b/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py @@ -5,6 +5,7 @@ Create Date: 2023-10-17 13:57:35.810054 """ +import sqlalchemy as sa from alembic import op from sqlalchemy.sql import text @@ -18,6 +19,13 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + table_name="model", + column_name="ethic", + new_column_name="ethics", + existing_type=sa.TEXT(), + ) + connection = op.get_bind() update_config_fields = text( @@ -28,7 +36,7 @@ def upgrade() -> None: '"model_config_model"', '"model_config"' ), - step_configurations = REPLACE( + step_configurations = REPLACE( step_configurations, '"model_config_model"', '"model_config"' @@ -42,6 +50,13 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + table_name="model", + column_name="ethics", + new_column_name="ethic", + existing_type=sa.TEXT(), + ) + connection = op.get_bind() update_config_fields = text( @@ -52,7 +67,7 @@ def downgrade() -> None: '"model_config"', '"model_config_model"' ), - step_configurations = REPLACE( + step_configurations = REPLACE( step_configurations, '"model_config"', '"model_config_model"' diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 774bb1bc684..3078069c768 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -72,7 +72,7 @@ class ModelSchema(NamedSchema, table=True): use_cases: str = Field(sa_column=Column(TEXT, nullable=True)) limitations: str = Field(sa_column=Column(TEXT, nullable=True)) trade_offs: str = Field(sa_column=Column(TEXT, nullable=True)) - ethic: str = Field(sa_column=Column(TEXT, nullable=True)) + ethics: str = Field(sa_column=Column(TEXT, nullable=True)) tags: str = Field(sa_column=Column(TEXT, nullable=True)) model_versions: List["ModelVersionSchema"] = Relationship( back_populates="model", @@ -107,7 +107,7 @@ def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": use_cases=model_request.use_cases, limitations=model_request.limitations, trade_offs=model_request.trade_offs, - ethic=model_request.ethic, + ethics=model_request.ethics, tags=json.dumps(model_request.tags) if model_request.tags else None, @@ -132,7 +132,7 @@ def to_model(self) -> ModelResponseModel: use_cases=self.use_cases, limitations=self.limitations, trade_offs=self.trade_offs, - ethic=self.ethic, + ethics=self.ethics, tags=json.loads(self.tags) if self.tags else None, ) diff --git a/tests/integration/functional/pipelines/test_pipeline_config.py b/tests/integration/functional/pipelines/test_pipeline_config.py index 6173da1a060..5eccf78bfc3 100644 --- a/tests/integration/functional/pipelines/test_pipeline_config.py +++ b/tests/integration/functional/pipelines/test_pipeline_config.py @@ -37,7 +37,7 @@ def assert_model_config_step(): assert model_config.use_cases == "use_cases" assert model_config.limitations == "limitations" assert model_config.trade_offs == "trade_offs" - assert model_config.ethic == "ethic" + assert model_config.ethics == "ethics" assert model_config.tags == ["tag"] assert model_config.version_description == "version_description" assert model_config.save_models_to_registry @@ -62,7 +62,7 @@ def test_pipeline_with_model_config_from_yaml(clean_workspace, tmp_path): use_cases="use_cases", limitations="limitations", trade_offs="trade_offs", - ethic="ethic", + ethics="ethics", tags=["tag"], version_description="version_description", save_models_to_registry=True, @@ -146,7 +146,7 @@ def assert_model_config_pipeline(): use_cases="use_cases", limitations="limitations", trade_offs="trade_offs", - ethic="ethic", + ethics="ethics", tags=["tag"], version_description="version_description", save_models_to_registry=True, @@ -165,7 +165,7 @@ def assert_model_config_pipeline(): assert p.configuration.model_config.use_cases == "use_cases" assert p.configuration.model_config.limitations == "limitations" assert p.configuration.model_config.trade_offs == "trade_offs" - assert p.configuration.model_config.ethic == "ethic" + assert p.configuration.model_config.ethics == "ethics" assert p.configuration.model_config.tags == ["tag"] assert ( p.configuration.model_config.version_description diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index 8f8bf5a57ab..2d21acf2d41 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -945,7 +945,7 @@ def update_method( use_cases="all", limitations="none", trade_offs="secret", - ethic="all good", + ethics="all good", tags=["cool", "stuff"], ), update_model=ModelUpdateModel( From 53f254de9f9d693c9620cd24838200ef9139c170 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 09:34:23 +0200 Subject: [PATCH 24/41] update `step_run` --- ...bb9_rename_model_config_model_to_model_.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py b/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py index 9bbaaac9648..4aeabb23532 100644 --- a/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py +++ b/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py @@ -44,6 +44,18 @@ def upgrade() -> None: """ ) connection.execute(update_config_fields) + + update_config_fields = text( + """ + UPDATE step_run + SET step_configuration = REPLACE( + step_configuration, + '"model_config_model"', + '"model_config"' + ) + """ + ) + connection.execute(update_config_fields) # ### end Alembic commands ### @@ -76,4 +88,16 @@ def downgrade() -> None: ) connection.execute(update_config_fields) + update_config_fields = text( + """ + UPDATE step_run + SET step_configuration = REPLACE( + step_configuration, + '"model_config"', + '"model_config_model"' + ) + """ + ) + connection.execute(update_config_fields) + # ### end Alembic commands ### From b117128bc2d128cf706f4b86a279fed498000e4f Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 11:15:25 +0200 Subject: [PATCH 25/41] update logging mock in tests --- .../integration/functional/model/test_model_config.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/integration/functional/model/test_model_config.py b/tests/integration/functional/model/test_model_config.py index 20c613701fe..29709f1a162 100644 --- a/tests/integration/functional/model/test_model_config.py +++ b/tests/integration/functional/model/test_model_config.py @@ -173,14 +173,10 @@ def test_init_create_new_version_with_version_fails(self): def test_init_recovery_without_create_new_version_warns(self): """Test that use of `recovery` warn on `create_new_model_version` set to False.""" - with mock.patch( - "zenml.models.model_base_model.logger.warning" - ) as logger: + with mock.patch("zenml.model.model_config.logger.warning") as logger: ModelConfig(name=MODEL_NAME, delete_new_version_on_failure=False) logger.assert_called_once() - with mock.patch( - "zenml.models.model_base_model.logger.warning" - ) as logger: + with mock.patch("zenml.model.model_config.logger.warning") as logger: ModelConfig( name=MODEL_NAME, delete_new_version_on_failure=False, @@ -190,7 +186,7 @@ def test_init_recovery_without_create_new_version_warns(self): def test_init_stage_logic(self): """Test that if version is set to string contained in ModelStages user is informed about it.""" - with mock.patch("zenml.models.model_base_model.logger.info") as logger: + with mock.patch("zenml.model.model_config.logger.info") as logger: mc = ModelConfig( name=MODEL_NAME, version=ModelStages.PRODUCTION.value, From fe8b5b5552dd53b0ef68c2c51b84c8a3bef52716 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 11:19:43 +0200 Subject: [PATCH 26/41] extend test case to link from cache --- .../functional/model/test_artifact_config.py | 64 +++++++++++-------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/tests/integration/functional/model/test_artifact_config.py b/tests/integration/functional/model/test_artifact_config.py index fadb67b8dd5..2fcc349571d 100644 --- a/tests/integration/functional/model/test_artifact_config.py +++ b/tests/integration/functional/model/test_artifact_config.py @@ -228,13 +228,13 @@ def multi_named_output_step_from_self() -> ( return 1, 2, 3 -@pipeline(enable_cache=False) -def multi_named_pipeline_from_self(): +@pipeline +def multi_named_pipeline_from_self(enable_cache: bool): """Multi output linking from Annotated.""" - multi_named_output_step_from_self() + multi_named_output_step_from_self.with_options(enable_cache=enable_cache)() -def test_link_multiple_named_outputs_with_self_context(): +def test_link_multiple_named_outputs_with_self_context_and_caching(): """Test multi output linking with context defined in Annotated.""" with model_killer(): client = Client() @@ -266,32 +266,42 @@ def test_link_multiple_named_outputs_with_self_context(): ) ) - multi_named_pipeline_from_self() + for run_count in range(1, 3): + multi_named_pipeline_from_self(run_count == 2) - al1 = client.list_model_version_artifact_links( - ModelVersionArtifactFilterModel( - user_id=user, - workspace_id=ws, - model_id=mv1.model.id, - model_version_id=mv1.id, + al1 = client.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + user_id=user, + workspace_id=ws, + model_id=mv1.model.id, + model_version_id=mv1.id, + ) ) - ) - al2 = client.list_model_version_artifact_links( - ModelVersionArtifactFilterModel( - user_id=user, - workspace_id=ws, - model_id=mv2.model.id, - model_version_id=mv2.id, + al2 = client.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + user_id=user, + workspace_id=ws, + model_id=mv2.model.id, + model_version_id=mv2.id, + ) ) - ) - assert al1.size == 2 - assert al2.size == 1 - - assert {al.name for al in al1} == { - "1", - "2", - } - assert al2[0].name == "3" + assert al1.size == 2 + assert al2.size == 1 + + assert {al.name for al in al1} == { + "1", + "2", + } + assert al2[0].name == "3" + + # clean-up links to test caching linkage + for mv, al in zip([mv1, mv2], [al1, al2]): + for al_ in al: + client.zen_store.delete_model_version_artifact_link( + model_name_or_id=mv.model.id, + model_version_name_or_id=mv.id, + model_version_artifact_link_name_or_id=al_.id, + ) @step(model_config=ModelConfig(name="step", version="step")) From 7aeb14464807e0e15690d2d40d3180d4b7703fd4 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 11:20:21 +0200 Subject: [PATCH 27/41] fixing bugs from tests --- src/zenml/model/model_config.py | 19 ++------ src/zenml/orchestrators/step_launcher.py | 42 +++++++++------- src/zenml/orchestrators/step_runner.py | 61 +++++++++++++----------- 3 files changed, 64 insertions(+), 58 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 2dc19a733fa..72253e008e5 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -211,14 +211,13 @@ def _validate_config_in_runtime(self) -> None: f"is still running with version `{model_version.name}`." ) - if not self.delete_new_version_on_failure: + if self.delete_new_version_on_failure: raise RuntimeError( f"Cannot create version `{self.version}` " f"for model `{self.name}` since it already exists" ) except KeyError: - pass - self.get_or_create_model_version() + self.get_or_create_model_version() def get_or_create_model(self) -> "ModelResponseModel": """This method should get or create a model from Model Control Plane. @@ -274,11 +273,6 @@ def _create_model_version( Returns: The model version based on configuration. """ - if self.model_version is not None: - from zenml.models.model_models import ModelVersionResponseModel - - return ModelVersionResponseModel(**dict(self.model_version)) - from zenml.client import Client from zenml.models.model_models import ModelVersionRequestModel @@ -311,11 +305,6 @@ def _get_model_version(self) -> "ModelVersionResponseModel": Returns: The model version based on configuration. """ - if self.model_version is not None: - from zenml.models.model_models import ModelVersionResponseModel - - return ModelVersionResponseModel(**dict(self.model_version)) - from zenml.client import Client zenml_client = Client() @@ -368,7 +357,9 @@ def _merge(self, model_config: "ModelConfig") -> None: self.trade_offs = self.trade_offs or model_config.trade_offs self.ethics = self.ethics or model_config.ethics if model_config.tags is not None: - self.tags = (self.tags or []) + model_config.tags + self.tags = list( + {t for t in self.tags or []}.union(set(model_config.tags)) + ) self.delete_new_version_on_failure &= ( model_config.delete_new_version_on_failure diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 9d8b1713e52..061a506f179 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -381,11 +381,10 @@ def _prepare( output_name: artifact.id for output_name, artifact in cached_outputs.items() } - if model_config: - self._link_cached_artifacts_to_model_version( - model_config=model_config, - step_run=step_run, - ) + self._link_cached_artifacts_to_model_version( + model_config_from_context=model_config, + step_run=step_run, + ) step_run.status = ExecutionStatus.CACHED step_run.end_time = step_run.start_time @@ -393,7 +392,7 @@ def _prepare( def _link_cached_artifacts_to_model_version( self, - model_config: "ModelConfig", + model_config_from_context: Optional["ModelConfig"], step_run: StepRunRequestModel, ) -> None: """Links the output artifacts of the cached step to the model version in Control Plane. @@ -406,7 +405,6 @@ def _link_cached_artifacts_to_model_version( from zenml.steps.base_step import BaseStep from zenml.steps.utils import parse_return_type_annotations - model_version = model_config.get_or_create_model_version() step_instance = BaseStep.load_from_source(self._step.spec.source) output_annotations = parse_return_type_annotations( step_instance.entrypoint @@ -418,23 +416,33 @@ def _link_cached_artifacts_to_model_version( artifact_config_ = annotation.artifact_config.copy() else: artifact_config_ = ArtifactConfig( - artifact_name=output_name_, - model_name=model_version.model.name, - model_version=model_version.name, + artifact_name=output_name_ ) logger.info( f"Linking artifact `{artifact_config_.artifact_name}` to " f"model `{artifact_config_.model_name}` version " f"`{artifact_config_.model_version}` implicitly." ) + if artifact_config_.model_name is None: + model_config = model_config_from_context + else: + from zenml.model.model_config import ModelConfig - artifact_config_._pipeline_name = ( - self._deployment.pipeline_configuration.name - ) - artifact_config_._step_name = self._step_name - artifact_config_.link_to_model( - artifact_uuid=output_, model_config=model_config - ) + model_config = ModelConfig( + name=artifact_config_.model_name, + version=artifact_config_.model_version, + ) + if model_config: + model_config.get_or_create_model_version() + + artifact_config_._pipeline_name = ( + self._deployment.pipeline_configuration.name + ) + artifact_config_._step_name = self._step_name + artifact_config_.link_to_model( + artifact_uuid=output_, + model_config=model_config, + ) def _run_step( self, diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 00cfae4b007..a40d768108f 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -624,48 +624,55 @@ def _link_artifacts_to_model( from zenml.model.artifact_config import ArtifactConfig try: - mc = get_step_context().model_config + model_config_from_context = get_step_context().model_config except StepContextError: - mc = None + model_config_from_context = None for artifact_name in artifact_ids: artifact_uuid = artifact_ids[artifact_name] - artifact_config = ( + artifact_config_ = ( get_step_context()._get_output(artifact_name).artifact_config ) - if artifact_config is None and mc is not None: - artifact_config = ArtifactConfig( - model_name=mc.name, - model_version=mc.version, - artifact_name=artifact_name, - ) - logger.info( - f"Linking artifact `{artifact_name}` to model `{mc.name}` version `{mc.version}` implicitly." - ) + if artifact_config_ is None: + if model_config_from_context is not None: + artifact_config_ = ArtifactConfig( + model_name=model_config_from_context.name, + model_version=model_config_from_context.version, + artifact_name=artifact_name, + ) + logger.info( + f"Linking artifact `{artifact_name}` to model `{model_config_from_context.name}` version `{model_config_from_context.version}` implicitly." + ) + else: + artifact_config_ = artifact_config_.copy() - if artifact_config is not None: - if mc is None: - if artifact_config.model_name is None: + if artifact_config_ is not None: + if model_config_from_context is None: + if artifact_config_.model_name is None: logger.warning( "No model context found, unable to auto-link artifacts." ) return - else: - from zenml.model.model_config import ModelConfig + if artifact_config_.model_name is not None: + from zenml.model.model_config import ModelConfig - mc = ModelConfig( - name=artifact_config.model_name, - version=artifact_config.model_version, - ) - artifact_config.artifact_name = ( - artifact_config.artifact_name or artifact_name + model_config = ModelConfig( + name=artifact_config_.model_name, + version=artifact_config_.model_version, + ) + else: + model_config = model_config_from_context + + artifact_config_.artifact_name = ( + artifact_config_.artifact_name or artifact_name ) - artifact_config._pipeline_name = ( + artifact_config_._pipeline_name = ( get_step_context().pipeline.name ) - artifact_config._step_name = get_step_context().step_run.name - artifact_config.link_to_model( - artifact_uuid=artifact_uuid, model_config=mc + artifact_config_._step_name = get_step_context().step_run.name + artifact_config_.link_to_model( + artifact_uuid=artifact_uuid, + model_config=model_config, ) def _get_model_versions_from_artifacts( From 42b7b4cfaf3d78ada019476b34375ec1e337b1a1 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 11:23:18 +0200 Subject: [PATCH 28/41] linting --- src/zenml/orchestrators/step_launcher.py | 2 +- src/zenml/orchestrators/step_runner.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 061a506f179..e5eae10f17e 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -398,7 +398,7 @@ def _link_cached_artifacts_to_model_version( """Links the output artifacts of the cached step to the model version in Control Plane. Args: - model_config: The model config of the current step. + model_config_from_context: The model config of the current step. step_run: The step to run. """ from zenml.model.artifact_config import ArtifactConfig diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index a40d768108f..1cf9c661700 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -653,15 +653,16 @@ def _link_artifacts_to_model( "No model context found, unable to auto-link artifacts." ) return - if artifact_config_.model_name is not None: - from zenml.model.model_config import ModelConfig - - model_config = ModelConfig( - name=artifact_config_.model_name, - version=artifact_config_.model_version, - ) else: - model_config = model_config_from_context + if artifact_config_.model_name is not None: + from zenml.model.model_config import ModelConfig + + model_config = ModelConfig( + name=artifact_config_.model_name, + version=artifact_config_.model_version, + ) + else: + model_config = model_config_from_context artifact_config_.artifact_name = ( artifact_config_.artifact_name or artifact_name From a6f53a1c3883c12ff2021263d98d9dd5456e2631 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:04:01 +0200 Subject: [PATCH 29/41] fix bug --- src/zenml/orchestrators/step_runner.py | 31 +++++++++++++------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 1cf9c661700..02c8a45459b 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -623,16 +623,17 @@ def _link_artifacts_to_model( """ from zenml.model.artifact_config import ArtifactConfig + context = get_step_context() try: - model_config_from_context = get_step_context().model_config + model_config_from_context = context.model_config except StepContextError: model_config_from_context = None for artifact_name in artifact_ids: artifact_uuid = artifact_ids[artifact_name] - artifact_config_ = ( - get_step_context()._get_output(artifact_name).artifact_config - ) + artifact_config_ = context._get_output( + artifact_name + ).artifact_config if artifact_config_ is None: if model_config_from_context is not None: artifact_config_ = ArtifactConfig( @@ -647,6 +648,7 @@ def _link_artifacts_to_model( artifact_config_ = artifact_config_.copy() if artifact_config_ is not None: + model_config = None if model_config_from_context is None: if artifact_config_.model_name is None: logger.warning( @@ -664,17 +666,16 @@ def _link_artifacts_to_model( else: model_config = model_config_from_context - artifact_config_.artifact_name = ( - artifact_config_.artifact_name or artifact_name - ) - artifact_config_._pipeline_name = ( - get_step_context().pipeline.name - ) - artifact_config_._step_name = get_step_context().step_run.name - artifact_config_.link_to_model( - artifact_uuid=artifact_uuid, - model_config=model_config, - ) + if model_config: + artifact_config_.artifact_name = ( + artifact_config_.artifact_name or artifact_name + ) + artifact_config_._pipeline_name = context.pipeline.name + artifact_config_._step_name = context.step_run.name + artifact_config_.link_to_model( + artifact_uuid=artifact_uuid, + model_config=model_config, + ) def _get_model_versions_from_artifacts( self, From 68cfd1665ecd52d56d60aca0007eef1a40705bbc Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:04:18 +0200 Subject: [PATCH 30/41] redesign model_config warnings logic --- src/zenml/model/model_config.py | 131 ++++++--------------- tests/unit/model/test_model_config_init.py | 2 - 2 files changed, 37 insertions(+), 96 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 72253e008e5..11c1b63e721 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -39,7 +39,26 @@ class ModelConfig(BaseModel): - """ModelConfig class to pass into pipeline or step to set it into a model context.""" + """ModelConfig class to pass into pipeline or step to set it into a model context. + + name: The name of the model. + license: The license under which the model is created. + description: The description of the model. + audience: The target audience of the model. + use_cases: The use cases of the model. + limitations: The known limitations of the model. + trade_offs: The tradeoffs of the model. + ethics: The ethical implications of the model. + tags: Tags associated with the model. + version: The model version name, number or stage is optional and points model context + to a specific version/stage. If skipped and `create_new_model_version` is False - + latest model version will be used. + version_description: The description of the model version. + create_new_model_version: Whether to create a new model version during execution + save_models_to_registry: Whether to save all ModelArtifacts to Model Registry, + if available in active stack. + delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. + """ name: str license: Optional[str] @@ -56,69 +75,7 @@ class ModelConfig(BaseModel): save_models_to_registry: bool = True delete_new_version_on_failure: bool = True - model: Optional[Any] = None - model_version: Optional[Any] = None - user_not_yet_warned: bool = True - - def __init__( - self, - name: str, - license: Optional[str] = None, - description: Optional[str] = None, - audience: Optional[str] = None, - use_cases: Optional[str] = None, - limitations: Optional[str] = None, - trade_offs: Optional[str] = None, - ethics: Optional[str] = None, - tags: Optional[List[str]] = None, - version: Optional[Union[ModelStages, int, str]] = None, - version_description: Optional[str] = None, - create_new_model_version: bool = False, - save_models_to_registry: bool = True, - delete_new_version_on_failure: bool = True, - **kwargs: Dict[str, Any], - ): - """ModelConfig class to pass into pipeline or step to set it into a model context. - - Args: - name: The name of the model. - license: The license under which the model is created. - description: The description of the model. - audience: The target audience of the model. - use_cases: The use cases of the model. - limitations: The known limitations of the model. - trade_offs: The tradeoffs of the model. - ethics: The ethical implications of the model. - tags: Tags associated with the model. - version: The model version name, number or stage is optional and points model context - to a specific version/stage. If skipped and `create_new_model_version` is False - - latest model version will be used. - version_description: The description of the model version. - create_new_model_version: Whether to create a new model version during execution - save_models_to_registry: Whether to save all ModelArtifacts to Model Registry, - if available in active stack. - delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. - kwargs: Other arguments. - """ - super().__init__( - name=name, - license=license, - description=description, - audience=audience, - use_cases=use_cases, - limitations=limitations, - trade_offs=trade_offs, - ethics=ethics, - tags=tags, - version=version, - version_description=version_description, - create_new_model_version=create_new_model_version, - save_models_to_registry=save_models_to_registry, - delete_new_version_on_failure=delete_new_version_on_failure, - model=kwargs.get("model", None), - model_version=kwargs.get("model_version", None), - user_not_yet_warned=kwargs.get("user_not_yet_warned", True), - ) + suppress_class_validation_warning: bool = False class Config: """Config class.""" @@ -141,18 +98,9 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: create_new_model_version = values.get( "create_new_model_version", False ) - delete_new_version_on_failure = values.get( - "delete_new_version_on_failure", True + suppress_class_validation_warning = values.get( + "suppress_class_validation_warning", False ) - user_not_yet_warned = values.get("user_not_yet_warned", True) - if not delete_new_version_on_failure and not create_new_model_version: - if user_not_yet_warned: - logger.warning( - "Using `delete_new_version_on_failure=False` and `create_new_model_version=False` has no effect." - "Setting `delete_new_version_on_failure` to `True`." - ) - values["delete_new_version_on_failure"] = True - version = values.get("version", None) if create_new_model_version: @@ -174,7 +122,7 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: if str(version).isnumeric(): raise ValueError(misuse_message.format(set="a numeric value")) if version is None: - if user_not_yet_warned: + if not suppress_class_validation_warning: logger.info( "Creation of new model version was requested, but no version name was explicitly provided. " f"Setting `version` to `{RUNNING_MODEL_VERSION}`." @@ -182,16 +130,16 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["version"] = RUNNING_MODEL_VERSION if ( version in [stage.value for stage in ModelStages] - and user_not_yet_warned + and not suppress_class_validation_warning ): logger.info( f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage." ) - if str(version).isnumeric() and user_not_yet_warned: + if str(version).isnumeric() and not suppress_class_validation_warning: logger.info( f"`version` `{version}` is numeric and will be fetched using version number." ) - values["user_not_yet_warned"] = False + values["suppress_class_validation_warning"] = True return values def _validate_config_in_runtime(self) -> None: @@ -227,17 +175,12 @@ def get_or_create_model(self) -> "ModelResponseModel": Returns: The model based on configuration. """ - from zenml.models.model_models import ModelResponseModel - - if self.model is not None: - return ModelResponseModel(**dict(self.model)) - from zenml.client import Client from zenml.models.model_models import ModelRequestModel zenml_client = Client() try: - self.model = zenml_client.get_model(model_name_or_id=self.name) + model = zenml_client.get_model(model_name_or_id=self.name) except KeyError: model_request = ModelRequestModel( name=self.name, @@ -254,13 +197,13 @@ def get_or_create_model(self) -> "ModelResponseModel": ) model_request = ModelRequestModel.parse_obj(model_request) try: - self.model = zenml_client.create_model(model=model_request) + model = zenml_client.create_model(model=model_request) logger.info(f"New model `{self.name}` was created implicitly.") except EntityExistsError: # this is backup logic, if model was created somehow in between get and create calls - self.model = zenml_client.get_model(model_name_or_id=self.name) + model = zenml_client.get_model(model_name_or_id=self.name) - return self.model + return model def _create_model_version( self, model: "ModelResponseModel" @@ -290,14 +233,14 @@ def _create_model_version( model_name_or_id=self.name, model_version_name_or_number_or_id=self.version, ) - self.model_version = mv + model_version = mv except KeyError: - self.model_version = zenml_client.create_model_version( + model_version = zenml_client.create_model_version( model_version=mv_request ) logger.info(f"New model version `{self.version}` was created.") - return self.model_version + return model_version def _get_model_version(self) -> "ModelVersionResponseModel": """This method gets a model version from Model Control Plane. @@ -310,17 +253,17 @@ def _get_model_version(self) -> "ModelVersionResponseModel": zenml_client = Client() if self.version is None: # raise if not found - self.model_version = zenml_client.get_model_version( + model_version = zenml_client.get_model_version( model_name_or_id=self.name ) else: # by version name or stage or number # raise if not found - self.model_version = zenml_client.get_model_version( + model_version = zenml_client.get_model_version( model_name_or_id=self.name, model_version_name_or_number_or_id=self.version, ) - return self.model_version + return model_version def get_or_create_model_version(self) -> "ModelVersionResponseModel": """This method should get or create a model and a model version from Model Control Plane. diff --git a/tests/unit/model/test_model_config_init.py b/tests/unit/model/test_model_config_init.py index ba4b6943c75..adb68efa578 100644 --- a/tests/unit/model/test_model_config_init.py +++ b/tests/unit/model/test_model_config_init.py @@ -9,14 +9,12 @@ @pytest.mark.parametrize( "version_name,create_new_model_version,delete_new_version_on_failure,logger", [ - [None, False, False, "warning"], [None, True, False, "info"], ["staging", False, False, "info"], ["1", False, False, "info"], [1, False, False, "info"], ], ids=[ - "No new version, but recovery", "Default running version", "Pick model by text stage", "Pick model by text version number", From ada92758c712e8c829ec41fca281fdec74dc42c1 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:22:42 +0200 Subject: [PATCH 31/41] renaming --- src/zenml/cli/model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index 862acb42bd6..73ebe0dc0a1 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -451,7 +451,7 @@ def update_model_version( def _print_artifacts_links_generic( model_name_or_id: str, model_version_name_or_number_or_id: str, - only_artifacts: bool = False, + only_artifact_objects: bool = False, only_deployments: bool = False, only_model_objects: bool = False, **kwargs: Any, @@ -461,7 +461,7 @@ def _print_artifacts_links_generic( Args: model_name_or_id: The ID or name of the model containing version. model_version_name_or_number_or_id: The name, number or ID of the model version. - only_artifacts: If set, only print artifacts. + only_artifact_objects: If set, only print artifacts. only_deployments: If set, only print deployments. only_model_objects: If set, only print model objects. **kwargs: Keyword arguments to filter models. @@ -474,14 +474,14 @@ def _print_artifacts_links_generic( ) type_ = ( "artifacts" - if only_artifacts + if only_artifact_objects else "deployments" if only_deployments else "model objects" ) if ( - (only_artifacts and not model_version.artifact_object_ids) + (only_artifact_objects and not model_version.artifact_object_ids) or (only_deployments and not model_version.deployment_ids) or (only_model_objects and not model_version.model_object_ids) ): @@ -496,7 +496,7 @@ def _print_artifacts_links_generic( ModelVersionArtifactFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, - only_artifacts=only_artifacts, + only_artifacts=only_artifact_objects, only_deployments=only_deployments, only_model_objects=only_model_objects, **kwargs, @@ -539,7 +539,7 @@ def list_model_version_artifacts( _print_artifacts_links_generic( model_name_or_id=model_name_or_id, model_version_name_or_number_or_id=model_version_name_or_number_or_id, - only_artifacts=True, + only_artifact_objects=True, **kwargs, ) From 76c39de8905975f657fd16a5d4c6f3d5f107a931 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:59:04 +0200 Subject: [PATCH 32/41] fix bug --- src/zenml/orchestrators/step_runner.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 02c8a45459b..0338cc47263 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -655,16 +655,16 @@ def _link_artifacts_to_model( "No model context found, unable to auto-link artifacts." ) return - else: - if artifact_config_.model_name is not None: - from zenml.model.model_config import ModelConfig - model_config = ModelConfig( - name=artifact_config_.model_name, - version=artifact_config_.model_version, - ) - else: - model_config = model_config_from_context + if artifact_config_.model_name is not None: + from zenml.model.model_config import ModelConfig + + model_config = ModelConfig( + name=artifact_config_.model_name, + version=artifact_config_.model_version, + ) + else: + model_config = model_config_from_context if model_config: artifact_config_.artifact_name = ( From 287b80ed4d1a54e1cf8568f2c33d459651430907 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:59:15 +0200 Subject: [PATCH 33/41] remove outdated test --- .../functional/model/test_model_config.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/integration/functional/model/test_model_config.py b/tests/integration/functional/model/test_model_config.py index 29709f1a162..3a38052af91 100644 --- a/tests/integration/functional/model/test_model_config.py +++ b/tests/integration/functional/model/test_model_config.py @@ -171,19 +171,6 @@ def test_init_create_new_version_with_version_fails(self): assert mc.create_new_model_version assert mc.version == RUNNING_MODEL_VERSION - def test_init_recovery_without_create_new_version_warns(self): - """Test that use of `recovery` warn on `create_new_model_version` set to False.""" - with mock.patch("zenml.model.model_config.logger.warning") as logger: - ModelConfig(name=MODEL_NAME, delete_new_version_on_failure=False) - logger.assert_called_once() - with mock.patch("zenml.model.model_config.logger.warning") as logger: - ModelConfig( - name=MODEL_NAME, - delete_new_version_on_failure=False, - create_new_model_version=True, - ) - logger.assert_not_called() - def test_init_stage_logic(self): """Test that if version is set to string contained in ModelStages user is informed about it.""" with mock.patch("zenml.model.model_config.logger.info") as logger: From 6c88495b2fbc1e6afbc602b66ab30222b4fdb1c0 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:59:36 +0200 Subject: [PATCH 34/41] improve validator warnings --- src/zenml/config/compiler.py | 31 +++++++++++++++++++---------- src/zenml/model/model_config.py | 26 ++++++++++++++++-------- src/zenml/new/pipelines/pipeline.py | 30 +++++++++++++++++----------- 3 files changed, 56 insertions(+), 31 deletions(-) diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index 28cdc68707d..c099351b7a7 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -84,6 +84,8 @@ def compile( Returns: The compiled pipeline deployment and spec """ + from zenml.model.model_config import ModelConfig + logger.debug("Compiling pipeline `%s`.", pipeline.name) # Copy the pipeline before we apply any run-level configurations, so # we don't mess with the pipeline object/step objects in any way @@ -101,7 +103,8 @@ def compile( stack=stack, ) with pipeline.__suppress_configure_warnings__(): - pipeline.configure(settings=pipeline_settings, merge=False) + with ModelConfig.__suppress_validation_warnings__(): + pipeline.configure(settings=pipeline_settings, merge=False) settings_to_passdown = { key: settings @@ -196,16 +199,19 @@ def _apply_run_configuration( KeyError: If the run configuration contains options for a non-existent step. """ + from zenml.model.model_config import ModelConfig + with pipeline.__suppress_configure_warnings__(): - pipeline.configure( - enable_cache=config.enable_cache, - enable_artifact_metadata=config.enable_artifact_metadata, - enable_artifact_visualization=config.enable_artifact_visualization, - enable_step_logs=config.enable_step_logs, - settings=config.settings, - extra=config.extra, - model_config=config.model_config, - ) + with ModelConfig.__suppress_validation_warnings__(): + pipeline.configure( + enable_cache=config.enable_cache, + enable_artifact_metadata=config.enable_artifact_metadata, + enable_artifact_visualization=config.enable_artifact_visualization, + enable_step_logs=config.enable_step_logs, + settings=config.settings, + extra=config.extra, + model_config=config.model_config, + ) for invocation_id in config.steps: if invocation_id not in pipeline.invocations: @@ -246,6 +252,8 @@ def _apply_stack_default_settings( pipeline: The pipeline to which to apply the default settings. stack: The stack containing potential default settings. """ + from zenml.model.model_config import ModelConfig + pipeline_settings = pipeline.configuration.settings for component in stack.components.values(): @@ -266,7 +274,8 @@ def _apply_stack_default_settings( pipeline_settings[settings_key] = default_settings with pipeline.__suppress_configure_warnings__(): - pipeline.configure(settings=pipeline_settings, merge=False) + with ModelConfig.__suppress_validation_warnings__(): + pipeline.configure(settings=pipeline_settings, merge=False) def _get_default_settings( self, diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 11c1b63e721..242ca260e12 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -13,10 +13,13 @@ # permissions and limitations under the License. """ModelConfig user facing interface to pass into pipeline or step.""" +from contextlib import contextmanager from typing import ( TYPE_CHECKING, Any, + ClassVar, Dict, + Iterator, List, Optional, Union, @@ -75,13 +78,21 @@ class ModelConfig(BaseModel): save_models_to_registry: bool = True delete_new_version_on_failure: bool = True - suppress_class_validation_warning: bool = False + __SUPPRESS_VALIDATION_WARNINGS__: ClassVar[bool] = False class Config: """Config class.""" smart_union = True + @classmethod + @contextmanager + def __suppress_validation_warnings__(cls) -> Iterator[Any]: + """Suppress validation warning.""" + cls.__SUPPRESS_VALIDATION_WARNINGS__ = True + yield + cls.__SUPPRESS_VALIDATION_WARNINGS__ = False + @root_validator def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate all in one. @@ -98,9 +109,6 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: create_new_model_version = values.get( "create_new_model_version", False ) - suppress_class_validation_warning = values.get( - "suppress_class_validation_warning", False - ) version = values.get("version", None) if create_new_model_version: @@ -122,7 +130,7 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: if str(version).isnumeric(): raise ValueError(misuse_message.format(set="a numeric value")) if version is None: - if not suppress_class_validation_warning: + if not cls.__SUPPRESS_VALIDATION_WARNINGS__: logger.info( "Creation of new model version was requested, but no version name was explicitly provided. " f"Setting `version` to `{RUNNING_MODEL_VERSION}`." @@ -130,16 +138,18 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["version"] = RUNNING_MODEL_VERSION if ( version in [stage.value for stage in ModelStages] - and not suppress_class_validation_warning + and not cls.__SUPPRESS_VALIDATION_WARNINGS__ ): logger.info( f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage." ) - if str(version).isnumeric() and not suppress_class_validation_warning: + if ( + str(version).isnumeric() + and not cls.__SUPPRESS_VALIDATION_WARNINGS__ + ): logger.info( f"`version` `{version}` is numeric and will be fetched using version number." ) - values["suppress_class_validation_warning"] = True return values def _validate_config_in_runtime(self) -> None: diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 85378b35730..7eb18bc9cbf 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -148,6 +148,8 @@ def __init__( function (e.g. `module.my_function`). model_config: Model(Version) configuration for this step as `ModelConfig` instance. """ + from zenml.model.model_config import ModelConfig + self._invocations: Dict[str, StepInvocation] = {} self._run_args: Dict[str, Any] = {} @@ -156,17 +158,18 @@ def __init__( ) self._from_config_file: Dict[str, Any] = {} with self.__suppress_configure_warnings__(): - self.configure( - enable_cache=enable_cache, - enable_artifact_metadata=enable_artifact_metadata, - enable_artifact_visualization=enable_artifact_visualization, - enable_step_logs=enable_step_logs, - settings=settings, - extra=extra, - on_failure=on_failure, - on_success=on_success, - model_config=model_config, - ) + with ModelConfig.__suppress_validation_warnings__(): + self.configure( + enable_cache=enable_cache, + enable_artifact_metadata=enable_artifact_metadata, + enable_artifact_visualization=enable_artifact_visualization, + enable_step_logs=enable_step_logs, + settings=settings, + extra=extra, + on_failure=on_failure, + on_success=on_success, + model_config=model_config, + ) self.entrypoint = entrypoint self._parameters: Dict[str, Any] = {} @@ -1432,6 +1435,8 @@ def with_options( Returns: The copied pipeline instance. """ + from zenml.model.model_config import ModelConfig + pipeline_copy = self.copy() pipeline_copy._parse_config_file( @@ -1443,7 +1448,8 @@ def with_options( ) with pipeline_copy.__suppress_configure_warnings__(): - pipeline_copy.configure(**pipeline_copy._from_config_file) + with ModelConfig.__suppress_validation_warnings__(): + pipeline_copy.configure(**pipeline_copy._from_config_file) run_args = dict_utils.remove_none_values( { From 4f52d76a1ef100aab634dba95a1f6a7dab3335d2 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 15:03:28 +0200 Subject: [PATCH 35/41] improve validator warnings --- src/zenml/new/pipelines/pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 7eb18bc9cbf..eb94f566c94 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -1128,6 +1128,7 @@ def _compile( """ # Activating the built-in integrations to load all materializers from zenml.integrations.registry import integration_registry + from zenml.model.model_config import ModelConfig integration_registry.activate_integrations() @@ -1142,7 +1143,8 @@ def _compile( update = PipelineRunConfiguration.parse_obj(new_values) # Update with the values in code so they take precedence - run_config = pydantic_utils.update_model(run_config, update=update) + with ModelConfig.__suppress_validation_warnings__(): + run_config = pydantic_utils.update_model(run_config, update=update) deployment, pipeline_spec = Compiler().compile( pipeline=self, From 098207fada64a11549749cbbc98e3b5edf63dde2 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 15:09:33 +0200 Subject: [PATCH 36/41] improve validator warnings --- src/zenml/config/compiler.py | 28 ++++++++++--------------- src/zenml/model/model_config.py | 27 ++++++++---------------- src/zenml/new/pipelines/pipeline.py | 32 ++++++++++++----------------- 3 files changed, 33 insertions(+), 54 deletions(-) diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index c099351b7a7..ea495d7b0b7 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -84,7 +84,6 @@ def compile( Returns: The compiled pipeline deployment and spec """ - from zenml.model.model_config import ModelConfig logger.debug("Compiling pipeline `%s`.", pipeline.name) # Copy the pipeline before we apply any run-level configurations, so @@ -103,8 +102,7 @@ def compile( stack=stack, ) with pipeline.__suppress_configure_warnings__(): - with ModelConfig.__suppress_validation_warnings__(): - pipeline.configure(settings=pipeline_settings, merge=False) + pipeline.configure(settings=pipeline_settings, merge=False) settings_to_passdown = { key: settings @@ -199,19 +197,17 @@ def _apply_run_configuration( KeyError: If the run configuration contains options for a non-existent step. """ - from zenml.model.model_config import ModelConfig with pipeline.__suppress_configure_warnings__(): - with ModelConfig.__suppress_validation_warnings__(): - pipeline.configure( - enable_cache=config.enable_cache, - enable_artifact_metadata=config.enable_artifact_metadata, - enable_artifact_visualization=config.enable_artifact_visualization, - enable_step_logs=config.enable_step_logs, - settings=config.settings, - extra=config.extra, - model_config=config.model_config, - ) + pipeline.configure( + enable_cache=config.enable_cache, + enable_artifact_metadata=config.enable_artifact_metadata, + enable_artifact_visualization=config.enable_artifact_visualization, + enable_step_logs=config.enable_step_logs, + settings=config.settings, + extra=config.extra, + model_config=config.model_config, + ) for invocation_id in config.steps: if invocation_id not in pipeline.invocations: @@ -252,7 +248,6 @@ def _apply_stack_default_settings( pipeline: The pipeline to which to apply the default settings. stack: The stack containing potential default settings. """ - from zenml.model.model_config import ModelConfig pipeline_settings = pipeline.configuration.settings @@ -274,8 +269,7 @@ def _apply_stack_default_settings( pipeline_settings[settings_key] = default_settings with pipeline.__suppress_configure_warnings__(): - with ModelConfig.__suppress_validation_warnings__(): - pipeline.configure(settings=pipeline_settings, merge=False) + pipeline.configure(settings=pipeline_settings, merge=False) def _get_default_settings( self, diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 242ca260e12..3b0bf0b4a78 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -13,13 +13,10 @@ # permissions and limitations under the License. """ModelConfig user facing interface to pass into pipeline or step.""" -from contextlib import contextmanager from typing import ( TYPE_CHECKING, Any, - ClassVar, Dict, - Iterator, List, Optional, Union, @@ -78,21 +75,13 @@ class ModelConfig(BaseModel): save_models_to_registry: bool = True delete_new_version_on_failure: bool = True - __SUPPRESS_VALIDATION_WARNINGS__: ClassVar[bool] = False + suppress_class_validation_warnings: bool = False class Config: """Config class.""" smart_union = True - @classmethod - @contextmanager - def __suppress_validation_warnings__(cls) -> Iterator[Any]: - """Suppress validation warning.""" - cls.__SUPPRESS_VALIDATION_WARNINGS__ = True - yield - cls.__SUPPRESS_VALIDATION_WARNINGS__ = False - @root_validator def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate all in one. @@ -109,6 +98,9 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: create_new_model_version = values.get( "create_new_model_version", False ) + suppress_class_validation_warnings = values.get( + "suppress_class_validation_warnings", False + ) version = values.get("version", None) if create_new_model_version: @@ -130,7 +122,7 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: if str(version).isnumeric(): raise ValueError(misuse_message.format(set="a numeric value")) if version is None: - if not cls.__SUPPRESS_VALIDATION_WARNINGS__: + if not suppress_class_validation_warnings: logger.info( "Creation of new model version was requested, but no version name was explicitly provided. " f"Setting `version` to `{RUNNING_MODEL_VERSION}`." @@ -138,18 +130,17 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["version"] = RUNNING_MODEL_VERSION if ( version in [stage.value for stage in ModelStages] - and not cls.__SUPPRESS_VALIDATION_WARNINGS__ + and not suppress_class_validation_warnings ): logger.info( f"`version` `{version}` matches one of the possible `ModelStages` and will be fetched using stage." ) - if ( - str(version).isnumeric() - and not cls.__SUPPRESS_VALIDATION_WARNINGS__ - ): + if str(version).isnumeric() and not suppress_class_validation_warnings: logger.info( f"`version` `{version}` is numeric and will be fetched using version number." ) + values["suppress_class_validation_warnings"] = True + cls.__fields_set__.add("suppress_class_validation_warnings") return values def _validate_config_in_runtime(self) -> None: diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index eb94f566c94..4efb05ec38d 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -148,7 +148,6 @@ def __init__( function (e.g. `module.my_function`). model_config: Model(Version) configuration for this step as `ModelConfig` instance. """ - from zenml.model.model_config import ModelConfig self._invocations: Dict[str, StepInvocation] = {} self._run_args: Dict[str, Any] = {} @@ -158,18 +157,17 @@ def __init__( ) self._from_config_file: Dict[str, Any] = {} with self.__suppress_configure_warnings__(): - with ModelConfig.__suppress_validation_warnings__(): - self.configure( - enable_cache=enable_cache, - enable_artifact_metadata=enable_artifact_metadata, - enable_artifact_visualization=enable_artifact_visualization, - enable_step_logs=enable_step_logs, - settings=settings, - extra=extra, - on_failure=on_failure, - on_success=on_success, - model_config=model_config, - ) + self.configure( + enable_cache=enable_cache, + enable_artifact_metadata=enable_artifact_metadata, + enable_artifact_visualization=enable_artifact_visualization, + enable_step_logs=enable_step_logs, + settings=settings, + extra=extra, + on_failure=on_failure, + on_success=on_success, + model_config=model_config, + ) self.entrypoint = entrypoint self._parameters: Dict[str, Any] = {} @@ -1128,7 +1126,6 @@ def _compile( """ # Activating the built-in integrations to load all materializers from zenml.integrations.registry import integration_registry - from zenml.model.model_config import ModelConfig integration_registry.activate_integrations() @@ -1143,8 +1140,7 @@ def _compile( update = PipelineRunConfiguration.parse_obj(new_values) # Update with the values in code so they take precedence - with ModelConfig.__suppress_validation_warnings__(): - run_config = pydantic_utils.update_model(run_config, update=update) + run_config = pydantic_utils.update_model(run_config, update=update) deployment, pipeline_spec = Compiler().compile( pipeline=self, @@ -1437,7 +1433,6 @@ def with_options( Returns: The copied pipeline instance. """ - from zenml.model.model_config import ModelConfig pipeline_copy = self.copy() @@ -1450,8 +1445,7 @@ def with_options( ) with pipeline_copy.__suppress_configure_warnings__(): - with ModelConfig.__suppress_validation_warnings__(): - pipeline_copy.configure(**pipeline_copy._from_config_file) + pipeline_copy.configure(**pipeline_copy._from_config_file) run_args = dict_utils.remove_none_values( { From 7e002310f66a11175f1926498ac9e0f41b4358a7 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 15:48:23 +0200 Subject: [PATCH 37/41] improve validator warnings --- src/zenml/model/model_config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 3b0bf0b4a78..4fc00954e35 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -82,7 +82,7 @@ class Config: smart_union = True - @root_validator + @root_validator(pre=True) def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate all in one. @@ -140,7 +140,6 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: f"`version` `{version}` is numeric and will be fetched using version number." ) values["suppress_class_validation_warnings"] = True - cls.__fields_set__.add("suppress_class_validation_warnings") return values def _validate_config_in_runtime(self) -> None: From f0ac4bff4f66e7609b833b1f7bf4c91c63bf6d3a Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 16:14:11 +0200 Subject: [PATCH 38/41] resolve alembic branches --- .../4f66af55fbb9_rename_model_config_model_to_model_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py b/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py index 4aeabb23532..13e3b296a2c 100644 --- a/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py +++ b/src/zenml/zen_stores/migrations/versions/4f66af55fbb9_rename_model_config_model_to_model_.py @@ -11,7 +11,7 @@ # revision identifiers, used by Alembic. revision = "4f66af55fbb9" -down_revision = "0.45.2" +down_revision = "729263e47b55" branch_labels = None depends_on = None From cc7064315cdd44fd3a30465e65f3227d687bf98b Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 16:23:36 +0200 Subject: [PATCH 39/41] remove excessive init --- src/zenml/orchestrators/step_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 0338cc47263..d9385ab3a6b 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -637,8 +637,6 @@ def _link_artifacts_to_model( if artifact_config_ is None: if model_config_from_context is not None: artifact_config_ = ArtifactConfig( - model_name=model_config_from_context.name, - model_version=model_config_from_context.version, artifact_name=artifact_name, ) logger.info( From 1a3fc4f1e48cc802bf14cdade7426f260a0a8bee Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 16:27:05 +0200 Subject: [PATCH 40/41] fix runs cli --- src/zenml/cli/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index 73ebe0dc0a1..dc5aa2a8e78 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -622,7 +622,9 @@ def list_model_version_pipeline_runs( """ model_version = Client().get_model_version( model_name_or_id=model_name_or_id, - model_version_name_or_number_or_id=model_version_name_or_number_or_id, + model_version_name_or_number_or_id=None + if model_version_name_or_number_or_id == "0" + else model_version_name_or_number_or_id, ) if not model_version.pipeline_run_ids: From 6372de0b951eee6a14f072c60231918f9c96569c Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 18 Oct 2023 16:41:03 +0200 Subject: [PATCH 41/41] lint --- src/zenml/config/compiler.py | 3 --- src/zenml/new/pipelines/pipeline.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index ea495d7b0b7..28cdc68707d 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -84,7 +84,6 @@ def compile( Returns: The compiled pipeline deployment and spec """ - logger.debug("Compiling pipeline `%s`.", pipeline.name) # Copy the pipeline before we apply any run-level configurations, so # we don't mess with the pipeline object/step objects in any way @@ -197,7 +196,6 @@ def _apply_run_configuration( KeyError: If the run configuration contains options for a non-existent step. """ - with pipeline.__suppress_configure_warnings__(): pipeline.configure( enable_cache=config.enable_cache, @@ -248,7 +246,6 @@ def _apply_stack_default_settings( pipeline: The pipeline to which to apply the default settings. stack: The stack containing potential default settings. """ - pipeline_settings = pipeline.configuration.settings for component in stack.components.values(): diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 4efb05ec38d..85378b35730 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -148,7 +148,6 @@ def __init__( function (e.g. `module.my_function`). model_config: Model(Version) configuration for this step as `ModelConfig` instance. """ - self._invocations: Dict[str, StepInvocation] = {} self._run_args: Dict[str, Any] = {} @@ -1433,7 +1432,6 @@ def with_options( Returns: The copied pipeline instance. """ - pipeline_copy = self.copy() pipeline_copy._parse_config_file(