Skip to content

Commit

Permalink
add ModelConfig to step deco
Browse files Browse the repository at this point in the history
  • Loading branch information
avishniakov committed Sep 18, 2023
1 parent 2b3aec5 commit ba9fd6a
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/zenml/config/step_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from zenml.config.source import Source, convert_source_validator
from zenml.config.strict_base_model import StrictBaseModel
from zenml.logger import get_logger
from zenml.model import ModelConfig
from zenml.utils import deprecation_utils

if TYPE_CHECKING:
Expand Down Expand Up @@ -131,6 +132,7 @@ class StepConfigurationUpdate(StrictBaseModel):
extra: Dict[str, Any] = {}
failure_hook_source: Optional[Source] = None
success_hook_source: Optional[Source] = None
model: Optional["ModelConfig"] = None

outputs: Mapping[str, PartialArtifactConfiguration] = {}

Expand Down
27 changes: 22 additions & 5 deletions src/zenml/model/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,15 @@

from pydantic import Field, validator

from zenml.client import Client
from zenml.logger import get_logger
from zenml.model.model_stages import ModelStages
from zenml.models import (
from zenml.models.model_models import (
ModelBaseModel,
ModelRequestModel,
ModelResponseModel,
ModelVersionRequestModel,
ModelVersionResponseModel,
)

logger = get_logger(__name__)


class ModelConfig(ModelBaseModel):
"""ModelConfig class to pass into pipeline or step to set it into a model context.
Expand Down Expand Up @@ -70,6 +66,10 @@ def _validate_recovery(
) -> bool:
if recovery:
if not values.get("create_new_model_version", False):
from zenml.logger import get_logger

logger = get_logger(__name__)

logger.warning(
"Using `recovery` flag without `create_new_model_version=True` makes no effect"
)
Expand All @@ -80,6 +80,10 @@ def _validate_version(
cls, version: Union[str, ModelStages]
) -> Union[str, ModelStages]:
if isinstance(version, str) and version in ModelStages._members():
from zenml.logger import get_logger

logger = get_logger(__name__)

logger.warning(
f"Version `{version}` matches one of the possible `ModelStages`, if you want to fetch "
"model version by its' stage make sure to pass in instance of `ModelStages`."
Expand All @@ -99,6 +103,8 @@ def _get_request_params(
],
**kwargs: Any,
) -> Dict[str, Any]:
from zenml.client import Client

zenml_client = Client()
request_params = {
k: v
Expand All @@ -122,12 +128,18 @@ def get_or_create_model(self) -> ModelResponseModel:
Returns:
The model based on configuration.
"""
from zenml.client import Client

zenml_client = Client()
try:
model = zenml_client.zen_store.get_model(
model_name_or_id=self.name
)
except KeyError:
from zenml.logger import get_logger

logger = get_logger(__name__)

model_request = ModelRequestModel.parse_obj(
self._get_request_params(ModelRequestModel)
)
Expand All @@ -149,6 +161,11 @@ def _get_or_create_model_version(
Returns:
The model version based on configuration.
"""
from zenml.client import Client
from zenml.logger import get_logger

logger = get_logger(__name__)

zenml_client = Client()
# if specific version requested
if not self.create_new_model_version:
Expand Down
5 changes: 5 additions & 0 deletions src/zenml/new/steps/step_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from zenml.config.base_settings import SettingsOrDict
from zenml.config.source import Source
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.model import ModelConfig
from zenml.steps import BaseStep

MaterializerClassOrSource = Union[str, Source, Type[BaseMaterializer]]
Expand Down Expand Up @@ -66,6 +67,7 @@ def step(
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["ModelConfig"] = None,
) -> Callable[["F"], "BaseStep"]:
...

Expand All @@ -85,6 +87,7 @@ def step(
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["ModelConfig"] = None,
) -> Union["BaseStep", Callable[["F"], "BaseStep"]]:
"""Decorator to create a ZenML step.
Expand Down Expand Up @@ -114,6 +117,7 @@ def step(
on_success: Callback function in event of success of the step. Can be a
function with no arguments, or a source path to such a function
(e.g. `module.my_function`).
model: Model(Version) configuration for this step as `ModelConfig` instance.
Returns:
The step instance.
Expand Down Expand Up @@ -145,6 +149,7 @@ def inner_decorator(func: "F") -> "BaseStep":
extra=extra,
on_failure=on_failure,
on_success=on_success,
model=model,
)

return step_instance
Expand Down
17 changes: 16 additions & 1 deletion src/zenml/steps/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

if TYPE_CHECKING:
from zenml.config.base_settings import SettingsOrDict
from zenml.model import ModelConfig

ParametersOrDict = Union["BaseParameters", Dict[str, Any]]
MaterializerClassOrSource = Union[str, Source, Type["BaseMaterializer"]]
Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["ModelConfig"] = None,
**kwargs: Any,
) -> None:
"""Initializes a step.
Expand Down Expand Up @@ -161,6 +163,7 @@ def __init__(
on_success: Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. `module.my_function`).
model: Model(Version) configuration for this step as `ModelConfig` instance.
**kwargs: Keyword arguments passed to the step.
"""
self._upstream_steps: Set["BaseStep"] = set()
Expand Down Expand Up @@ -200,12 +203,17 @@ def __init__(
if enable_artifact_visualization is not False
else "disabled",
)

logger.debug(
"Step '%s': logs %s.",
name,
"enabled" if enable_step_logs is not False else "disabled",
)
if model is not None:
logger.debug(
"Step '%s': Is in Model context %s.",
name,
{"model": model.name, "version": model.version},
)

self._configuration = PartialStepConfiguration(
name=name,
Expand All @@ -223,6 +231,7 @@ def __init__(
extra=extra,
on_failure=on_failure,
on_success=on_success,
model=model,
)
self._verify_and_apply_init_params(*args, **kwargs)

Expand Down Expand Up @@ -631,6 +640,7 @@ def configure(
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["ModelConfig"] = None,
merge: bool = True,
) -> T:
"""Configures the step.
Expand Down Expand Up @@ -668,6 +678,7 @@ def configure(
on_success: Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. `module.my_function`).
model: Model(Version) configuration for this step as `ModelConfig` instance.
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
Expand Down Expand Up @@ -740,6 +751,7 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]:
"extra": extra,
"failure_hook_source": failure_hook_source,
"success_hook_source": success_hook_source,
"model": model,
}
)
config = StepConfigurationUpdate(**values)
Expand All @@ -762,6 +774,7 @@ def with_options(
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["ModelConfig"] = None,
merge: bool = True,
) -> "BaseStep":
"""Copies the step and applies the given configurations.
Expand All @@ -788,6 +801,7 @@ def with_options(
on_success: Callback function in event of success of the step. Can
be a function with no arguments, or a source path to such a
function (e.g. `module.my_function`).
model: Model(Version) configuration for this step as `ModelConfig` instance.
merge: If `True`, will merge the given dictionary configurations
like `parameters` and `settings` with existing
configurations. If `False` the given configurations will
Expand All @@ -811,6 +825,7 @@ def with_options(
extra=extra,
on_failure=on_failure,
on_success=on_success,
model=model,
merge=merge,
)
return step_copy
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/steps/step_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from zenml.config.base_settings import SettingsOrDict
from zenml.config.source import Source
from zenml.materializers.base_materializer import BaseMaterializer
from zenml.model import ModelConfig

MaterializerClassOrSource = Union[str, "Source", Type["BaseMaterializer"]]
HookSpecification = Union[str, "Source", FunctionType]
Expand Down Expand Up @@ -106,6 +107,7 @@ def step(
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["ModelConfig"] = None,
) -> Callable[[F], Type[BaseStep]]:
...

Expand All @@ -125,6 +127,7 @@ def step(
extra: Optional[Dict[str, Any]] = None,
on_failure: Optional["HookSpecification"] = None,
on_success: Optional["HookSpecification"] = None,
model: Optional["ModelConfig"] = None,
) -> Union[Type[BaseStep], Callable[[F], Type[BaseStep]]]:
"""Outer decorator function for the creation of a ZenML step.
Expand Down Expand Up @@ -157,6 +160,7 @@ def step(
on_success: Callback function in event of success of the step. Can be a
function with no arguments, or a source path to such a function
(e.g. `module.my_function`).
model: Model(Version) configuration for this step as `ModelConfig` instance.
Returns:
The inner decorator which creates the step class based on the
Expand Down

0 comments on commit ba9fd6a

Please sign in to comment.