Skip to content

Commit

Permalink
Enable write components to cache (#814)
Browse files Browse the repository at this point in the history
Fixes #813
  • Loading branch information
PhilippeMoussalli authored Jan 30, 2024
1 parent ea6462e commit 4400963
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 33 deletions.
79 changes: 50 additions & 29 deletions src/fondant/component/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,15 +264,13 @@ def _write_data(

data_writer.write_dataframe(dataframe, self.client)

def _get_cached_manifest(self) -> t.Union[Manifest, None]:
def _get_cache_reference_content(self) -> t.Union[str, None]:
"""
Find and return the matching execution's Manifest for the component, if it exists.
This function searches for previous execution manifests that match the component's metadata.
Get the content of the cache reference file. This file contains the path to the cached
manifest or empty string if the component is cached without producing any manifest.
Returns:
The Manifest object representing the most recent matching execution,
or None if no matching execution is found.
The content of the cache reference file.
"""
manifest_reference_path = (
f"{self.metadata.base_path}/{self.metadata.pipeline_name}/cache/"
Expand All @@ -285,13 +283,7 @@ def _get_cached_manifest(self) -> t.Union[Manifest, None]:
mode="rt",
encoding="utf-8",
) as file_:
cached_manifest_path = file_.read()
manifest = Manifest.from_file(cached_manifest_path)
logger.info(
f"Matching execution detected for component. The last execution of the"
f" component originated from `{manifest.run_id}`.",
)
return manifest
return file_.read()

except FileNotFoundError:
logger.info("No matching execution for component detected")
Expand Down Expand Up @@ -345,6 +337,7 @@ def _run_execution(
component,
manifest=input_manifest,
)

output_manifest = input_manifest.evolve(
operation_spec=self.operation_spec,
run_id=self.metadata.run_id,
Expand All @@ -363,50 +356,75 @@ def execute(self, component_cls: t.Type[Component]) -> None:
component_cls: The class of the component to execute
"""
input_manifest = self._load_or_create_manifest()
base_path = input_manifest.base_path
pipeline_name = input_manifest.pipeline_name

if self.cache and self._is_previous_cached(input_manifest):
output_manifest = self._get_cached_manifest()
if output_manifest is not None:
cache_reference_content = self._get_cache_reference_content()

if cache_reference_content is not None:
logger.info("Skipping component execution")

if cache_reference_content:
output_manifest = Manifest.from_file(cache_reference_content)

logger.info(
f"Matching execution detected for component. The last execution of the"
f" component originated from `{output_manifest.run_id}`.",
)
else:
logger.info("Component is cached without producing a manifest")
output_manifest = None
else:
output_manifest = self._run_execution(component_cls, input_manifest)

else:
logger.info("Caching disabled for the component")
output_manifest = self._run_execution(component_cls, input_manifest)

self.upload_manifest(output_manifest, save_path=self.output_manifest_path)
if output_manifest:
self.upload_manifest(output_manifest, save_path=self.output_manifest_path)

self._upload_cache_reference_content(
base_path=base_path,
pipeline_name=pipeline_name,
)

def _upload_cache_key(
def _upload_cache_reference_content(
self,
manifest: Manifest,
manifest_save_path: t.Union[str, Path],
base_path: str,
pipeline_name: str,
):
"""
Write the cache key containing the reference to the location of the written manifest..
Write the cache key containing the reference to the location of the written manifest.
This function creates a file with the format "<cache_key>.txt" at the specified
'manifest_save_path' to store the manifest location for future retrieval of
cached component executions.
Args:
manifest: The reference manifest.
manifest_save_path (str): The path where the manifest is saved.
base_path: The base path of the pipeline.
pipeline_name: The name of the pipeline.
"""
manifest_reference_path = (
f"{manifest.base_path}/{manifest.pipeline_name}/cache/"
f"{self.metadata.cache_key}.txt"
cache_reference_path = (
f"{base_path}/{pipeline_name}/cache/{self.metadata.cache_key}.txt"
)

logger.info(f"Writing cache key to {manifest_reference_path}")
logger.info(
f"Writing cache key with manifest reference to {cache_reference_path}",
)

with fs_open(
manifest_reference_path,
cache_reference_path,
mode="wt",
encoding="utf-8",
auto_mkdir=True,
) as file_:
file_.write(str(manifest_save_path))
file_.write(self.cache_reference_content)

@property
def cache_reference_content(self) -> str:
return str(self.output_manifest_path)

def upload_manifest(self, manifest: Manifest, save_path: t.Union[str, Path]):
"""
Expand All @@ -420,7 +438,6 @@ def upload_manifest(self, manifest: Manifest, save_path: t.Union[str, Path]):
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
manifest.to_file(save_path)
logger.info(f"Saving output manifest to {save_path}")
self._upload_cache_key(manifest=manifest, manifest_save_path=save_path)


class DaskLoadExecutor(Executor[DaskLoadComponent]):
Expand Down Expand Up @@ -602,6 +619,10 @@ def _write_data(self, dataframe: dd.DataFrame, *, manifest: Manifest):
def upload_manifest(self, manifest: Manifest, save_path: t.Union[str, Path]):
pass

@property
def cache_reference_content(self) -> str:
return ""


class ExecutorFactory:
def __init__(self, component: t.Type[Component]):
Expand Down
13 changes: 9 additions & 4 deletions tests/component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def mocked_write_dataframe(self, dataframe, dask_client=None):
"upload_manifest",
lambda self, manifest, save_path: None,
)
monkeypatch.setattr(
Executor,
"_upload_cache_reference_content",
lambda self, base_path, pipeline_name: None,
)


def patch_method_class(method):
Expand Down Expand Up @@ -195,9 +200,9 @@ def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]:
pass

executor = MyExecutor.from_args()
matching_execution_manifest = executor._get_cached_manifest()
cache_reference_content = executor._get_cache_reference_content()
# Check that the latest manifest is returned
assert matching_execution_manifest.run_id == "example_pipeline_2023"
assert Manifest.from_file(cache_reference_content).run_id == "example_pipeline_2023"
# Check that the previous component is cached due to similar run IDs
assert executor._is_previous_cached(Manifest.from_file(input_manifest_path)) is True

Expand Down Expand Up @@ -244,7 +249,7 @@ def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]:
pass

executor = MyExecutor.from_args()
matching_execution_manifest = executor._get_cached_manifest()
matching_execution_manifest = executor._get_cache_reference_content()
# Check that the latest manifest is returned
assert matching_execution_manifest is None

Expand Down Expand Up @@ -528,7 +533,7 @@ def transform(dataframe: pd.DataFrame) -> pd.DataFrame:
assert output_df.columns.tolist() == ["caption_text", "image_height"]


@pytest.mark.usefixtures("_patched_data_loading")
@pytest.mark.usefixtures("_patched_data_loading", "_patched_data_writing")
def test_write_component(metadata):
operation_spec = OperationSpec(
ComponentSpec.from_file(components_path / "component.yaml"),
Expand Down

0 comments on commit 4400963

Please sign in to comment.