diff --git a/fondant/component.py b/fondant/component.py index 7594bcfaf..adb11ebcb 100644 --- a/fondant/component.py +++ b/fondant/component.py @@ -112,7 +112,7 @@ def _load_or_create_manifest(self) -> Manifest: """Abstract method that returns the dataset manifest.""" @abstractmethod - def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: + def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]: """Abstract method that processes the manifest and returns another dataframe. """ @@ -223,7 +223,7 @@ def transform(self, *args, **kwargs) -> dd.DataFrame: kwargs: Arguments provided to the component are passed as keyword arguments """ - def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: + def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]: """ Creates a DataLoader using the provided manifest and loads the input dataframe using the `load_dataframe` instance, and applies data transformations to it using the `transform` @@ -237,3 +237,57 @@ def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: df = self.transform(dataframe=df, **self.user_arguments) return df + + +class WriteComponent(Component): + """Base class for a Fondant write component.""" + + @classmethod + def _add_and_parse_args(cls, spec: ComponentSpec): + parser = argparse.ArgumentParser() + component_arguments = cls._get_component_arguments(spec) + + for arg in component_arguments.values(): + parser.add_argument( + f"--{arg.name}", + type=kubeflow2python_type[arg.type], # type: ignore + required=True, + help=arg.description, + ) + + return parser.parse_args() + + def _load_or_create_manifest(self) -> Manifest: + return Manifest.from_file(self.input_manifest_path) + + @abstractmethod + def write(self, *args, **kwargs): + """ + Abstract method to write a dataframe to a final custom location. + + Args: + args: The dataframe will be passed in as a positional argument + kwargs: Arguments provided to the component are passed as keyword arguments + """ + + def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]: + """ + Creates a DataLoader using the provided manifest and loads the input dataframe using the + `load_dataframe` instance, and applies data transformations to it using the `transform` + method implemented by the derived class. Returns a single dataframe. + + Returns: + A `dd.DataFrame` instance with updated data based on the applied data transformations. + """ + data_loader = DaskDataLoader(manifest=manifest, component_spec=self.spec) + df = data_loader.load_dataframe() + self.write(dataframe=df, **self.user_arguments) + + return None + + def _write_data(self, dataframe: dd.DataFrame, *, manifest: Manifest): + """Create a data writer given a manifest and writes out the index and subsets.""" + pass + + def upload_manifest(self, manifest: Manifest, save_path: str): + pass diff --git a/tests/test_component.py b/tests/test_component.py index 4c8f267c0..d35ff8277 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -10,7 +10,7 @@ import pytest import yaml -from fondant.component import LoadComponent, TransformComponent +from fondant.component import LoadComponent, TransformComponent, WriteComponent from fondant.data_io import DaskDataLoader components_path = Path(__file__).parent / "example_specs/components" @@ -75,8 +75,10 @@ def test_component(mock_args): } -def test_valid_transform_kwargs(monkeypatch): - """Test that arguments are passed correctly to `Component.transform` method.""" +def test_transform_component(monkeypatch): + """Test that arguments are passed correctly to `Component.transform` method and that valid + errors are returned when required arguments are missing. + """ class EarlyStopException(Exception): """Used to stop execution early instead of mocking all later functionality.""" @@ -120,12 +122,23 @@ def transform(self, dataframe, *, flag, value): # Instantiate and run component component = MyComponent.from_args() + with pytest.raises(EarlyStopException): component.run() + # Remove component specs from arguments + component_spec_index = sys.argv.index("--component_spec") + del sys.argv[component_spec_index : component_spec_index + 2] + + # Instantiate and run component + with pytest.raises(ValueError): + MyComponent.from_args() + -def test_invalid_transform_kwargs(monkeypatch): - """Test that arguments are passed correctly to `Component.transform` method.""" +def test_write_component(tmp_path_factory, monkeypatch): + """Test that arguments are passed correctly to `Component.write` method and that valid + errors are returned when required arguments are missing. + """ class EarlyStopException(Exception): """Used to stop execution early instead of mocking all later functionality.""" @@ -141,14 +154,16 @@ def mocked_load_dataframe(self): component_spec = arguments_dir / "component.yaml" input_manifest = arguments_dir / "input_manifest.json" - yaml_file_to_json_string(component_spec) + component_spec_string = yaml_file_to_json_string(component_spec) # Implemented Component class - class MyComponent(TransformComponent): - def transform(self, dataframe, *, flag, value): + class MyComponent(WriteComponent): + def write(self, dataframe, *, flag, value): assert flag == "success" assert value == 1 - raise EarlyStopException() + # Mock write function that sinks final data to a local directory + with tmp_path_factory.mktemp("temp") as fn: + dataframe.to_parquet(fn) # Mock CLI arguments sys.argv = [ @@ -163,8 +178,18 @@ def transform(self, dataframe, *, flag, value): "1", "--output_manifest_path", "", + "--component_spec", + f"{component_spec_string}", ] + # # Instantiate and run component + component = MyComponent.from_args() + component.run() + + # Remove component specs from arguments + component_spec_index = sys.argv.index("--component_spec") + del sys.argv[component_spec_index : component_spec_index + 2] + # Instantiate and run component with pytest.raises(ValueError): MyComponent.from_args()