diff --git a/docs/book/how-to/handle-data-artifacts/registering-existing-data.md b/docs/book/how-to/handle-data-artifacts/registering-existing-data.md new file mode 100644 index 00000000000..e7d8e5c352c --- /dev/null +++ b/docs/book/how-to/handle-data-artifacts/registering-existing-data.md @@ -0,0 +1,393 @@ +--- +description: Learn how to register an external data as a ZenML artifact for future use. +--- + +# Register Existing Data as a ZenML Artifact + +Many modern Machine Learning framework create their own data as a byproduct of model training or other processes. In such cases there is no need to read and materialize those data assets to pack them into a ZenML Artifact, instead it is beneficial registering those data assets as-is in ZenML for future use. + +## Register Existing Folder as a ZenML Artifact + +If the data created externally is a folder you can register the whole folder as a ZenML Artifact and later make use of it in subsequent steps or other pipelines. + +```python +import os +from uuid import uuid4 +from pathlib import Path + +from zenml.client import Client +from zenml import register_artifact + +prefix = Client().active_stack.artifact_store.path +test_file_name = "test_file.txt" +preexisting_folder = os.path.join(prefix,f"my_test_folder_{uuid4()}") +preexisting_file = os.path.join(preexisting_folder,test_file_name) + +# produce a folder with a file inside artifact store boundaries +os.mkdir(preexisting_folder) +with open(preexisting_file,"w") as f: + f.write("test") + +# create artifact from the preexisting folder +register_artifact( + folder_or_file_uri=preexisting_folder, + name="my_folder_artifact" +) + +# consume artifact as a folder +temp_artifact_folder_path = Client().get_artifact_version(name_id_or_prefix="my_folder_artifact").load() +assert isinstance(temp_artifact_folder_path, Path) +assert os.path.isdir(temp_artifact_folder_path) +with open(os.path.join(temp_artifact_folder_path,test_file_name),"r") as f: + assert f.read() == "test" +``` + +{% hint style="info" %} +The artifact produced from the preexisting data will have a `pathlib.Path` type, once loaded or passed as input to another step. The path will be pointing to a temporary location in the executing environment and ready for use as a normal local `Path` (passed into `from_pretrained` or `open` functions to name a few examples). +{% endhint %} + +## Register Existing File as a ZenML Artifact + +If the data created externally is a file you can register it as a ZenML Artifact and later make use of it in subsequent steps or other pipelines. + +```python +import os +from uuid import uuid4 +from pathlib import Path + +from zenml.client import Client +from zenml import register_artifact + +prefix = Client().active_stack.artifact_store.path +test_file_name = "test_file.txt" +preexisting_folder = os.path.join(prefix,f"my_test_folder_{uuid4()}") +preexisting_file = os.path.join(preexisting_folder,test_file_name) + +# produce a file inside artifact store boundaries +os.mkdir(preexisting_folder) +with open(preexisting_file,"w") as f: + f.write("test") + +# create artifact from the preexisting file +register_artifact( + folder_or_file_uri=preexisting_file, + name="my_file_artifact" +) + +# consume artifact as a file +temp_artifact_file_path = Client().get_artifact_version(name_id_or_prefix="my_file_artifact").load() +assert isinstance(temp_artifact_file_path, Path) +assert not os.path.isdir(temp_artifact_file_path) +with open(temp_artifact_file_path,"r") as f: + assert f.read() == "test" +``` + +## Register All Checkpoints of a Pytorch Lightning Training Run + +Now let's explore the Pytorch Lightning example to fit the model and store the checkpoints in a remote location. + +```python +import os +from zenml.client import Client +from zenml import register_artifact +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from uuid import uuid4 + +# Define where the model data should be saved +# use active ArtifactStore +prefix = Client().active_stack.artifact_store.path +# keep data separable for future runs with uuid4 folder +default_root_dir = os.path.join(prefix, uuid4().hex) + +# Define the model and fit it +model = ... +trainer = Trainer( + default_root_dir=default_root_dir, + callbacks=[ + ModelCheckpoint( + every_n_epochs=1, save_top_k=-1, filename="checkpoint-{epoch:02d}" + ) + ], +) +try: + trainer.fit(model) +finally: + # We now link those checkpoints in ZenML as an artifact + # This will create a new artifact version + register_artifact(default_root_dir, name="all_my_model_checkpoints") +``` + +Even if an artifact is created and stored externally, it can be treated like any other artifact produced by ZenML steps - with all the functionalities described above! + +## Register Checkpoints of a Pytorch Lightning Training Run as Separate Artifact Versions +To make checkpoints (or other intermediate artifacts) linkage better versioned you can extend the `ModelCheckpoint` callback to your needs. For example such custom implementation could look like the one below, where we extend the `on_train_epoch_end` method to register each checkpoint created during the training as a separate Artifact Version in ZenML. + +{% hint style="warning" %} +To make checkpoint files last you need to set `save_top_k=-1`, otherwise older checkpoints will be deleted, making registered artifact version unusable. +{% endhint %} + +```python +import os + +from zenml.client import Client +from zenml import register_artifact +from zenml import get_step_context +from zenml.exceptions import StepContextError +from zenml.logger import get_logger + +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning import Trainer, LightningModule + +logger = get_logger(__name__) + + +class ZenMLModelCheckpoint(ModelCheckpoint): + """A ModelCheckpoint that can be used with ZenML. + + Used to store model checkpoints in ZenML as artifacts. + Supports `default_root_dir` to pass into `Trainer`. + """ + + def __init__( + self, + artifact_name: str, + every_n_epochs: int = 1, + save_top_k: int = -1, + *args, + **kwargs, + ): + # get all needed info for the ZenML logic + try: + zenml_model = get_step_context().model + except StepContextError: + raise RuntimeError( + "`ZenMLModelCheckpoint` can only be called from within a step." + ) + model_name = zenml_model.name + filename = model_name + "_{epoch:02d}" + self.filename_format = model_name + "_epoch={epoch:02d}.ckpt" + self.artifact_name = artifact_name + + prefix = Client().active_stack.artifact_store.path + self.default_root_dir = os.path.join(prefix, str(zenml_model.version)) + logger.info(f"Model data will be stored in {self.default_root_dir}") + + super().__init__( + every_n_epochs=every_n_epochs, + save_top_k=save_top_k, + filename=filename, + *args, + **kwargs, + ) + + def on_train_epoch_end( + self, trainer: "Trainer", pl_module: "LightningModule" + ) -> None: + super().on_train_epoch_end(trainer, pl_module) + + # We now link those checkpoints in ZenML as an artifact + # This will create a new artifact version + register_artifact( + os.path.join( + self.dirpath, self.filename_format.format(epoch=trainer.current_epoch) + ), + self.artifact_name, + is_model_artifact=True, + ) +``` + +Below you can find a sophisticated example of a pipeline doing a Pytorch Lightning training with the artifacts linkage for checkpoint artifacts implemented as an extended Callback. + +
+ +Pytorch Lightning training with the checkpoints linkage full example + +```python +import os +from typing import Annotated +from pathlib import Path + +import numpy as np +from zenml.client import Client +from zenml import register_artifact +from zenml import step, pipeline, get_step_context, Model +from zenml.exceptions import StepContextError +from zenml.logger import get_logger + +from torch.utils.data import DataLoader +from torch.nn import ReLU, Linear, Sequential +from torch.nn.functional import mse_loss +from torch.optim import Adam +from torch import rand +from torchvision.datasets import MNIST +from torchvision.transforms import ToTensor +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning import Trainer, LightningModule + +from zenml.new.pipelines.pipeline_context import get_pipeline_context + +logger = get_logger(__name__) + + +class ZenMLModelCheckpoint(ModelCheckpoint): + """A ModelCheckpoint that can be used with ZenML. + + Used to store model checkpoints in ZenML as artifacts. + Supports `default_root_dir` to pass into `Trainer`. + """ + + def __init__( + self, + artifact_name: str, + every_n_epochs: int = 1, + save_top_k: int = -1, + *args, + **kwargs, + ): + # get all needed info for the ZenML logic + try: + zenml_model = get_step_context().model + except StepContextError: + raise RuntimeError( + "`ZenMLModelCheckpoint` can only be called from within a step." + ) + model_name = zenml_model.name + filename = model_name + "_{epoch:02d}" + self.filename_format = model_name + "_epoch={epoch:02d}.ckpt" + self.artifact_name = artifact_name + + prefix = Client().active_stack.artifact_store.path + self.default_root_dir = os.path.join(prefix, str(zenml_model.version)) + logger.info(f"Model data will be stored in {self.default_root_dir}") + + super().__init__( + every_n_epochs=every_n_epochs, + save_top_k=save_top_k, + filename=filename, + *args, + **kwargs, + ) + + def on_train_epoch_end( + self, trainer: "Trainer", pl_module: "LightningModule" + ) -> None: + super().on_train_epoch_end(trainer, pl_module) + + # We now link those checkpoints in ZenML as an artifact + # This will create a new artifact version + register_artifact( + os.path.join( + self.dirpath, self.filename_format.format(epoch=trainer.current_epoch) + ), + self.artifact_name, + is_model_artifact=True, + ) + + +# define the LightningModule toy model +class LitAutoEncoder(LightningModule): + def __init__(self, encoder, decoder): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def training_step(self, batch, batch_idx): + # training_step defines the train loop. + # it is independent of forward + x, _ = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = mse_loss(x_hat, x) + # Logging to TensorBoard (if installed) by default + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + optimizer = Adam(self.parameters(), lr=1e-3) + return optimizer + + +@step +def get_data() -> DataLoader: + """Get the training data.""" + dataset = MNIST(os.getcwd(), download=True, transform=ToTensor()) + train_loader = DataLoader(dataset) + + return train_loader + + +@step +def get_model() -> LightningModule: + """Get the model to train.""" + encoder = Sequential(Linear(28 * 28, 64), ReLU(), Linear(64, 3)) + decoder = Sequential(Linear(3, 64), ReLU(), Linear(64, 28 * 28)) + model = LitAutoEncoder(encoder, decoder) + return model + + +@step +def train_model( + model: LightningModule, + train_loader: DataLoader, + epochs: int = 1, + artifact_name: str = "my_model_ckpts", +) -> None: + """Run the training loop.""" + # configure checkpointing + chkpt_cb = ZenMLModelCheckpoint(artifact_name=artifact_name) + + trainer = Trainer( + # pass default_root_dir from ZenML checkpoint to + # ensure that the data is accessible for the artifact + # store + default_root_dir=chkpt_cb.default_root_dir, + limit_train_batches=100, + max_epochs=epochs, + callbacks=[chkpt_cb], + ) + trainer.fit(model, train_loader) + + +@step +def predict( + checkpoint_file: Path, +) -> Annotated[np.ndarray, "predictions"]: + # load the model from the checkpoint + encoder = Sequential(Linear(28 * 28, 64), ReLU(), Linear(64, 3)) + decoder = Sequential(Linear(3, 64), ReLU(), Linear(64, 28 * 28)) + autoencoder = LitAutoEncoder.load_from_checkpoint( + checkpoint_file, encoder=encoder, decoder=decoder + ) + encoder = autoencoder.encoder + encoder.eval() + + # predict on fake batch + fake_image_batch = rand(4, 28 * 28, device=autoencoder.device) + embeddings = encoder(fake_image_batch) + if embeddings.device.type == "cpu": + return embeddings.detach().numpy() + else: + return embeddings.detach().cpu().numpy() + + +@pipeline(model=Model(name="LightningDemo")) +def train_pipeline(artifact_name: str = "my_model_ckpts"): + train_loader = get_data() + model = get_model() + train_model(model, train_loader, 10, artifact_name) + # pass in the latest checkpoint for predictions + predict( + get_pipeline_context().model.get_artifact(artifact_name), after=["train_model"] + ) + + +if __name__ == "__main__": + train_pipeline() +``` + +
+ + +
ZenML Scarf
\ No newline at end of file diff --git a/docs/book/toc.md b/docs/book/toc.md index 887d8d07eb5..54122a27ec6 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -136,6 +136,7 @@ * [Load artifacts into memory](how-to/handle-data-artifacts/load-artifacts-into-memory.md) * [Skipping materialization](how-to/handle-data-artifacts/unmaterialized-artifacts.md) * [Passing artifacts between pipelines](how-to/handle-data-artifacts/passing-artifacts-between-pipelines.md) + * [Register Existing Data as a ZenML Artifact](how-to/handle-data-artifacts/registring-existing-data.md) * [📊 Visualizing artifacts](how-to/visualize-artifacts/README.md) * [Default visualizations](how-to/visualize-artifacts/types-of-visualizations.md) * [Creating custom visualizations](how-to/visualize-artifacts/creating-custom-visualizations.md) diff --git a/docs/book/user-guide/starter-guide/manage-artifacts.md b/docs/book/user-guide/starter-guide/manage-artifacts.md index 340f6bc4171..7ad8e950e7d 100644 --- a/docs/book/user-guide/starter-guide/manage-artifacts.md +++ b/docs/book/user-guide/starter-guide/manage-artifacts.md @@ -301,6 +301,51 @@ Even if an artifact is created externally, it can be treated like any other arti It is also possible to use these functions inside your ZenML steps. However, it is usually cleaner to return the artifacts as outputs of your step to save them, or to use External Artifacts to load them instead. {% endhint %} +### Linking existing data as a ZenML artifact + +Sometimes, data is produced completely outside of ZenML and can be conveniently stored on a given storage. A good example of this is the checkpoint files created as a side-effect of the Deep Learning model training. We know that the intermediate data of the deep learning frameworks is quite big and there is no good reason to move it around again and again, if it can be produced directly in the artifact store boundaries and later just linked to become an artifact of ZenML. +Let's explore the Pytorch Lightning example to fit the model and store the checkpoints in a remote location. + +```python +import os +from zenml.client import Client +from zenml import register_artifact +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from uuid import uuid4 + +# Define where the model data should be saved +# use active ArtifactStore +prefix = Client().active_stack.artifact_store.path +# keep data separable for future runs with uuid4 folder +default_root_dir = os.path.join(prefix, uuid4().hex) + +# Define the model and fit it +model = ... +trainer = Trainer( + default_root_dir=default_root_dir, + callbacks=[ + ModelCheckpoint( + every_n_epochs=1, save_top_k=-1, filename="checkpoint-{epoch:02d}" + ) + ], +) +try: + trainer.fit(model) +finally: + # We now link those checkpoints in ZenML as an artifact + # This will create a new artifact version + register_artifact(default_root_dir, name="all_my_model_checkpoints") +``` + +{% hint style="info" %} +The artifact produced from the preexisting data will have a `pathlib.Path` type, once loaded or passed as input to another step. +{% endhint %} + +Even if an artifact is created and stored externally, it can be treated like any other artifact produced by ZenML steps - with all the functionalities described above! + +For more details and use-cases check-out detailed docs page [Register Existing Data as a ZenML Artifact](../../how-to/handle-data-artifacts/registring-existing-data.md). + ## Logging metadata for an artifact One of the most useful ways of interacting with artifacts in ZenML is the ability to associate metadata with them. [As mentioned before](../../how-to/build-pipelines/fetching-pipelines.md#artifact-information), artifact metadata is an arbitrary dictionary of key-value pairs that are useful for understanding the nature of the data. diff --git a/examples/e2e/Makefile b/examples/e2e/Makefile index ca4b34f8008..4b6ba4f5d0a 100644 --- a/examples/e2e/Makefile +++ b/examples/e2e/Makefile @@ -12,5 +12,4 @@ install-stack-local: zenml stack register -a default -o default -r mlflow_local_$${stack_name} \ -d mlflow_local_$${stack_name} -e mlflow_local_$${stack_name} -dv \ evidently_$${stack_name} $${stack_name} && \ - zenml stack set $${stack_name} && \ - zenml stack up + zenml stack set $${stack_name} diff --git a/src/zenml/__init__.py b/src/zenml/__init__.py index fa5e0a85227..6dca64419a1 100644 --- a/src/zenml/__init__.py +++ b/src/zenml/__init__.py @@ -36,6 +36,7 @@ log_artifact_metadata, save_artifact, load_artifact, + register_artifact, ) from zenml.model.utils import ( log_model_metadata, @@ -62,6 +63,7 @@ "link_artifact_to_model", "pipeline", "save_artifact", + "register_artifact", "show", "step", "entrypoint", diff --git a/src/zenml/artifacts/load_directory_materializer.py b/src/zenml/artifacts/load_directory_materializer.py new file mode 100644 index 00000000000..d618c20a515 --- /dev/null +++ b/src/zenml/artifacts/load_directory_materializer.py @@ -0,0 +1,91 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Only-load materializer for directories.""" + +import os +import tempfile +from pathlib import Path +from typing import Any, ClassVar, Tuple, Type + +from zenml.enums import ArtifactType +from zenml.io import fileio +from zenml.materializers.base_materializer import BaseMaterializer + + +class PreexistingDataMaterializer(BaseMaterializer): + """Materializer to load directories from the artifact store. + + This materializer is very special, since it do not implement save + logic at all. The save of the data to some URI inside the artifact store + shall happen outside and is in user's responsibility. + + This materializer solely supports the `register_artifact` function. + """ + + ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Path,) + ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA + SKIP_REGISTRATION: ClassVar[bool] = True + + def load(self, data_type: Type[Any]) -> Any: + """Copy the artifact file(s) to a local temp directory. + + Args: + data_type: Unused. + + Returns: + Path to the local directory that contains the artifact files. + """ + directory = tempfile.mkdtemp(prefix="zenml-artifact") + if fileio.isdir(self.uri): + self._copy_directory(src=self.uri, dst=directory) + return Path(directory) + else: + dst = os.path.join(directory, os.path.split(self.uri)[-1]) + fileio.copy(src=self.uri, dst=dst) + return Path(dst) + + def save(self, data: Any) -> None: + """Store the directory in the artifact store. + + Args: + data: Path to a local directory to store. + + Raises: + NotImplementedError: Always + """ + raise NotImplementedError( + "`PreexistingDataMaterializer` can only be used in the " + "context of `register_artifact` function, " + "which expects the data to be already properly saved in " + "the Artifact Store, thus `save` logic makes no sense here." + ) + + @staticmethod + def _copy_directory(src: str, dst: str) -> None: + """Recursively copy a directory. + + Args: + src: The directory to copy. + dst: Where to copy the directory to. + """ + for src_dir, _, files in fileio.walk(src): + src_dir_ = str(src_dir) + dst_dir = str(os.path.join(dst, os.path.relpath(src_dir_, src))) + fileio.makedirs(dst_dir) + + for file in files: + file_ = str(file) + src_file = os.path.join(src_dir_, file_) + dst_file = os.path.join(dst_dir, file_) + fileio.copy(src_file, dst_file) diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 4e2ade01336..33ccdd3d060 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -19,15 +19,29 @@ import time import zipfile from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Type, + Union, + cast, +) from uuid import UUID, uuid4 +from zenml.artifacts.load_directory_materializer import ( + PreexistingDataMaterializer, +) from zenml.client import Client from zenml.constants import ( MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION, MODEL_METADATA_YAML_FILE_NAME, ) from zenml.enums import ( + ArtifactType, ExecutionStatus, MetadataResourceTypes, StackComponentType, @@ -50,6 +64,7 @@ StepRunResponse, StepRunUpdate, ) +from zenml.models.v2.core.artifact import ArtifactResponse from zenml.stack import StackComponent from zenml.steps.step_context import get_step_context from zenml.utils import source_utils @@ -111,10 +126,6 @@ def save_artifact( Returns: The saved artifact response. - - Raises: - RuntimeError: If artifact URI already exists. - EntityExistsError: If artifact version already exists. """ from zenml.materializers.materializer_registry import ( materializer_registry, @@ -123,24 +134,11 @@ def save_artifact( client = Client() - # Get or create the artifact - try: - artifact = client.list_artifacts(name=name)[0] - if artifact.has_custom_name != has_custom_name: - client.update_artifact( - name_id_or_prefix=artifact.id, has_custom_name=has_custom_name - ) - except IndexError: - try: - artifact = client.zen_store.create_artifact( - ArtifactRequest( - name=name, - has_custom_name=has_custom_name, - tags=tags, - ) - ) - except EntityExistsError: - artifact = client.list_artifacts(name=name)[0] + artifact = _get_or_create_artifact( + name=name, + has_custom_name=has_custom_name, + tags=tags, + ) # Get the current artifact store artifact_store = client.active_stack.artifact_store @@ -151,16 +149,14 @@ def save_artifact( if not uri.startswith(artifact_store.path): uri = os.path.join(artifact_store.path, uri) - if manual_save and artifact_store.exists(uri): + if manual_save: # This check is only necessary for manual saves as we already check # it when creating the directory for step output artifacts - other_artifacts = client.list_artifact_versions(uri=uri, size=1) - if other_artifacts and (other_artifact := other_artifacts[0]): - raise RuntimeError( - f"Cannot save new artifact {name} version to URI " - f"{uri} because the URI is already used by artifact " - f"{other_artifact.name} (version {other_artifact.version})." - ) + _check_if_artifact_with_given_uri_already_registered( + artifact_store=artifact_store, + uri=uri, + name=name, + ) artifact_store.makedirs(uri) # Find and initialize the right materializer class @@ -211,7 +207,9 @@ def save_artifact( ) # Create the artifact version - def _create_version() -> Optional[ArtifactVersionResponse]: + def _create_version( + version: Union[int, str], + ) -> Optional[ArtifactVersionResponse]: artifact_version = ArtifactVersionRequest( artifact_id=artifact.id, version=version, @@ -233,37 +231,12 @@ def _create_version() -> Optional[ArtifactVersionResponse]: except EntityExistsError: return None - response = None - if not version: - retries_made = 0 - for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION): - # Get new artifact version - version = _get_new_artifact_version(name) - if response := _create_version(): - break - # smoothed exponential back-off, it will go as 0.2, 0.3, - # 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ... - sleep = 0.2 * 1.5**i - logger.debug( - f"Failed to create artifact version `{version}` for " - f"artifact `{name}`. Retrying in {sleep}..." - ) - time.sleep(sleep) - retries_made += 1 - if not response: - raise EntityExistsError( - f"Failed to create new artifact version for artifact " - f"`{name}`. Retried {retries_made} times. " - "This could be driven by exceptionally high concurrency of " - "pipeline runs. Please, reach out to us on ZenML Slack for support." - ) - else: - response = _create_version() - if not response: - raise EntityExistsError( - f"Failed to create artifact version `{version}` for artifact " - f"`{name}`. Given version already exists." - ) + response = _create_artifact_version_with_retries( + name=name, + version=version, + create_version_fn=_create_version, + ) + if artifact_metadata: client.create_run_metadata( metadata=artifact_metadata, @@ -272,29 +245,112 @@ def _create_version() -> Optional[ArtifactVersionResponse]: ) if manual_save: + _link_artifact_version_to_the_step_and_model( + response=response, + is_model_artifact=is_model_artifact, + is_deployment_artifact=is_deployment_artifact, + ) + + return response + + +def register_artifact( + folder_or_file_uri: str, + name: str, + version: Optional[Union[int, str]] = None, + tags: Optional[List[str]] = None, + has_custom_name: bool = True, + is_model_artifact: bool = False, + is_deployment_artifact: bool = False, + artifact_metadata: Dict[str, "MetadataType"] = {}, +) -> "ArtifactVersionResponse": + """Register existing data stored in the artifact store as a ZenML Artifact. + + Args: + folder_or_file_uri: The full URI within the artifact store to the folder + or to the file. + name: The name of the artifact. + version: The version of the artifact. If not provided, a new + auto-incremented version will be used. + tags: Tags to associate with the artifact. + has_custom_name: If the artifact name is custom and should be listed in + the dashboard "Artifacts" tab. + is_model_artifact: If the artifact is a model artifact. + is_deployment_artifact: If the artifact is a deployment artifact. + artifact_metadata: Metadata dictionary to attach to the artifact version. + + Returns: + The saved artifact response. + + Raises: + FileNotFoundError: If the folder URI is outside of the artifact store + bounds. + """ + client = Client() + + # Get the current artifact store + artifact_store = client.active_stack.artifact_store + + if not folder_or_file_uri.startswith(artifact_store.path): + raise FileNotFoundError( + f"Folder `{folder_or_file_uri}` is outside of " + f"artifact store bounds `{artifact_store.path}`" + ) + + _check_if_artifact_with_given_uri_already_registered( + artifact_store=artifact_store, + uri=folder_or_file_uri, + name=name, + ) + + artifact = _get_or_create_artifact( + name=name, + has_custom_name=has_custom_name, + tags=tags, + ) + + # Create the artifact version + def _create_version( + version: Union[int, str], + ) -> Optional[ArtifactVersionResponse]: + artifact_version = ArtifactVersionRequest( + artifact_id=artifact.id, + version=version, + tags=tags, + type=ArtifactType.DATA, + uri=folder_or_file_uri, + materializer=source_utils.resolve(PreexistingDataMaterializer), + data_type=source_utils.resolve(Path), + user=Client().active_user.id, + workspace=Client().active_workspace.id, + artifact_store_id=artifact_store.id, + has_custom_name=has_custom_name, + ) try: - error_message = "step run" - step_context = get_step_context() - step_run = step_context.step_run - client.zen_store.update_run_step( - step_run_id=step_run.id, - step_run_update=StepRunUpdate( - saved_artifact_versions={name: response.id} - ), + return client.zen_store.create_artifact_version( + artifact_version=artifact_version ) - error_message = "model" - model = step_context.model - if model: - from zenml.model.utils import link_artifact_to_model - - link_artifact_to_model( - artifact_version_id=response.id, - model=model, - is_model_artifact=is_model_artifact, - is_deployment_artifact=is_deployment_artifact, - ) - except (RuntimeError, StepContextError): - logger.debug(f"Unable to link saved artifact to {error_message}.") + except EntityExistsError: + return None + + response = _create_artifact_version_with_retries( + name=name, + version=version, + create_version_fn=_create_version, + ) + + if artifact_metadata: + client.create_run_metadata( + metadata=artifact_metadata, + resource_id=response.id, + resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + + _link_artifact_version_to_the_step_and_model( + response=response, + is_model_artifact=is_model_artifact, + is_deployment_artifact=is_deployment_artifact, + ) return response @@ -577,6 +633,177 @@ def get_artifacts_versions_of_pipeline_run( # ------------------------- +def _check_if_artifact_with_given_uri_already_registered( + artifact_store: "BaseArtifactStore", + uri: str, + name: str, +) -> None: + """Check if the given artifact store already contains an artifact with the given URI. + + Args: + artifact_store: The artifact store to check. + uri: The uri of the artifact. + name: The name of the artifact. + + Raises: + RuntimeError: If the artifact store already contains an artifact with + the given URI. + """ + if artifact_store.exists(uri): + # This check is only necessary for manual saves as we already check + # it when creating the directory for step output artifacts + other_artifacts = Client().list_artifact_versions(uri=uri, size=1) + if other_artifacts and (other_artifact := other_artifacts[0]): + raise RuntimeError( + f"Cannot create new artifact {name} version with URI " + f"{uri} because the URI is already used by artifact " + f"{other_artifact.name} (version {other_artifact.version})." + ) + + +def _get_or_create_artifact( + name: str, has_custom_name: bool, tags: Optional[List[str]] = None +) -> ArtifactResponse: + """Get or create an artifact with the given name. + + Args: + name: The name of the artifact. + has_custom_name: If the artifact name is custom and should be listed in + the dashboard "Artifacts" tab. + tags: Tags to associate with the artifact. + + Returns: + The artifact. + """ + client = Client() + # Get or create the artifact + try: + artifact = client.list_artifacts(name=name)[0] + if artifact.has_custom_name != has_custom_name: + client.update_artifact( + name_id_or_prefix=artifact.id, has_custom_name=has_custom_name + ) + except IndexError: + try: + artifact = client.zen_store.create_artifact( + ArtifactRequest( + name=name, + has_custom_name=has_custom_name, + tags=tags, + ) + ) + except EntityExistsError: + artifact = client.list_artifacts(name=name)[0] + return artifact + + +def _create_artifact_version_with_retries( + name: str, + version: Optional[Union[int, str]], + create_version_fn: Callable[ + [ + Union[int, str], + ], + Optional[ArtifactVersionResponse], + ], +) -> ArtifactVersionResponse: + """Create an artifact version with some retries. + + This function will retry the creation of an artifact version up to + MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION times if it fails. + It can fail in high-concurrency environments. + + Args: + name: The name of the artifact. + version: The version of the artifact. If not provided, a new + auto-incremented version will be used. + create_version_fn: The function to create the artifact version. + + Returns: + The created artifact version. + + Raises: + EntityExistsError: If the artifact version could not be created + after MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION attempts due + to collisions. + + """ + response = None + if not version: + retries_made = 0 + for i in range(MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION): + # Get new artifact version + version = _get_new_artifact_version(name) + if response := create_version_fn(version): + break + # smoothed exponential back-off, it will go as 0.2, 0.3, + # 0.45, 0.68, 1.01, 1.52, 2.28, 3.42, 5.13, 7.69, ... + sleep = 0.2 * 1.5**i + logger.debug( + f"Failed to create artifact version `{version}` for " + f"artifact `{name}`. Retrying in {sleep}..." + ) + time.sleep(sleep) + retries_made += 1 + if not response: + raise EntityExistsError( + f"Failed to create new artifact version for artifact " + f"`{name}`. Retried {retries_made} times. " + "This could be driven by exceptionally high concurrency of " + "pipeline runs. Please, reach out to us on ZenML Slack for support." + ) + else: + response = create_version_fn(version) + if not response: + raise EntityExistsError( + f"Failed to create artifact version `{version}` for artifact " + f"`{name}`. Given version already exists." + ) + return response + + +def _link_artifact_version_to_the_step_and_model( + response: ArtifactVersionResponse, + is_model_artifact: bool, + is_deployment_artifact: bool, +) -> None: + """Link an artifact version to the step and its' context model. + + This function links the AV to: + - the step run + - the MV from the step context + + Args: + response: The artifact version to link. + is_model_artifact: Whether the artifact is a model artifact. + is_deployment_artifact: Whether the artifact is a deployment artifact. + """ + client = Client() + try: + error_message = "step run" + step_context = get_step_context() + step_run = step_context.step_run + client.zen_store.update_run_step( + step_run_id=step_run.id, + step_run_update=StepRunUpdate( + saved_artifact_versions={response.artifact.name: response.id} + ), + ) + error_message = "model" + model = step_context.model + if model: + from zenml.model.utils import link_artifact_to_model + + link_artifact_to_model( + artifact_version_id=response.id, + model=model, + is_model_artifact=is_model_artifact, + is_deployment_artifact=is_deployment_artifact, + ) + except (RuntimeError, StepContextError): + logger.debug(f"Unable to link saved artifact to {error_message}.") + + def _load_artifact_from_uri( materializer: Union["Source", str], data_type: Union["Source", str], diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index 1606957d7b5..a084f935fbb 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -3,7 +3,9 @@ import multiprocessing import os import shutil +import tempfile import zipfile +from pathlib import Path from typing import Optional, Tuple from unittest.mock import patch @@ -17,6 +19,7 @@ save_artifact, step, ) +from zenml.artifacts.utils import register_artifact from zenml.client import Client from zenml.models.v2.core.artifact import ArtifactResponse @@ -369,3 +372,94 @@ def test_parallel_artifact_creation(clean_client: Client): assert {av.version for av in avs} == { str(i) for i in range(1, process_count + 1) } + + +def test_register_artifact(clean_client: Client): + """Tests that a folder can be linked as an artifact in local setting.""" + + uri_prefix = os.path.join( + clean_client.active_stack.artifact_store.path, "test_folder" + ) + os.makedirs(uri_prefix, exist_ok=True) + with open(os.path.join(uri_prefix, "test.txt"), "w") as f: + f.write("test") + + register_artifact(folder_or_file_uri=uri_prefix, name="test_folder") + + artifact = clean_client.get_artifact_version( + name_id_or_prefix="test_folder", version=1 + ) + assert artifact + assert artifact.uri == uri_prefix + + loaded_dir = artifact.load() + assert isinstance(loaded_dir, Path) + + with open(loaded_dir / "test.txt", "r") as f: + assert f.read() == "test" + + +def test_register_artifact_out_of_bounds(clean_client: Client): + """Tests that a folder cannot be linked as an artifact if out of bounds.""" + + uri_prefix = tempfile.mkdtemp() + try: + with pytest.raises(FileNotFoundError): + register_artifact( + folder_or_file_uri=uri_prefix, name="test_folder" + ) + finally: + os.rmdir(uri_prefix) + + +@step(enable_cache=False) +def register_artifact_step_1() -> None: + # find out where to save some data + uri_prefix = os.path.join( + Client().active_stack.artifact_store.path, "test_folder" + ) + os.makedirs(uri_prefix, exist_ok=True) + # generate dat to validate in register_artifact_step_2 + test_file = os.path.join(uri_prefix, "test.txt") + with open(test_file, "w") as f: + f.write("test") + + register_artifact(folder_or_file_uri=uri_prefix, name="test_folder") + + register_artifact(folder_or_file_uri=test_file, name="test_file") + + +@step(enable_cache=False) +def register_artifact_step_2( + inp_folder: Path, +) -> None: + # step should receive a path pointing to the folder + # from register_artifact_step_1 + with open(inp_folder / "test.txt", "r") as f: + assert f.read() == "test" + # at the same time the input artifact is no longer inside the + # artifact store, but in the temporary folder of local file system + assert not str(inp_folder.absolute()).startswith( + Client().active_stack.artifact_store.path + ) + + file_artifact_path = Client().get_artifact_version("test_file").load() + + with open(file_artifact_path, "r") as f: + assert f.read() == "test" + + +def test_register_artifact_between_steps(clean_client: Client): + """Tests that a folder can be linked as an artifact and used in pipelines.""" + + @pipeline(enable_cache=False) + def register_artifact_pipeline(): + register_artifact_step_1() + register_artifact_step_2( + clean_client.get_artifact_version( + name_id_or_prefix="test_folder", version=1 + ), + after=["register_artifact_step_1"], + ) + + register_artifact_pipeline()