Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase reuse of ModelConfig #1954

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7e4171a
flatten `ModelConfig` code
avishniakov Oct 17, 2023
82d69d3
revert config names for backward compatibility
avishniakov Oct 17, 2023
f9ba41c
revert utils change
avishniakov Oct 17, 2023
8e6f82a
fix ExternalArtifact
avishniakov Oct 17, 2023
6247d03
backward compatibility
avishniakov Oct 17, 2023
90efe1b
`model_config_model` -> `model_config`
avishniakov Oct 17, 2023
7056f96
cleanup after rename
avishniakov Oct 17, 2023
33a0bc0
reuse model config in artifact config
avishniakov Oct 17, 2023
ab9d117
warm up all model configs on pipeline level
avishniakov Oct 17, 2023
84c756f
avoid multi warning in ModelConfig
avishniakov Oct 17, 2023
ffcdc34
properly pass values to validator
avishniakov Oct 17, 2023
7f00404
avoid instantiate in parse config file
avishniakov Oct 17, 2023
cc8cd01
avoid instantiate in parse config file
avishniakov Oct 17, 2023
9e7fe6c
remove debug statements
avishniakov Oct 17, 2023
f5bbdc0
update misleading logging
avishniakov Oct 17, 2023
019059f
improve logging of cached linking
avishniakov Oct 17, 2023
b177f4c
allow latest version in cli links listing
avishniakov Oct 17, 2023
468c999
Update src/zenml/model/model_config.py
avishniakov Oct 17, 2023
14b3ec4
Update src/zenml/model/model_config.py
avishniakov Oct 17, 2023
b469058
Update src/zenml/model/model_config.py
avishniakov Oct 17, 2023
ce5ee89
Update src/zenml/model/model_config.py
avishniakov Oct 17, 2023
397766e
unique test file name
avishniakov Oct 17, 2023
eee1052
ethic->ethics
avishniakov Oct 17, 2023
53f254d
update `step_run`
avishniakov Oct 18, 2023
b117128
update logging mock in tests
avishniakov Oct 18, 2023
fe8b5b5
extend test case to link from cache
avishniakov Oct 18, 2023
7aeb144
fixing bugs from tests
avishniakov Oct 18, 2023
42b7b4c
linting
avishniakov Oct 18, 2023
a6f53a1
fix bug
avishniakov Oct 18, 2023
68cfd16
redesign model_config warnings logic
avishniakov Oct 18, 2023
ada9275
renaming
avishniakov Oct 18, 2023
6fc148a
Merge branch 'develop' into feature/OSS-2510-reduce-the-number-of-re-…
avishniakov Oct 18, 2023
76c39de
fix bug
avishniakov Oct 18, 2023
287b80e
remove outdated test
avishniakov Oct 18, 2023
6c88495
improve validator warnings
avishniakov Oct 18, 2023
4f52d76
improve validator warnings
avishniakov Oct 18, 2023
098207f
improve validator warnings
avishniakov Oct 18, 2023
7e00231
improve validator warnings
avishniakov Oct 18, 2023
42e103f
Merge branch 'develop' into feature/OSS-2510-reduce-the-number-of-re-…
avishniakov Oct 18, 2023
66caf18
Merge branch 'develop' into feature/OSS-2510-reduce-the-number-of-re-…
avishniakov Oct 18, 2023
f0ac4bf
resolve alembic branches
avishniakov Oct 18, 2023
cc70643
remove excessive init
avishniakov Oct 18, 2023
1a3fc4f
fix runs cli
avishniakov Oct 18, 2023
6372de0
lint
avishniakov Oct 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/zenml/artifacts/external_artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def _get_artifact_from_model(
RuntimeError: If `model_artifact_name` is set, but `model_name` is empty and
model configuration is missing in @step and @pipeline.
"""
from zenml.model.model_config import ModelConfig

if self.model_name is None:
if model_config is None:
raise RuntimeError(
Expand All @@ -104,13 +102,18 @@ def _get_artifact_from_model(
)
self.model_name = model_config.name
self.model_version = model_config.version

_model_config = ModelConfig(
name=self.model_name,
version=self.model_version,
suppress_warnings=True,
)
model_version = _model_config._get_model_version()
if (
model_config is None
or self.model_name != model_config.name
or self.model_version != model_config.version
):
from zenml.model.model_config import ModelConfig

model_config = ModelConfig(
name=self.model_name,
version=self.model_version,
)
model_version = model_config._get_model_version()

for artifact_getter in [
model_version.get_artifact_object,
Expand Down
57 changes: 36 additions & 21 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def register_model(
audience=audience,
use_cases=use_cases,
trade_offs=tradeoffs,
ethic=ethical,
ethics=ethical,
limitations=limitations,
tags=tag,
user=Client().active_user.id,
Expand Down Expand Up @@ -265,7 +265,7 @@ def update_model(
audience=audience,
use_cases=use_cases,
trade_offs=tradeoffs,
ethic=ethical,
ethics=ethical,
limitations=limitations,
tags=tag,
user=Client().active_user.id,
Expand Down Expand Up @@ -451,7 +451,7 @@ def update_model_version(
def _print_artifacts_links_generic(
model_name_or_id: str,
model_version_name_or_number_or_id: str,
only_artifacts: bool = False,
only_artifact_objects: bool = False,
only_deployments: bool = False,
only_model_objects: bool = False,
**kwargs: Any,
Expand All @@ -461,36 +461,42 @@ def _print_artifacts_links_generic(
Args:
model_name_or_id: The ID or name of the model containing version.
model_version_name_or_number_or_id: The name, number or ID of the model version.
only_artifacts: If set, only print artifacts.
only_artifact_objects: If set, only print artifacts.
only_deployments: If set, only print deployments.
only_model_objects: If set, only print model objects.
**kwargs: Keyword arguments to filter models.
"""
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
model_version_name_or_number_or_id=None
if model_version_name_or_number_or_id == "0"
else model_version_name_or_number_or_id,
)
type_ = (
"artifacts"
if only_artifact_objects
else "deployments"
if only_deployments
else "model objects"
)

