Skip to content

Commit

Permalink
Add teardown method (#767)
Browse files Browse the repository at this point in the history
PR that adds a teardown method to every component. Useful for shutting
down database connections and clients instead of it happening abruptly
after the container is shutdown.

Related to ml6team/fondant-internal#59
  • Loading branch information
PhilippeMoussalli authored Jan 11, 2024
1 parent 0861022 commit e677157
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 1 deletion.
35 changes: 35 additions & 0 deletions docs/components/component_spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,4 +301,39 @@ class ExampleComponent(PandasTransformComponent):
Returns:
A pandas dataframe containing the transformed data
"""
```

Afterwards, we pass all keyword arguments to the `__init__()` method of the component.


You can also use the a `teardown()` method to perform any cleanup after the component has been executed.
This is a good place to close any open connections or files.

```python
import pandas as pd
from fondant.component import PandasTransformComponent
from my_library import Client
def __init__(self, *, client_url, **kwargs) -> None:
"""
Args:
x_argument: An argument passed to the component
"""
# Initialize your component here based on the arguments
self.client = Client(client_url)
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"""Implement your custom logic in this single method
Args:
dataframe: A Pandas dataframe containing the data
Returns:
A pandas dataframe containing the transformed data
"""
def teardown(self):
"""Perform any cleanup after the component has been executed
"""
self.client.shutdown()
```
3 changes: 3 additions & 0 deletions src/fondant/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def __init__(
):
pass

def teardown(self) -> None:
"""Method called after the component has been executed."""


class DaskLoadComponent(BaseComponent):
"""Component that loads data and returns a Dask DataFrame."""
Expand Down
4 changes: 3 additions & 1 deletion src/fondant/component/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _run_execution(
input_manifest: Manifest,
) -> Manifest:
logging.info("Executing component")
component = component_cls(
component: Component = component_cls(
consumes=self.operation_spec.inner_consumes,
produces=self.operation_spec.inner_produces,
**self.user_arguments,
Expand All @@ -350,6 +350,8 @@ def _run_execution(
)
self._write_data(dataframe=output_df, manifest=output_manifest)

component.teardown()

return output_manifest

def execute(self, component_cls: t.Type[Component]) -> None:
Expand Down
63 changes: 63 additions & 0 deletions tests/component/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,69 @@ def load(self):
load.mock.assert_called_once()


@pytest.mark.usefixtures("_patched_data_writing")
def test_teardown_method(metadata):
# Mock CLI arguments load
operation_spec = OperationSpec(
ComponentSpec.from_file(components_path / "component.yaml"),
)

sys.argv = [
"",
"--metadata",
metadata.to_json(),
"--flag",
"success",
"--value",
"1",
"--output_manifest_path",
str(components_path / "output_manifest.json"),
"--operation_spec",
operation_spec.to_json(),
"--cache",
"False",
"--produces",
"{}",
]

class MockClient:
def __init__(self):
self.is_connected = True

def shutdown(self):
if self.is_connected:
self.is_connected = False

client = MockClient()

class MyLoadComponent(DaskLoadComponent):
def __init__(self, *, flag, value, **kwargs):
self.flag = flag
self.value = value
self.client = client

def load(self):
data = {
"id": [0, 1],
"captions_data": ["hello world", "this is another caption"],
}
return dd.DataFrame.from_dict(data, npartitions=N_PARTITIONS)

def teardown(self) -> None:
self.client.shutdown()

executor_factory = ExecutorFactory(MyLoadComponent)
executor = executor_factory.get_executor()
assert executor.input_partition_rows is None

teardown = patch_method_class(MyLoadComponent.teardown)
assert client.is_connected is True
with mock.patch.object(MyLoadComponent, "teardown", teardown):
executor.execute(MyLoadComponent)
teardown.mock.assert_called_once()
assert client.is_connected is False


@pytest.mark.usefixtures("_patched_data_loading", "_patched_data_writing")
def test_dask_transform_component(metadata):
operation_spec = OperationSpec(
Expand Down
Empty file added tests/examples/__init__.py
Empty file.

0 comments on commit e677157

Please sign in to comment.