Skip to content

Commit

Permalink
Change dask_client to general setup method (#861)
Browse files Browse the repository at this point in the history
This PR changes the `dask_client` method introduced in #852 into a
general `setup` method.

This method differs from `__init__` in that it allows users to return a
state, which is passed into the `teardown` method by Fondant. This is
necessary since the Dask client is not pickleable, and setting it as an
instance attribute leads to issues when executing the `transform` method
across processes (as is the case in the `PandasTransformComponent`).
  • Loading branch information
RobbeSneyders authored Feb 20, 2024
1 parent 0de3f5c commit ae3aead
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 56 deletions.
62 changes: 40 additions & 22 deletions docs/components/component_spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,35 +305,53 @@ class ExampleComponent(PandasTransformComponent):

Afterwards, we pass all keyword arguments to the `__init__()` method of the component.

You can also use the `setup()` and `teardown()` methods to do setup and cleanup of component
configuration.

You can also use the a `teardown()` method to perform any cleanup after the component has been executed.
This is a good place to close any open connections or files.
The `setup()` method is useful to set up any configuration that is not directly used by your
component, but by some of the underlying dependencies such as `Dask`. The advantage compared to
`__init__()` is that you can return a state which will be injected into the `teardown()` method,
so you don't need to store everything as an instance attribute, which can be a problem for
unpickleable objects when running in parallel across processes.

You can use the `teardown()` method to clean up both instance variables from `__init__()` and
state from `setup()`. Eg. closing open connections or files.

```python
import typing as t
import pandas as pd
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
from fondant.component import PandasTransformComponent
from my_library import Client
def __init__(self, *, client_url) -> None:
"""
Args:
x_argument: An argument passed to the component
"""
# Initialize your component here based on the arguments
self.client = Client(client_url)
from my_library import HTTPClient
class MyComponent(PandasTransformComponent):
def __init__(self, *, client_url) -> None:
"""
Args:
client_url: An argument passed to the component
"""
# Initialize your component here based on the arguments
self.http_client = HTTPClient(client_url)
def setup(self) -> t.Any:
return Client(LocalCUDACluster)
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"""Implement your custom logic in this single method
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"""Implement your custom logic in this single method
Args:
dataframe: A Pandas dataframe containing the data
Args:
dataframe: A Pandas dataframe containing the data
Returns:
A pandas dataframe containing the transformed data
"""
Returns:
A pandas dataframe containing the transformed data
"""
def teardown(self):
"""Perform any cleanup after the component has been executed
"""
self.client.shutdown()
def teardown(self, dask_client):
"""Perform any cleanup after the component has been executed
"""
self.http_client.shutdown()
dask_client.shutdown()
```
28 changes: 19 additions & 9 deletions src/fondant/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,21 @@ def __init__(self):
self.consumes = None
self.produces = None

def teardown(self) -> None:
"""Method called after the component has been executed."""
def setup(self) -> t.Any:
"""Method to do additional component setup. This method can return a state (any object),
which is passed into the `teardown` method.
There's two reasons to separate this from `__init__`:
- It can be overwritten separately
- The Fondant executor handles the state, which is a good alternative for instance
attributes if the state is only needed in `setup` / `teardown`, since instance
attributes need to be pickleable when executing a component method across processes.
"""

def teardown(self, state: t.Any) -> None:
"""Method called after the component has been executed. The Fondant executor injects the
state returned by the `setup` method.
"""


class DaskComponent(BaseComponent):
Expand All @@ -32,23 +45,20 @@ class DaskComponent(BaseComponent):
def __init__(self, **kwargs):
super().__init__()

# don't assume every object is a string
dask.config.set({"dataframe.convert-string": False})
def setup(self) -> t.Any:
# worker.daemon is set to false because creating a worker process in daemon
# mode is not possible in our docker container setup.
dask.config.set({"distributed.worker.daemon": False})

self.dask_client()

def dask_client(self) -> Client:
"""Initialize the dask client to use for this component."""
cluster = LocalCluster(
processes=True,
n_workers=os.cpu_count(),
threads_per_worker=1,
)
return Client(cluster)

def teardown(self, client: t.Any) -> None:
return client.shutdown()


class DaskLoadComponent(DaskComponent):
"""Component that loads data and returns a Dask DataFrame."""
Expand Down
5 changes: 5 additions & 0 deletions src/fondant/component/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing as t
from collections import defaultdict

import dask
import dask.dataframe as dd
from dask.diagnostics import ProgressBar
from dask.distributed import Client
Expand Down Expand Up @@ -30,6 +31,10 @@ def __init__(
input_partition_rows: t.Optional[int] = None,
):
super().__init__(manifest=manifest, operation_spec=operation_spec)
# Don't assume every object is a string
# https://docs.dask.org/en/stable/changelog.html#v2023-7-1
dask.config.set({"dataframe.convert-string": False})

self.input_partition_rows = input_partition_rows

def partition_loaded_dataframe(self, dataframe: dd.DataFrame) -> dd.DataFrame:
Expand Down
4 changes: 3 additions & 1 deletion src/fondant/component/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def _run_execution(
component.consumes = self.operation_spec.inner_consumes
component.produces = self.operation_spec.inner_produces

state = component.setup()

output_df = self._execute_component(
component,
manifest=input_manifest,
Expand All @@ -298,7 +300,7 @@ def _run_execution(
)
self._write_data(dataframe=output_df, manifest=output_manifest)

component.teardown()
component.teardown(state)

return output_manifest

Expand Down
4 changes: 2 additions & 2 deletions src/fondant/components/embed_images/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ def __init__(

super().__init__()

def dask_client(self) -> Client:
def setup(self) -> Client:
if self.device == "cuda":
cluster = LocalCUDACluster()
return Client(cluster)

return super().dask_client()
return super().setup()

def process_image_batch(self, images: np.ndarray) -> t.List[torch.Tensor]:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/fondant/components/index_aws_opensearch/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
)
self.create_index(index_body)

def teardown(self) -> None:
def teardown(self, _) -> None:
self.client.close()

def create_index(self, index_body: Dict[str, Any]):
Expand Down
2 changes: 1 addition & 1 deletion src/fondant/components/index_qdrant/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self.batch_size = batch_size
self.parallelism = parallelism

def teardown(self) -> None:
def teardown(self, _) -> None:
self.client.close()

def write(self, dataframe: dd.DataFrame) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/fondant/components/index_weaviate/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def create_class_schema(self) -> t.Dict[str, t.Any]:

return class_schema

def teardown(self) -> None:
def teardown(self, _) -> None:
del self.client

def write(self, dataframe: dd.DataFrame) -> None:
Expand Down
39 changes: 20 additions & 19 deletions tests/component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def load(self):


@pytest.mark.usefixtures("_patched_data_writing")
def test_teardown_method(metadata):
def test_setup_teardown_methods(metadata):
# Mock CLI arguments load
operation_spec = OperationSpec(
ComponentSpec.from_file(components_path / "component.yaml"),
Expand Down Expand Up @@ -335,13 +335,13 @@ def shutdown(self):
if self.is_connected:
self.is_connected = False

client = MockClient()

class MyLoadComponent(DaskLoadComponent):
def __init__(self, *, flag, value):
self.flag = flag
self.value = value
self.client = client

def setup(self):
return MockClient()

def load(self):
data = {
Expand All @@ -350,19 +350,23 @@ def load(self):
}
return dd.DataFrame.from_dict(data, npartitions=N_PARTITIONS)

def teardown(self) -> None:
self.client.shutdown()
def teardown(self, client) -> None:
client.shutdown()

executor_factory = ExecutorFactory(MyLoadComponent)
executor = executor_factory.get_executor()
assert executor.input_partition_rows is None

setup = patch_method_class(MyLoadComponent.setup)
teardown = patch_method_class(MyLoadComponent.teardown)
assert client.is_connected is True
with mock.patch.object(MyLoadComponent, "teardown", teardown):
with mock.patch.object(MyLoadComponent, "setup", setup), mock.patch.object(
MyLoadComponent,
"teardown",
teardown,
):
executor.execute(MyLoadComponent)
setup.mock.assert_called_once()
teardown.mock.assert_called_once()
assert client.is_connected is False


@pytest.mark.usefixtures("_patched_data_loading", "_patched_data_writing")
Expand Down Expand Up @@ -394,6 +398,7 @@ def test_dask_transform_component(metadata):

class MyDaskComponent(DaskTransformComponent):
def __init__(self, *, flag, value):
super().__init__()
self.flag = flag
self.value = value

Expand Down Expand Up @@ -442,27 +447,23 @@ def test_pandas_transform_component(metadata):
"False",
]

init_called = 0

class MyPandasComponent(PandasTransformComponent):
def __init__(self, *, flag, value):
assert flag == "success"
assert value == 1
nonlocal init_called
init_called += 1

def transform(self, dataframe):
assert isinstance(dataframe, pd.DataFrame)
return dataframe.rename(columns={"images": "embeddings"})

init = patch_method_class(MyPandasComponent.__init__)
transform = patch_method_class(MyPandasComponent.transform)
executor_factory = ExecutorFactory(MyPandasComponent)
executor = executor_factory.get_executor()
with mock.patch.object(MyPandasComponent, "__init__", init), mock.patch.object(
MyPandasComponent,
"transform",
transform,
):
executor.execute(MyPandasComponent)
init.mock.assert_called_once()
assert transform.mock.call_count == N_PARTITIONS
executor.execute(MyPandasComponent)
assert init_called == 1


def test_wrap_transform():
Expand Down

0 comments on commit ae3aead

Please sign in to comment.