From e67715775802e6b98133a7bab4d7ba85010d59c5 Mon Sep 17 00:00:00 2001 From: Philippe Moussalli Date: Thu, 11 Jan 2024 10:41:48 +0100 Subject: [PATCH] Add teardown method (#767) PR that adds a teardown method to every component. Useful for shutting down database connections and clients instead of it happening abruptly after the container is shutdown. Related to https://github.com/ml6team/fondant-use-cases/issues/59 --- docs/components/component_spec.md | 35 +++++++++++++++++ src/fondant/component/component.py | 3 ++ src/fondant/component/executor.py | 4 +- tests/component/test_component.py | 63 ++++++++++++++++++++++++++++++ tests/examples/__init__.py | 0 5 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 tests/examples/__init__.py diff --git a/docs/components/component_spec.md b/docs/components/component_spec.md index 9299305a2..65e121c18 100644 --- a/docs/components/component_spec.md +++ b/docs/components/component_spec.md @@ -301,4 +301,39 @@ class ExampleComponent(PandasTransformComponent): Returns: A pandas dataframe containing the transformed data """ +``` + +Afterwards, we pass all keyword arguments to the `__init__()` method of the component. + + +You can also use the a `teardown()` method to perform any cleanup after the component has been executed. +This is a good place to close any open connections or files. + +```python +import pandas as pd +from fondant.component import PandasTransformComponent +from my_library import Client + + def __init__(self, *, client_url, **kwargs) -> None: + """ + Args: + x_argument: An argument passed to the component + """ + # Initialize your component here based on the arguments + self.client = Client(client_url) + + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + """Implement your custom logic in this single method + + Args: + dataframe: A Pandas dataframe containing the data + + Returns: + A pandas dataframe containing the transformed data + """ + + def teardown(self): + """Perform any cleanup after the component has been executed + """ + self.client.shutdown() ``` \ No newline at end of file diff --git a/src/fondant/component/component.py b/src/fondant/component/component.py index 82d539b84..7d9a86113 100644 --- a/src/fondant/component/component.py +++ b/src/fondant/component/component.py @@ -26,6 +26,9 @@ def __init__( ): pass + def teardown(self) -> None: + """Method called after the component has been executed.""" + class DaskLoadComponent(BaseComponent): """Component that loads data and returns a Dask DataFrame.""" diff --git a/src/fondant/component/executor.py b/src/fondant/component/executor.py index 3026eb625..db3140703 100644 --- a/src/fondant/component/executor.py +++ b/src/fondant/component/executor.py @@ -335,7 +335,7 @@ def _run_execution( input_manifest: Manifest, ) -> Manifest: logging.info("Executing component") - component = component_cls( + component: Component = component_cls( consumes=self.operation_spec.inner_consumes, produces=self.operation_spec.inner_produces, **self.user_arguments, @@ -350,6 +350,8 @@ def _run_execution( ) self._write_data(dataframe=output_df, manifest=output_manifest) + component.teardown() + return output_manifest def execute(self, component_cls: t.Type[Component]) -> None: diff --git a/tests/component/test_component.py b/tests/component/test_component.py index 397ab210e..191e6e329 100644 --- a/tests/component/test_component.py +++ b/tests/component/test_component.py @@ -298,6 +298,69 @@ def load(self): load.mock.assert_called_once() +@pytest.mark.usefixtures("_patched_data_writing") +def test_teardown_method(metadata): + # Mock CLI arguments load + operation_spec = OperationSpec( + ComponentSpec.from_file(components_path / "component.yaml"), + ) + + sys.argv = [ + "", + "--metadata", + metadata.to_json(), + "--flag", + "success", + "--value", + "1", + "--output_manifest_path", + str(components_path / "output_manifest.json"), + "--operation_spec", + operation_spec.to_json(), + "--cache", + "False", + "--produces", + "{}", + ] + + class MockClient: + def __init__(self): + self.is_connected = True + + def shutdown(self): + if self.is_connected: + self.is_connected = False + + client = MockClient() + + class MyLoadComponent(DaskLoadComponent): + def __init__(self, *, flag, value, **kwargs): + self.flag = flag + self.value = value + self.client = client + + def load(self): + data = { + "id": [0, 1], + "captions_data": ["hello world", "this is another caption"], + } + return dd.DataFrame.from_dict(data, npartitions=N_PARTITIONS) + + def teardown(self) -> None: + self.client.shutdown() + + executor_factory = ExecutorFactory(MyLoadComponent) + executor = executor_factory.get_executor() + assert executor.input_partition_rows is None + + teardown = patch_method_class(MyLoadComponent.teardown) + assert client.is_connected is True + with mock.patch.object(MyLoadComponent, "teardown", teardown): + executor.execute(MyLoadComponent) + teardown.mock.assert_called_once() + assert client.is_connected is False + + @pytest.mark.usefixtures("_patched_data_loading", "_patched_data_writing") def test_dask_transform_component(metadata): operation_spec = OperationSpec( diff --git a/tests/examples/__init__.py b/tests/examples/__init__.py new file mode 100644 index 000000000..e69de29bb