Skip to content

Commit

Permalink
New log_metadata function, new oneof filtering, additional `run_m…
Browse files Browse the repository at this point in the history
…etadata` filtering (#3182)

* Initial commit, nuking all metadata responses and seeing what breaks

* Removed last remnant of LazyLoader

* Reintroducing the lazy loaders.

* Add LazyRunMetadataResponse to EntrypointFunctionDefinition

* Test for lazy loaders works now

* Fixed tests, reformatted

* Use updated template

* Auto-update of Starter template

* Updated more templates

* Fixed failing test

* Fixed step run schemas

* Auto-update of E2E template

* Auto-update of NLP template

* Fixed tests, removed additional .value access

* Further fixing

* Fixed linting issues

* Reformatted

* Linted, formatted and tested again

* Typing

* Maybe fix everything

* Apply some feedback

* new operation

* new log_metadata function

* changes to the base filters

* new filters

* adding log_metadata to __all__

* checkpoint with float casting

* adding tests

* final touches and formatting

* formatting

* moved the utils

* modified log metadata function

* checkpoint

* deprecating the old functions

* linting and final fixes

* better error message

* fixing the client method

* better error message

* consistent creation\

* adjusting tests

* linting

* changes for step metadata

* more test adjustments

* testing unit tests

* linting

* fixing more tests

* fixing more tests

* more test fixes

* fixing the test

* fixing per comments

* added validation, constant error message

* linting

---------

Co-authored-by: AlexejPenner <[email protected]>
Co-authored-by: Andrei Vishniakov <[email protected]>
Co-authored-by: GitHub Actions <[email protected]>
Co-authored-by: Michael Schuster <[email protected]>
  • Loading branch information
5 people authored Nov 12, 2024
1 parent a624ab8 commit 69b6b80
Show file tree
Hide file tree
Showing 20 changed files with 709 additions and 68 deletions.
2 changes: 2 additions & 0 deletions src/zenml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from zenml.pipelines import get_pipeline_context, pipeline
from zenml.steps import step, get_step_context
from zenml.steps.utils import log_step_metadata
from zenml.utils.metadata_utils import log_metadata
from zenml.entrypoints import entrypoint

__all__ = [
Expand All @@ -56,6 +57,7 @@
"get_pipeline_context",
"get_step_context",
"load_artifact",
"log_metadata",
"log_artifact_metadata",
"log_model_metadata",
"log_step_metadata",
Expand Down
6 changes: 5 additions & 1 deletion src/zenml/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def log_artifact_metadata(
not provided, when being called inside a step that produces an
artifact named `artifact_name`, the metadata will be associated to
the corresponding newly created artifact. Or, if not provided when
being called outside of a step, or in a step that does not produce
being called outside a step, or in a step that does not produce
any artifact named `artifact_name`, the metadata will be associated
to the latest version of that artifact.
Expand All @@ -417,6 +417,10 @@ def log_artifact_metadata(
called inside a step with a single output, or, if neither an
artifact nor an output with the given name exists.
"""
logger.warning(
"The `log_artifact_metadata` function is deprecated and will soon be "
"removed. Please use `log_metadata` instead."
)
try:
step_context = get_step_context()
in_step_outputs = (artifact_name in step_context._outputs) or (
Expand Down
7 changes: 6 additions & 1 deletion src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3796,6 +3796,7 @@ def list_pipeline_runs(
templatable: Optional[bool] = None,
tag: Optional[str] = None,
user: Optional[Union[UUID, str]] = None,
run_metadata: Optional[Dict[str, str]] = None,
pipeline: Optional[Union[UUID, str]] = None,
code_repository: Optional[Union[UUID, str]] = None,
model: Optional[Union[UUID, str]] = None,
Expand Down Expand Up @@ -3835,6 +3836,7 @@ def list_pipeline_runs(
templatable: If the runs should be templatable or not.
tag: Tag to filter by.
user: The name/ID of the user to filter by.
run_metadata: The run_metadata of the run to filter by.
pipeline: The name/ID of the pipeline to filter by.
code_repository: Filter by code repository name/ID.
model: Filter by model name/ID.
Expand Down Expand Up @@ -3874,6 +3876,7 @@ def list_pipeline_runs(
tag=tag,
unlisted=unlisted,
user=user,
run_metadata=run_metadata,
pipeline=pipeline,
code_repository=code_repository,
stack=stack,
Expand Down Expand Up @@ -4194,7 +4197,7 @@ def get_artifact_version(
),
)
except RuntimeError:
pass # Cannot link to step run if called outside of a step
pass # Cannot link to step run if called outside a step
return artifact

def list_artifact_versions(
Expand Down Expand Up @@ -4222,6 +4225,7 @@ def list_artifact_versions(
user: Optional[Union[UUID, str]] = None,
model: Optional[Union[UUID, str]] = None,
pipeline_run: Optional[Union[UUID, str]] = None,
run_metadata: Optional[Dict[str, str]] = None,
tag: Optional[str] = None,
hydrate: bool = False,
) -> Page[ArtifactVersionResponse]:
Expand Down Expand Up @@ -4253,6 +4257,7 @@ def list_artifact_versions(
user: Filter by user name or ID.
model: Filter by model name or ID.
pipeline_run: Filter by pipeline run name or ID.
run_metadata: Filter by run metadata.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Expand Down
1 change: 1 addition & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class GenericFilterOps(StrEnum):
CONTAINS = "contains"
STARTSWITH = "startswith"
ENDSWITH = "endswith"
ONEOF = "oneof"
GTE = "gte"
GT = "gt"
LTE = "lte"
Expand Down
5 changes: 5 additions & 0 deletions src/zenml/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def log_model_metadata(
ValueError: If no model name/version is provided and the function is not
called inside a step with configured `model` in decorator.
"""
logger.warning(
"The `log_model_metadata` function is deprecated and will soon be "
"removed. Please use `log_metadata` instead."
)

if model_name and model_version:
from zenml import Model

Expand Down
129 changes: 121 additions & 8 deletions src/zenml/models/v2/base/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Base filter model definitions."""

import json
from abc import ABC, abstractmethod
from datetime import datetime
from typing import (
Expand All @@ -36,7 +37,7 @@
field_validator,
model_validator,
)
from sqlalchemy import asc, desc
from sqlalchemy import Float, and_, asc, cast, desc
from sqlmodel import SQLModel

from zenml.constants import (
Expand All @@ -63,6 +64,11 @@

AnyQuery = TypeVar("AnyQuery", bound=Any)

ONEOF_ERROR = (
"When you are using the 'oneof:' filtering make sure that the "
"provided value is a json formatted list."
)


class Filter(BaseModel, ABC):
"""Filter for all fields.
Expand Down Expand Up @@ -171,8 +177,28 @@ class StrFilter(Filter):
GenericFilterOps.STARTSWITH,
GenericFilterOps.CONTAINS,
GenericFilterOps.ENDSWITH,
GenericFilterOps.ONEOF,
GenericFilterOps.GT,
GenericFilterOps.GTE,
GenericFilterOps.LT,
GenericFilterOps.LTE,
]

@model_validator(mode="after")
def check_value_if_operation_oneof(self) -> "StrFilter":
"""Validator to check if value is a list if oneof operation is used.
Raises:
ValueError: If the value is not a list
Returns:
self
"""
if self.operation == GenericFilterOps.ONEOF:
if not isinstance(self.value, list):
raise ValueError(ONEOF_ERROR)
return self

def generate_query_conditions_from_column(self, column: Any) -> Any:
"""Generate query conditions for a string column.
Expand All @@ -181,6 +207,9 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:
Returns:
A list of query conditions.
Raises:
ValueError: the comparison of the column to a numeric value fails.
"""
if self.operation == GenericFilterOps.CONTAINS:
return column.like(f"%{self.value}%")
Expand All @@ -190,6 +219,40 @@ def generate_query_conditions_from_column(self, column: Any) -> Any:
return column.endswith(f"{self.value}")
if self.operation == GenericFilterOps.NOT_EQUALS:
return column != self.value
if self.operation == GenericFilterOps.ONEOF:
return column.in_(self.value)
if self.operation in {
GenericFilterOps.GT,
GenericFilterOps.LT,
GenericFilterOps.GTE,
GenericFilterOps.LTE,
}:
try:
numeric_column = cast(column, Float)

assert self.value is not None

if self.operation == GenericFilterOps.GT:
return and_(
numeric_column, numeric_column > float(self.value)
)
if self.operation == GenericFilterOps.LT:
return and_(
numeric_column, numeric_column < float(self.value)
)
if self.operation == GenericFilterOps.GTE:
return and_(
numeric_column, numeric_column >= float(self.value)
)
if self.operation == GenericFilterOps.LTE:
return and_(
numeric_column, numeric_column <= float(self.value)
)
except Exception as e:
raise ValueError(
f"Failed to compare the column '{column}' to the "
f"value '{self.value}' (must be numeric): {e}"
)

return column == self.value

Expand All @@ -211,6 +274,9 @@ def _remove_hyphens_from_value(cls, value: Any) -> Any:
if isinstance(value, str):
return value.replace("-", "")

if isinstance(value, list):
return [str(v).replace("-", "") for v in value]

return value

def generate_query_conditions_from_column(self, column: Any) -> Any:
Expand Down Expand Up @@ -588,6 +654,10 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]:
Returns:
A tuple of the filter value and the operator.
Raises:
ValueError: when we try to use the `oneof` operator with the wrong
value.
"""
operator = GenericFilterOps.EQUALS # Default operator
if isinstance(value, str):
Expand All @@ -598,6 +668,15 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]:
):
value = split_value[1]
operator = GenericFilterOps(split_value[0])

if operator == operator.ONEOF:
try:
value = json.loads(value)
if not isinstance(value, list):
raise ValueError
except ValueError:
raise ValueError(ONEOF_ERROR)

return value, operator

def generate_name_or_id_query_conditions(
Expand Down Expand Up @@ -648,8 +727,8 @@ def generate_name_or_id_query_conditions(

return or_(*conditions)

@staticmethod
def generate_custom_query_conditions_for_column(
self,
value: Any,
table: Type[SQLModel],
column: str,
Expand Down Expand Up @@ -833,16 +912,17 @@ def define_filter(

# Create str filters
if self.is_str_field(column):
return StrFilter(
operation=GenericFilterOps(operator),
return self._define_str_filter(
operator=GenericFilterOps(operator),
column=column,
value=value,
)

# Handle unsupported datatypes
logger.warning(
f"The Datatype {self._model_class.model_fields[column].annotation} might "
"not be supported for filtering. Defaulting to a string filter."
f"The Datatype {self._model_class.model_fields[column].annotation} "
"might not be supported for filtering. Defaulting to a string "
"filter."
)
return StrFilter(
operation=GenericFilterOps(operator),
Expand Down Expand Up @@ -1032,8 +1112,9 @@ def _define_uuid_filter(
"Invalid value passed as UUID query parameter."
) from e

# Cast the value to string for further comparisons.
value = str(value)
# For equality checks, ensure that the value is a valid UUID.
if operator == GenericFilterOps.ONEOF and not isinstance(value, list):
raise ValueError(ONEOF_ERROR)

# Generate the filter.
uuid_filter = UUIDFilter(
Expand All @@ -1043,6 +1124,38 @@ def _define_uuid_filter(
)
return uuid_filter

@staticmethod
def _define_str_filter(
column: str, value: Any, operator: GenericFilterOps
) -> StrFilter:
"""Define a str filter for a given column.
Args:
column: The column to filter on.
value: The UUID value by which to filter.
operator: The operator to use for filtering.
Returns:
A Filter object.
Raises:
ValueError: If the value is not a proper value.
"""
# For equality checks, ensure that the value is a valid UUID.
if operator == GenericFilterOps.ONEOF and not isinstance(value, list):
raise ValueError(
"If you are using `oneof:` as a filtering op, the value needs "
"to be a json formatted list string."
)

# Generate the filter.
str_filter = StrFilter(
operation=GenericFilterOps(operator),
column=column,
value=value,
)
return str_filter

@staticmethod
def _define_bool_filter(
column: str, value: Any, operator: GenericFilterOps
Expand Down
23 changes: 23 additions & 0 deletions src/zenml/models/v2/core/artifact_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
"user",
"model",
"pipeline_run",
"run_metadata",
]
artifact_id: Optional[Union[UUID, str]] = Field(
default=None,
Expand Down Expand Up @@ -545,6 +546,10 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
description="Name/ID of a pipeline run that is associated with this "
"artifact version.",
)
run_metadata: Optional[Dict[str, str]] = Field(
default=None,
description="The run_metadata to filter the artifact versions by.",
)

model_config = ConfigDict(protected_namespaces=())

Expand All @@ -564,6 +569,7 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]:
ModelSchema,
ModelVersionArtifactSchema,
PipelineRunSchema,
RunMetadataSchema,
StepRunInputArtifactSchema,
StepRunOutputArtifactSchema,
StepRunSchema,
Expand Down Expand Up @@ -645,6 +651,23 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]:
)
custom_filters.append(pipeline_run_filter)

if self.run_metadata is not None:
from zenml.enums import MetadataResourceTypes

for key, value in self.run_metadata.items():
additional_filter = and_(
RunMetadataSchema.resource_id == ArtifactVersionSchema.id,
RunMetadataSchema.resource_type
== MetadataResourceTypes.ARTIFACT_VERSION,
RunMetadataSchema.key == key,
self.generate_custom_query_conditions_for_column(
value=value,
table=RunMetadataSchema,
column="value",
),
)
custom_filters.append(additional_filter)

return custom_filters


Expand Down
Loading

0 comments on commit 69b6b80

Please sign in to comment.