From ae3aead8c5704100bc4579cddf0dfd6f1eba4293 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Tue, 20 Feb 2024 09:44:48 +0100 Subject: [PATCH] Change dask_client to general setup method (#861) This PR changes the `dask_client` method introduced in #852 into a general `setup` method. This method differs from `__init__` in that it allows users to return a state, which is passed into the `teardown` method by Fondant. This is necessary since the Dask client is not pickleable, and setting it as an instance attribute leads to issues when executing the `transform` method across processes (as is the case in the `PandasTransformComponent`). --- docs/components/component_spec.md | 62 ++++++++++++------- src/fondant/component/component.py | 28 ++++++--- src/fondant/component/data_io.py | 5 ++ src/fondant/component/executor.py | 4 +- .../components/embed_images/src/main.py | 4 +- .../index_aws_opensearch/src/main.py | 2 +- .../components/index_qdrant/src/main.py | 2 +- .../components/index_weaviate/src/main.py | 2 +- tests/component/test_component.py | 39 ++++++------ 9 files changed, 92 insertions(+), 56 deletions(-) diff --git a/docs/components/component_spec.md b/docs/components/component_spec.md index 02ad462e..1a7bd811 100644 --- a/docs/components/component_spec.md +++ b/docs/components/component_spec.md @@ -305,35 +305,53 @@ class ExampleComponent(PandasTransformComponent): Afterwards, we pass all keyword arguments to the `__init__()` method of the component. +You can also use the `setup()` and `teardown()` methods to do setup and cleanup of component +configuration. -You can also use the a `teardown()` method to perform any cleanup after the component has been executed. -This is a good place to close any open connections or files. +The `setup()` method is useful to set up any configuration that is not directly used by your +component, but by some of the underlying dependencies such as `Dask`. The advantage compared to +`__init__()` is that you can return a state which will be injected into the `teardown()` method, +so you don't need to store everything as an instance attribute, which can be a problem for +unpickleable objects when running in parallel across processes. + +You can use the `teardown()` method to clean up both instance variables from `__init__()` and +state from `setup()`. Eg. closing open connections or files. ```python +import typing as t + import pandas as pd +from dask.distributed import Client +from dask_cuda import LocalCUDACluster from fondant.component import PandasTransformComponent -from my_library import Client - - def __init__(self, *, client_url) -> None: - """ - Args: - x_argument: An argument passed to the component - """ - # Initialize your component here based on the arguments - self.client = Client(client_url) +from my_library import HTTPClient + +class MyComponent(PandasTransformComponent): + + def __init__(self, *, client_url) -> None: + """ + Args: + client_url: An argument passed to the component + """ + # Initialize your component here based on the arguments + self.http_client = HTTPClient(client_url) + + def setup(self) -> t.Any: + return Client(LocalCUDACluster) - def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: - """Implement your custom logic in this single method + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + """Implement your custom logic in this single method - Args: - dataframe: A Pandas dataframe containing the data + Args: + dataframe: A Pandas dataframe containing the data - Returns: - A pandas dataframe containing the transformed data - """ + Returns: + A pandas dataframe containing the transformed data + """ - def teardown(self): - """Perform any cleanup after the component has been executed - """ - self.client.shutdown() + def teardown(self, dask_client): + """Perform any cleanup after the component has been executed + """ + self.http_client.shutdown() + dask_client.shutdown() ``` \ No newline at end of file diff --git a/src/fondant/component/component.py b/src/fondant/component/component.py index a1036d8a..9da35dea 100644 --- a/src/fondant/component/component.py +++ b/src/fondant/component/component.py @@ -22,8 +22,21 @@ def __init__(self): self.consumes = None self.produces = None - def teardown(self) -> None: - """Method called after the component has been executed.""" + def setup(self) -> t.Any: + """Method to do additional component setup. This method can return a state (any object), + which is passed into the `teardown` method. + + There's two reasons to separate this from `__init__`: + - It can be overwritten separately + - The Fondant executor handles the state, which is a good alternative for instance + attributes if the state is only needed in `setup` / `teardown`, since instance + attributes need to be pickleable when executing a component method across processes. + """ + + def teardown(self, state: t.Any) -> None: + """Method called after the component has been executed. The Fondant executor injects the + state returned by the `setup` method. + """ class DaskComponent(BaseComponent): @@ -32,16 +45,10 @@ class DaskComponent(BaseComponent): def __init__(self, **kwargs): super().__init__() - # don't assume every object is a string - dask.config.set({"dataframe.convert-string": False}) + def setup(self) -> t.Any: # worker.daemon is set to false because creating a worker process in daemon # mode is not possible in our docker container setup. dask.config.set({"distributed.worker.daemon": False}) - - self.dask_client() - - def dask_client(self) -> Client: - """Initialize the dask client to use for this component.""" cluster = LocalCluster( processes=True, n_workers=os.cpu_count(), @@ -49,6 +56,9 @@ def dask_client(self) -> Client: ) return Client(cluster) + def teardown(self, client: t.Any) -> None: + return client.shutdown() + class DaskLoadComponent(DaskComponent): """Component that loads data and returns a Dask DataFrame.""" diff --git a/src/fondant/component/data_io.py b/src/fondant/component/data_io.py index c3b7bfd0..da4419b1 100644 --- a/src/fondant/component/data_io.py +++ b/src/fondant/component/data_io.py @@ -3,6 +3,7 @@ import typing as t from collections import defaultdict +import dask import dask.dataframe as dd from dask.diagnostics import ProgressBar from dask.distributed import Client @@ -30,6 +31,10 @@ def __init__( input_partition_rows: t.Optional[int] = None, ): super().__init__(manifest=manifest, operation_spec=operation_spec) + # Don't assume every object is a string + # https://docs.dask.org/en/stable/changelog.html#v2023-7-1 + dask.config.set({"dataframe.convert-string": False}) + self.input_partition_rows = input_partition_rows def partition_loaded_dataframe(self, dataframe: dd.DataFrame) -> dd.DataFrame: diff --git a/src/fondant/component/executor.py b/src/fondant/component/executor.py index 9fee0990..9db863ce 100644 --- a/src/fondant/component/executor.py +++ b/src/fondant/component/executor.py @@ -287,6 +287,8 @@ def _run_execution( component.consumes = self.operation_spec.inner_consumes component.produces = self.operation_spec.inner_produces + state = component.setup() + output_df = self._execute_component( component, manifest=input_manifest, @@ -298,7 +300,7 @@ def _run_execution( ) self._write_data(dataframe=output_df, manifest=output_manifest) - component.teardown() + component.teardown(state) return output_manifest diff --git a/src/fondant/components/embed_images/src/main.py b/src/fondant/components/embed_images/src/main.py index 1fd3f8e4..e2afb70c 100644 --- a/src/fondant/components/embed_images/src/main.py +++ b/src/fondant/components/embed_images/src/main.py @@ -40,12 +40,12 @@ def __init__( super().__init__() - def dask_client(self) -> Client: + def setup(self) -> Client: if self.device == "cuda": cluster = LocalCUDACluster() return Client(cluster) - return super().dask_client() + return super().setup() def process_image_batch(self, images: np.ndarray) -> t.List[torch.Tensor]: """ diff --git a/src/fondant/components/index_aws_opensearch/src/main.py b/src/fondant/components/index_aws_opensearch/src/main.py index 9650121e..33f64033 100644 --- a/src/fondant/components/index_aws_opensearch/src/main.py +++ b/src/fondant/components/index_aws_opensearch/src/main.py @@ -38,7 +38,7 @@ def __init__( ) self.create_index(index_body) - def teardown(self) -> None: + def teardown(self, _) -> None: self.client.close() def create_index(self, index_body: Dict[str, Any]): diff --git a/src/fondant/components/index_qdrant/src/main.py b/src/fondant/components/index_qdrant/src/main.py index 227e0dae..cb583bbd 100644 --- a/src/fondant/components/index_qdrant/src/main.py +++ b/src/fondant/components/index_qdrant/src/main.py @@ -47,7 +47,7 @@ def __init__( self.batch_size = batch_size self.parallelism = parallelism - def teardown(self) -> None: + def teardown(self, _) -> None: self.client.close() def write(self, dataframe: dd.DataFrame) -> None: diff --git a/src/fondant/components/index_weaviate/src/main.py b/src/fondant/components/index_weaviate/src/main.py index c30dc7bc..2f4ef96e 100644 --- a/src/fondant/components/index_weaviate/src/main.py +++ b/src/fondant/components/index_weaviate/src/main.py @@ -85,7 +85,7 @@ def create_class_schema(self) -> t.Dict[str, t.Any]: return class_schema - def teardown(self) -> None: + def teardown(self, _) -> None: del self.client def write(self, dataframe: dd.DataFrame) -> None: diff --git a/tests/component/test_component.py b/tests/component/test_component.py index 18917a4d..75d4dbe6 100644 --- a/tests/component/test_component.py +++ b/tests/component/test_component.py @@ -303,7 +303,7 @@ def load(self): @pytest.mark.usefixtures("_patched_data_writing") -def test_teardown_method(metadata): +def test_setup_teardown_methods(metadata): # Mock CLI arguments load operation_spec = OperationSpec( ComponentSpec.from_file(components_path / "component.yaml"), @@ -335,13 +335,13 @@ def shutdown(self): if self.is_connected: self.is_connected = False - client = MockClient() - class MyLoadComponent(DaskLoadComponent): def __init__(self, *, flag, value): self.flag = flag self.value = value - self.client = client + + def setup(self): + return MockClient() def load(self): data = { @@ -350,19 +350,23 @@ def load(self): } return dd.DataFrame.from_dict(data, npartitions=N_PARTITIONS) - def teardown(self) -> None: - self.client.shutdown() + def teardown(self, client) -> None: + client.shutdown() executor_factory = ExecutorFactory(MyLoadComponent) executor = executor_factory.get_executor() assert executor.input_partition_rows is None + setup = patch_method_class(MyLoadComponent.setup) teardown = patch_method_class(MyLoadComponent.teardown) - assert client.is_connected is True - with mock.patch.object(MyLoadComponent, "teardown", teardown): + with mock.patch.object(MyLoadComponent, "setup", setup), mock.patch.object( + MyLoadComponent, + "teardown", + teardown, + ): executor.execute(MyLoadComponent) + setup.mock.assert_called_once() teardown.mock.assert_called_once() - assert client.is_connected is False @pytest.mark.usefixtures("_patched_data_loading", "_patched_data_writing") @@ -394,6 +398,7 @@ def test_dask_transform_component(metadata): class MyDaskComponent(DaskTransformComponent): def __init__(self, *, flag, value): + super().__init__() self.flag = flag self.value = value @@ -442,27 +447,23 @@ def test_pandas_transform_component(metadata): "False", ] + init_called = 0 + class MyPandasComponent(PandasTransformComponent): def __init__(self, *, flag, value): assert flag == "success" assert value == 1 + nonlocal init_called + init_called += 1 def transform(self, dataframe): assert isinstance(dataframe, pd.DataFrame) return dataframe.rename(columns={"images": "embeddings"}) - init = patch_method_class(MyPandasComponent.__init__) - transform = patch_method_class(MyPandasComponent.transform) executor_factory = ExecutorFactory(MyPandasComponent) executor = executor_factory.get_executor() - with mock.patch.object(MyPandasComponent, "__init__", init), mock.patch.object( - MyPandasComponent, - "transform", - transform, - ): - executor.execute(MyPandasComponent) - init.mock.assert_called_once() - assert transform.mock.call_count == N_PARTITIONS + executor.execute(MyPandasComponent) + assert init_called == 1 def test_wrap_transform():