diff --git a/docs/component_spec.md b/docs/component_spec.md index 6d11e6184..00022103c 100644 --- a/docs/component_spec.md +++ b/docs/component_spec.md @@ -121,41 +121,49 @@ Please check the [examples](#examples) below to build a better understanding. ### Args The `args` section describes which arguments the component takes. Each argument is defined by a -`description` and a `type`, which should be one of the builtin Python types. +`description` and a `type`, which should be one of the builtin Python types. Additionally, you can +set an optional `default` value for each argument. +_Note:_ default iterable arguments such as `dict` and `list` have to be passed as a string +(e.g. `'{"foo":1, "bar":2}`, `'["foo","bar]'`) ```yaml args: custom_argument: description: A custom argument type: str + default_argument: + description: A default argument + type: str + default: bar ``` -These arguments are passed in when the component is instantiated: - +These arguments are passed in when the component is instantiated. +If an argument is not explicitly provided, the default value will be used instead if available.``` ```python from fondant.pipeline import ComponentOp custom_op = ComponentOp( component_spec_path="components/custom_component/fondant_component.yaml", arguments={ - "custom_argument": "foobar" + "custom_argument": "foo" }, ) ``` -And passed as a keyword argument to the `transform()` method of the component. +Afterwards, we pass all keyword arguments to the `transform()` method of the component. ```python from fondant.component import TransformComponent class ExampleComponent(TransformComponent): - def transform(self, dataframe, *, custom_argument): + def transform(self, dataframe, *, custom_argument, default_argument): """Implement your custom logic in this single method Args: dataframe: A Dask dataframe containing the data custom_argument: An argument passed to the component + default_argument: A default argument passed to the components """ ``` diff --git a/fondant/component.py b/fondant/component.py index adb11ebcb..f61c09585 100644 --- a/fondant/component.py +++ b/fondant/component.py @@ -77,7 +77,6 @@ def from_spec(cls, component_spec: ComponentSpec) -> "Component": metadata = args_dict.pop("metadata") metadata = json.loads(metadata) if metadata else {} - return cls( component_spec, input_manifest_path=input_manifest_path, @@ -86,6 +85,38 @@ def from_spec(cls, component_spec: ComponentSpec) -> "Component": user_arguments=args_dict, ) + @classmethod + def _add_and_parse_args(cls, spec: ComponentSpec): + parser = argparse.ArgumentParser() + component_arguments = cls._get_component_arguments(spec) + + for arg in component_arguments.values(): + # Input manifest is not required for loading component + if arg.name in cls.optional_fondant_arguments(): + input_required = False + default = None + elif arg.default: + input_required = False + default = arg.default + else: + input_required = True + default = None + + parser.add_argument( + f"--{arg.name}", + type=kubeflow2python_type[arg.type], # type: ignore + required=input_required, + default=default, + help=arg.description, + ) + + return parser.parse_args() + + @staticmethod + @abstractmethod + def optional_fondant_arguments() -> t.List[str]: + pass + @staticmethod def _get_component_arguments(spec: ComponentSpec) -> t.Dict[str, Argument]: """ @@ -102,11 +133,6 @@ def _get_component_arguments(spec: ComponentSpec) -> t.Dict[str, Argument]: component_arguments.update(kubeflow_component_spec.output_arguments) return component_arguments - @classmethod - @abstractmethod - def _add_and_parse_args(cls, spec: ComponentSpec) -> argparse.Namespace: - """Abstract method to add and parse the component arguments.""" - @abstractmethod def _load_or_create_manifest(self) -> Manifest: """Abstract method that returns the dataset manifest.""" @@ -142,26 +168,9 @@ def upload_manifest(self, manifest: Manifest, save_path: str): class LoadComponent(Component): """Base class for a Fondant load component.""" - @classmethod - def _add_and_parse_args(cls, spec: ComponentSpec): - parser = argparse.ArgumentParser() - component_arguments = cls._get_component_arguments(spec) - - for arg in component_arguments.values(): - # Input manifest is not required for loading component - if arg.name == "input_manifest_path": - input_required = False - else: - input_required = True - - parser.add_argument( - f"--{arg.name}", - type=kubeflow2python_type[arg.type], # type: ignore - required=input_required, - help=arg.description, - ) - - return parser.parse_args() + @staticmethod + def optional_fondant_arguments() -> t.List[str]: + return ["input_manifest_path"] def _load_or_create_manifest(self) -> Manifest: # create initial manifest @@ -195,20 +204,9 @@ def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: class TransformComponent(Component): """Base class for a Fondant transform component.""" - @classmethod - def _add_and_parse_args(cls, spec: ComponentSpec): - parser = argparse.ArgumentParser() - component_arguments = cls._get_component_arguments(spec) - - for arg in component_arguments.values(): - parser.add_argument( - f"--{arg.name}", - type=kubeflow2python_type[arg.type], # type: ignore - required=True, - help=arg.description, - ) - - return parser.parse_args() + @staticmethod + def optional_fondant_arguments() -> t.List[str]: + return [] def _load_or_create_manifest(self) -> Manifest: return Manifest.from_file(self.input_manifest_path) @@ -242,20 +240,9 @@ def _process_dataset(self, manifest: Manifest) -> t.Union[None, dd.DataFrame]: class WriteComponent(Component): """Base class for a Fondant write component.""" - @classmethod - def _add_and_parse_args(cls, spec: ComponentSpec): - parser = argparse.ArgumentParser() - component_arguments = cls._get_component_arguments(spec) - - for arg in component_arguments.values(): - parser.add_argument( - f"--{arg.name}", - type=kubeflow2python_type[arg.type], # type: ignore - required=True, - help=arg.description, - ) - - return parser.parse_args() + @staticmethod + def optional_fondant_arguments() -> t.List[str]: + return ["output_manifest_path"] def _load_or_create_manifest(self) -> Manifest: return Manifest.from_file(self.input_manifest_path) diff --git a/fondant/component_spec.py b/fondant/component_spec.py index 40f88e4dc..b9dbf6261 100644 --- a/fondant/component_spec.py +++ b/fondant/component_spec.py @@ -46,11 +46,14 @@ class Argument: name: name of the argument description: argument description type: the python argument type (str, int, ...) + default: default value of the argument (defaults to None) + """ name: str description: str type: str + default: t.Optional[str] = None class ComponentSubset: @@ -180,6 +183,7 @@ def args(self) -> t.Dict[str, Argument]: name=name, description=arg_info["description"], type=arg_info["type"], + default=arg_info["default"] if "default" in arg_info else None, ) for name, arg_info in self._specification.get("args", {}).items() } @@ -236,6 +240,7 @@ def from_fondant_component_spec( "name": arg.name, "description": arg.description, "type": python2kubeflow_type[arg.type], + **({"default": arg.default} if arg.default is not None else {}), } for arg in fondant_component.args.values() ), @@ -305,6 +310,7 @@ def input_arguments(self) -> t.Mapping[str, Argument]: name=info["name"], description=info["description"], type=info["type"], + default=info["default"] if "default" in info else None, ) for info in self._specification["inputs"] } diff --git a/fondant/schemas/component_spec.json b/fondant/schemas/component_spec.json index d196c3db3..aee436c00 100644 --- a/fondant/schemas/component_spec.json +++ b/fondant/schemas/component_spec.json @@ -80,6 +80,16 @@ }, "description": { "type": "string" + }, + "default": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "number" + } + ] } }, "required": [ diff --git a/tests/example_specs/components/arguments/component_default_args.yaml b/tests/example_specs/components/arguments/component_default_args.yaml new file mode 100644 index 000000000..7a1aabb01 --- /dev/null +++ b/tests/example_specs/components/arguments/component_default_args.yaml @@ -0,0 +1,33 @@ +name: Example component +description: This is an example component +image: example_component:latest + +args: + string_default_arg: + description: default string argument + type: str + default: foo + integer_default_arg: + description: default integer argument + type: int + default: 1 + float_default_arg: + description: default float argument + type: float + default: 3.14 + bool_default_arg: + description: default bool argument + type: bool + default: 'False' + list_default_arg: + description: default list argument + type: list + default: '["foo", "bar"]' + dict_default_arg: + description: default dict argument + type: dict + default: '{"foo":1, "bar":2}' + override_default_string_arg: + description: default argument that can be overriden + type: str + default: foo diff --git a/tests/test_component.py b/tests/test_component.py index d35ff8277..537899032 100644 --- a/tests/test_component.py +++ b/tests/test_component.py @@ -140,9 +140,6 @@ def test_write_component(tmp_path_factory, monkeypatch): errors are returned when required arguments are missing. """ - class EarlyStopException(Exception): - """Used to stop execution early instead of mocking all later functionality.""" - # Mock `Dataset.load_dataframe` so no actual data is loaded def mocked_load_dataframe(self): return dd.from_dict({"a": [1, 2, 3]}, npartitions=1) @@ -193,3 +190,65 @@ def write(self, dataframe, *, flag, value): # Instantiate and run component with pytest.raises(ValueError): MyComponent.from_args() + + +def test_default_args_component(tmp_path_factory, monkeypatch): + """Test that default arguments defined in the fondant spec are passed correctly and have the + proper data type. + """ + + # Mock `Dataset.load_dataframe` so no actual data is loaded + def mocked_load_dataframe(self): + return dd.from_dict({"a": [1, 2, 3]}, npartitions=1) + + monkeypatch.setattr(DaskDataLoader, "load_dataframe", mocked_load_dataframe) + + # Define paths to specs to instantiate component + arguments_dir = components_path / "arguments" + component_spec = arguments_dir / "component_default_args.yaml" + input_manifest = arguments_dir / "input_manifest.json" + + component_spec_string = yaml_file_to_json_string(component_spec) + + # Implemented Component class + class MyComponent(WriteComponent): + def write( + self, + dataframe, + *, + string_default_arg, + integer_default_arg, + float_default_arg, + bool_default_arg, + list_default_arg, + dict_default_arg, + override_default_string_arg, + ): + float_const = 3.14 + # Mock write function that sinks final data to a local directory + assert string_default_arg == "foo" + assert integer_default_arg == 1 + assert float_default_arg == float_const + assert bool_default_arg is False + assert list_default_arg == ["foo", "bar"] + assert dict_default_arg == {"foo": 1, "bar": 2} + assert override_default_string_arg == "bar" + + # Mock CLI arguments + sys.argv = [ + "", + "--input_manifest_path", + str(input_manifest), + "--metadata", + "", + "--output_manifest_path", + "", + "--component_spec", + f"{component_spec_string}", + "--override_default_string_arg", + "bar", + ] + + # # Instantiate and run component + component = MyComponent.from_args() + component.run()