if (
(only_artifacts and not model_version.artifact_object_ids)
(only_artifact_objects and not model_version.artifact_object_ids)
or (only_deployments and not model_version.deployment_ids)
or (only_model_objects and not model_version.model_object_ids)
):
_type = (
"artifacts"
if only_artifacts
else "deployments"
if only_deployments
else "model objects"
)
cli_utils.declare(f"No {_type} linked to the model version found.")
cli_utils.declare(f"No {type_} linked to the model version found.")
return

cli_utils.title(
f"{type_} linked to the model version `{model_version.name}[{model_version.number}]`:"
)

links = Client().list_model_version_artifact_links(
ModelVersionArtifactFilterModel(
model_id=model_version.model.id,
model_version_id=model_version.id,
only_artifacts=only_artifacts,
only_artifacts=only_artifact_objects,
only_deployments=only_deployments,
only_model_objects=only_model_objects,
**kwargs,
Expand All @@ -515,7 +521,7 @@ def _print_artifacts_links_generic(
help="List artifacts linked to a model version.",
)
@click.argument("model_name_or_id")
@click.argument("model_version_name_or_number_or_id")
@click.argument("model_version_name_or_number_or_id", default="0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a bit weird, why can't we use None as default / latest value instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cause it is strings - cli arguments and -1 which I planned to use is treated as a parameter name, so 0 seemed a solid choice. Moreover - user will not specify it directly, it will be set by just skipping this arg.
zenml model version artifacts my_model -> will return artifacts for latest in my_model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it, why can't it be Optional[str] instead? I.e., why would this not work?

@click.argument("model_name_or_id")
@click.argument("model_version_name_or_number_or_id")
def list_model_version_artifacts(
    model_name_or_id: str,
    model_version_name_or_number_or_id: Optional[str] = None,
) -> None:
    ...

def _print_artifacts_links_generic(
    model_name_or_id: str,
    model_version_name_or_number_or_id: Optional[str] = None,
    ...
):
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a click limitation, as I see this. If I set default=None it is treated as no default, making arg mandatory. Changes on the level of functions args are not effective at all - @click.argument is the king here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, click.argument means mandatory, for optional inputs you need to use click.option

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, the option would lead to the argument passing as --model_version_name_or_number_or_id and not just zenml some_command some_argument [some_skipped_optional_argument] or I completely misuse click 🙂

@cli_utils.list_options(ModelVersionArtifactFilterModel)
def list_model_version_artifacts(
model_name_or_id: str,
Expand All @@ -527,12 +533,13 @@ def list_model_version_artifacts(
Args:
model_name_or_id: The ID or name of the model containing version.
model_version_name_or_number_or_id: The name, number or ID of the model version.
Or use 0 for the latest version.
**kwargs: Keyword arguments to filter models.
"""
_print_artifacts_links_generic(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
only_artifacts=True,
only_artifact_objects=True,
**kwargs,
)

Expand All @@ -542,7 +549,7 @@ def list_model_version_artifacts(
help="List model objects linked to a model version.",
)
@click.argument("model_name_or_id")
@click.argument("model_version_name_or_number_or_id")
@click.argument("model_version_name_or_number_or_id", default="0")
@cli_utils.list_options(ModelVersionArtifactFilterModel)
def list_model_version_model_objects(
model_name_or_id: str,
Expand All @@ -554,6 +561,7 @@ def list_model_version_model_objects(
Args:
model_name_or_id: The ID or name of the model containing version.
model_version_name_or_number_or_id: The name, number or ID of the model version.
Or use 0 for the latest version.
**kwargs: Keyword arguments to filter models.
"""
_print_artifacts_links_generic(
Expand All @@ -569,7 +577,7 @@ def list_model_version_model_objects(
help="List deployments linked to a model version.",
)
@click.argument("model_name_or_id")
@click.argument("model_version_name_or_number_or_id")
@click.argument("model_version_name_or_number_or_id", default="0")
@cli_utils.list_options(ModelVersionArtifactFilterModel)
def list_model_version_deployments(
model_name_or_id: str,
Expand All @@ -581,6 +589,7 @@ def list_model_version_deployments(
Args:
model_name_or_id: The ID or name of the model containing version.
model_version_name_or_number_or_id: The name, number or ID of the model version.
Or use 0 for the latest version.
**kwargs: Keyword arguments to filter models.
"""
_print_artifacts_links_generic(
Expand All @@ -596,7 +605,7 @@ def list_model_version_deployments(
help="List pipeline runs of a model version.",
)
@click.argument("model_name_or_id")
@click.argument("model_version_name_or_number_or_id")
@click.argument("model_version_name_or_number_or_id", default="0")
@cli_utils.list_options(ModelVersionPipelineRunFilterModel)
def list_model_version_pipeline_runs(
model_name_or_id: str,
Expand All @@ -608,16 +617,22 @@ def list_model_version_pipeline_runs(
Args:
model_name_or_id: The ID or name of the model containing version.
model_version_name_or_number_or_id: The name, number or ID of the model version.
Or use 0 for the latest version.
**kwargs: Keyword arguments to filter models.
"""
model_version = Client().get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
model_version_name_or_number_or_id=None
if model_version_name_or_number_or_id == "0"
else model_version_name_or_number_or_id,
)

if not model_version.pipeline_run_ids:
cli_utils.declare("No pipeline runs attached to model version found.")
return
cli_utils.title(
f"Pipeline runs linked to the model version `{model_version.name}[{model_version.number}]`:"
)

links = Client().list_model_version_pipeline_run_links(
ModelVersionPipelineRunFilterModel(
Expand Down
21 changes: 2 additions & 19 deletions src/zenml/config/pipeline_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
from zenml.config.constants import DOCKER_SETTINGS_KEY
from zenml.config.source import Source, convert_source_validator
from zenml.config.strict_base_model import StrictBaseModel
from zenml.models.model_base_model import ModelConfigModel
from zenml.model.model_config import ModelConfig

if TYPE_CHECKING:
from zenml.config import DockerSettings
from zenml.model.model_config import ModelConfig

from zenml.config.base_settings import BaseSettings, SettingsOrDict

Expand All @@ -41,28 +40,12 @@ class PipelineConfigurationUpdate(StrictBaseModel):
extra: Dict[str, Any] = {}
failure_hook_source: Optional[Source] = None
success_hook_source: Optional[Source] = None
model_config_model: Optional[ModelConfigModel] = None
model_config: Optional[ModelConfig] = None

_convert_source = convert_source_validator(
"failure_hook_source", "success_hook_source"
)

@property
def model_config(self) -> Optional["ModelConfig"]:
"""Gets a ModelConfig object out of the model config model.

This is a technical circular import resolver.

Returns:
The model config object, if configured.
"""
if self.model_config_model is None:
return None

from zenml.model.model_config import ModelConfig

return ModelConfig.parse_obj(self.model_config_model.dict())


class PipelineConfiguration(PipelineConfigurationUpdate):
"""Pipeline configuration class."""
Expand Down
21 changes: 2 additions & 19 deletions src/zenml/config/step_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@
from zenml.config.source import Source, convert_source_validator
from zenml.config.strict_base_model import StrictBaseModel
from zenml.logger import get_logger
from zenml.models.model_base_model import ModelConfigModel
from zenml.model.model_config import ModelConfig
from zenml.utils import deprecation_utils

if TYPE_CHECKING:
from zenml.config import DockerSettings, ResourceSettings
from zenml.model.model_config import ModelConfig

logger = get_logger(__name__)

Expand Down Expand Up @@ -135,7 +134,7 @@ class StepConfigurationUpdate(StrictBaseModel):
extra: Dict[str, Any] = {}
failure_hook_source: Optional[Source] = None
success_hook_source: Optional[Source] = None
model_config_model: Optional[ModelConfigModel] = None
model_config: Optional[ModelConfig] = None

outputs: Mapping[str, PartialArtifactConfiguration] = {}

Expand All @@ -146,22 +145,6 @@ class StepConfigurationUpdate(StrictBaseModel):
"name"
)

@property
def model_config(self) -> Optional["ModelConfig"]:
"""Gets a ModelConfig object out of the model config model.

This is a technical circular import resolver.

Returns:
The model config object, if configured.
"""
if self.model_config_model is None:
return None

from zenml.model.model_config import ModelConfig

return ModelConfig.parse_obj(self.model_config_model.dict())


class PartialStepConfiguration(StepConfigurationUpdate):
"""Class representing a partial step configuration."""
Expand Down
Loading
Loading