Skip to content

Commit

Permalink
Add writer component (#196)
Browse files Browse the repository at this point in the history
PR that adds a writer class as discussed in #138. 
This enables us to write the final dataset without having to write the
dataset and manifest since there is no modification made on the data.

Next steps: 
- Enable default and optional arguments in components. The optional
arguments are needed to make the Reader/Writer components generic (e.g.
Write to hub requires special hf metadata to be attached to the image
column in case there is any, user needs to pass an optional argument
specifying the columns name of the image)
- Re implement load/Write to hub component to make them more generic.
  • Loading branch information
PhilippeMoussalli authored Jun 13, 2023
1 parent 2ce1b0e commit 4d733ea
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 11 deletions.
58 changes: 56 additions & 2 deletions fondant/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _load_or_create_manifest(self) -> Manifest:
"""Abstract method that returns the dataset manifest."""

@abstractmethod
def _process_dataset(self, manifest: Manifest) -> dd.DataFrame:
def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]:
"""Abstract method that processes the manifest and
returns another dataframe.
"""
Expand Down Expand Up @@ -223,7 +223,7 @@ def transform(self, *args, **kwargs) -> dd.DataFrame:
kwargs: Arguments provided to the component are passed as keyword arguments
"""

def _process_dataset(self, manifest: Manifest) -> dd.DataFrame:
def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]:
"""
Creates a DataLoader using the provided manifest and loads the input dataframe using the
`load_dataframe` instance, and applies data transformations to it using the `transform`
Expand All @@ -237,3 +237,57 @@ def _process_dataset(self, manifest: Manifest) -> dd.DataFrame:
df = self.transform(dataframe=df, **self.user_arguments)

return df


class WriteComponent(Component):
"""Base class for a Fondant write component."""

@classmethod
def _add_and_parse_args(cls, spec: ComponentSpec):
parser = argparse.ArgumentParser()
component_arguments = cls._get_component_arguments(spec)

for arg in component_arguments.values():
parser.add_argument(
f"--{arg.name}",
type=kubeflow2python_type[arg.type], # type: ignore
required=True,
help=arg.description,
)

return parser.parse_args()

def _load_or_create_manifest(self) -> Manifest:
return Manifest.from_file(self.input_manifest_path)

@abstractmethod
def write(self, *args, **kwargs):
"""
Abstract method to write a dataframe to a final custom location.
Args:
args: The dataframe will be passed in as a positional argument
kwargs: Arguments provided to the component are passed as keyword arguments
"""

def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]:
"""
Creates a DataLoader using the provided manifest and loads the input dataframe using the
`load_dataframe` instance, and applies data transformations to it using the `transform`
method implemented by the derived class. Returns a single dataframe.
Returns:
A `dd.DataFrame` instance with updated data based on the applied data transformations.
"""
data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec)
df = data_loader.load_dataframe()
self.write(dataframe=df, **self.user_arguments)

return None

def _write_data(self, dataframe: dd.DataFrame, *, manifest: Manifest):
"""Create a data writer given a manifest and writes out the index and subsets."""
pass

def upload_manifest(self, manifest: Manifest, save_path: str):
pass
43 changes: 34 additions & 9 deletions tests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
import yaml

from fondant.component import LoadComponent, TransformComponent
from fondant.component import LoadComponent, TransformComponent, WriteComponent
from fondant.data_io import DaskDataLoader

components_path = Path(__file__).parent / "example_specs/components"
Expand Down Expand Up @@ -75,8 +75,10 @@ def test_component(mock_args):
}


def test_valid_transform_kwargs(monkeypatch):
"""Test that arguments are passed correctly to `Component.transform` method."""
def test_transform_component(monkeypatch):
"""Test that arguments are passed correctly to `Component.transform` method and that valid
errors are returned when required arguments are missing.
"""

class EarlyStopException(Exception):
"""Used to stop execution early instead of mocking all later functionality."""
Expand Down Expand Up @@ -120,12 +122,23 @@ def transform(self, dataframe, *, flag, value):

# Instantiate and run component
component = MyComponent.from_args()

with pytest.raises(EarlyStopException):
component.run()

# Remove component specs from arguments
component_spec_index = sys.argv.index("--component_spec")
del sys.argv[component_spec_index : component_spec_index + 2]

# Instantiate and run component
with pytest.raises(ValueError):
MyComponent.from_args()


def test_invalid_transform_kwargs(monkeypatch):
"""Test that arguments are passed correctly to `Component.transform` method."""
def test_write_component(tmp_path_factory, monkeypatch):
"""Test that arguments are passed correctly to `Component.write` method and that valid
errors are returned when required arguments are missing.
"""

class EarlyStopException(Exception):
"""Used to stop execution early instead of mocking all later functionality."""
Expand All @@ -141,14 +154,16 @@ def mocked_load_dataframe(self):
component_spec = arguments_dir / "component.yaml"
input_manifest = arguments_dir / "input_manifest.json"

yaml_file_to_json_string(component_spec)
component_spec_string = yaml_file_to_json_string(component_spec)

# Implemented Component class
class MyComponent(TransformComponent):
def transform(self, dataframe, *, flag, value):
class MyComponent(WriteComponent):
def write(self, dataframe, *, flag, value):
assert flag == "success"
assert value == 1
raise EarlyStopException()
# Mock write function that sinks final data to a local directory
with tmp_path_factory.mktemp("temp") as fn:
dataframe.to_parquet(fn)

# Mock CLI arguments
sys.argv = [
Expand All @@ -163,8 +178,18 @@ def transform(self, dataframe, *, flag, value):
"1",
"--output_manifest_path",
"",
"--component_spec",
f"{component_spec_string}",
]

# # Instantiate and run component
component = MyComponent.from_args()
component.run()

# Remove component specs from arguments
component_spec_index = sys.argv.index("--component_spec")
del sys.argv[component_spec_index : component_spec_index + 2]

# Instantiate and run component
with pytest.raises(ValueError):
MyComponent.from_args()

0 comments on commit 4d733ea

Please sign in to comment.