diff --git a/examples/sample_pipeline/pipeline.py b/examples/sample_pipeline/pipeline.py index b0acd9682..d9089b866 100644 --- a/examples/sample_pipeline/pipeline.py +++ b/examples/sample_pipeline/pipeline.py @@ -1,8 +1,10 @@ # This file contains a sample pipeline. Loading data from a parquet file, # using the load_from_parquet component, chain a custom dummy component, and use # the reusable chunking component -import pyarrow as pa from pathlib import Path + +import pyarrow as pa + from fondant.pipeline import Pipeline BASE_PATH = Path("./.artifacts").resolve() @@ -17,7 +19,7 @@ } dataset = pipeline.read( - name_or_path="load_from_parquet", + "load_from_parquet", arguments={ "dataset_uri": "/data/sample.parquet", "column_name_mapping": load_component_column_mapping, @@ -27,11 +29,11 @@ ) dataset = dataset.apply( - name_or_path="./components/dummy_component", + "./components/dummy_component", ) dataset.apply( - name_or_path="chunk_text", + "chunk_text", arguments={"chunk_size": 10, "chunk_overlap": 2}, consumes={"text": "text_data"}, ) diff --git a/src/fondant/build.py b/src/fondant/build.py index 565209ab6..35c9a77b9 100644 --- a/src/fondant/build.py +++ b/src/fondant/build.py @@ -33,7 +33,7 @@ def build_component( # ruff: noqa: PLR0912, PLR0915 msg, ) - component_op = ComponentOp(component_dir) + component_op = ComponentOp.from_component_yaml(component_dir) component_spec = component_op.component_spec if component_op.dockerfile_path is None: diff --git a/src/fondant/core/component_spec.py b/src/fondant/core/component_spec.py index a22385551..452d92d1f 100644 --- a/src/fondant/core/component_spec.py +++ b/src/fondant/core/component_spec.py @@ -61,11 +61,58 @@ class ComponentSpec: Class representing a Fondant component specification. Args: - specification: The fondant component specification as a Python dict + name: The name of the component + image: The docker image uri to use for the component + description: The description of the component + consumes: A mapping containing the fields consumed by the operation. The keys are the + names of the fields to be received by the component, while the values are the + type of the field. + + + produces: A mapping containing the fields produced by the operation. The keys are the + names of the fields to be produced by the component, while the values are the + type of the field to be written + + arguments: A dictionary containing the argument name and value for the operation. + """ - def __init__(self, specification: t.Dict[str, t.Any]) -> None: - self._specification = copy.deepcopy(specification) + def __init__( + self, + name: str, + image: str, + *, + description: t.Optional[str] = None, + consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None, + produces: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None, + previous_index: t.Optional[str] = None, + args: t.Optional[t.Dict[str, t.Any]] = None, + tags: t.Optional[t.List[str]] = None, + ): + spec_dict: t.Dict[str, t.Any] = { + "name": name, + "image": image, + } + + if description: + spec_dict["description"] = description + + if tags: + spec_dict["tags"] = tags + + if consumes: + spec_dict["consumes"] = consumes + + if produces: + spec_dict["produces"] = produces + + if previous_index: + spec_dict["previous_index"] = previous_index + + if args: + spec_dict["args"] = args + + self._specification = spec_dict self._validate_spec() def _validate_spec(self) -> None: @@ -102,13 +149,22 @@ def from_file(cls, path: t.Union[str, Path]) -> "ComponentSpec": """Load the component spec from the file specified by the provided path.""" with open(path, encoding="utf-8") as file_: specification = yaml.safe_load(file_) - return cls(specification) + return cls.from_dict(specification) def to_file(self, path) -> None: """Dump the component spec to the file specified by the provided path.""" with open(path, "w", encoding="utf-8") as file_: yaml.dump(self._specification, file_) + @classmethod + def from_dict(cls, component_spec_dict: t.Dict[str, t.Any]) -> "ComponentSpec": + """Load the component spec from a dictionary.""" + try: + return cls(**component_spec_dict) + except TypeError as e: + msg = f"Invalid component spec: {e}" + raise InvalidComponentSpec(msg) + @property def name(self): return self._specification["name"] @@ -334,7 +390,9 @@ def _parse_mapping( return json_mapping return cls( - component_spec=ComponentSpec(operation_spec_dict["specification"]), + component_spec=ComponentSpec.from_dict( + operation_spec_dict["specification"], + ), consumes=_parse_mapping(operation_spec_dict["consumes"]), produces=_parse_mapping(operation_spec_dict["produces"]), ) diff --git a/src/fondant/pipeline/__init__.py b/src/fondant/pipeline/__init__.py index 2fd1aa6fc..0c3f60ac1 100644 --- a/src/fondant/pipeline/__init__.py +++ b/src/fondant/pipeline/__init__.py @@ -1,8 +1,9 @@ +from .lightweight_component import Image, PythonComponent, lightweight_component # noqa from .pipeline import ( # noqa + VALID_ACCELERATOR_TYPES, + VALID_VERTEX_ACCELERATOR_TYPES, ComponentOp, Dataset, Pipeline, Resources, - VALID_ACCELERATOR_TYPES, - VALID_VERTEX_ACCELERATOR_TYPES, ) diff --git a/src/fondant/pipeline/lightweight_component.py b/src/fondant/pipeline/lightweight_component.py new file mode 100644 index 000000000..cd4146347 --- /dev/null +++ b/src/fondant/pipeline/lightweight_component.py @@ -0,0 +1,42 @@ +import typing as t +from dataclasses import dataclass +from functools import wraps + + +@dataclass +class Image: + base_image: str = "fondant:latest" + extra_requires: t.Optional[t.List[str]] = None + script: t.Optional[str] = None + + +class PythonComponent: + @classmethod + def image(cls) -> Image: + raise NotImplementedError + + +def lightweight_component( + extra_requires: t.Optional[t.List[str]] = None, + base_image: t.Optional[str] = None, +): + """Decorator to enable a python component.""" + + def wrapper(cls): + kwargs = {} + if base_image: + kwargs["base_image"] = base_image + if extra_requires: + kwargs["extra_requires"] = extra_requires + image = Image(**kwargs) + + # updated=() is needed to prevent an attempt to update the class's __dict__ + @wraps(cls, updated=()) + class AppliedPythonComponent(cls, PythonComponent): + @classmethod + def image(cls) -> Image: + return image + + return AppliedPythonComponent + + return wrapper diff --git a/src/fondant/pipeline/pipeline.py b/src/fondant/pipeline/pipeline.py index 8c1027342..04b38fc94 100644 --- a/src/fondant/pipeline/pipeline.py +++ b/src/fondant/pipeline/pipeline.py @@ -1,6 +1,7 @@ """This module defines classes to represent a Fondant Pipeline.""" import datetime import hashlib +import inspect import json import logging import re @@ -20,6 +21,7 @@ from fondant.core.exceptions import InvalidPipelineDefinition from fondant.core.manifest import Manifest from fondant.core.schema import Field +from fondant.pipeline import Image, PythonComponent logger = logging.getLogger(__name__) @@ -131,7 +133,9 @@ class ComponentOp: def __init__( self, - name_or_path: t.Union[str, Path], + name: str, + image: Image, + component_spec: ComponentSpec, *, consumes: t.Optional[t.Dict[str, str]] = None, produces: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None, @@ -141,20 +145,16 @@ def __init__( cluster_type: t.Optional[str] = "default", client_kwargs: t.Optional[dict] = None, resources: t.Optional[Resources] = None, + component_dir: t.Optional[Path] = None, ) -> None: - if self._is_custom_component(name_or_path): - self.component_dir = Path(name_or_path) - else: - self.component_dir = self._get_registry_path(str(name_or_path)) - + self.name = name + self.image = image + self.component_spec = component_spec self.input_partition_rows = input_partition_rows - self.component_spec = ComponentSpec.from_file( - self.component_dir / self.COMPONENT_SPEC_NAME, - ) - self.name = self.component_spec.component_folder_name self.cache = self._configure_caching_from_image_tag(cache) self.cluster_type = cluster_type self.client_kwargs = client_kwargs + self.component_dir = component_dir self.operation_spec = OperationSpec( self.component_spec, @@ -181,6 +181,28 @@ def __init__( self.resources = resources or Resources() + @classmethod + def from_component_yaml(cls, path, **kwargs): + if cls._is_custom_component(path): + component_dir = Path(path) + else: + component_dir = cls._get_registry_path(str(path)) + component_spec = ComponentSpec.from_file( + component_dir / cls.COMPONENT_SPEC_NAME, + ) + name = component_spec.component_folder_name + + image = Image( + base_image=component_spec.image, + ) + return cls( + name=name, + image=image, + component_spec=component_spec, + component_dir=component_dir, + **kwargs, + ) + def _configure_caching_from_image_tag( self, cache: t.Optional[bool], @@ -213,10 +235,13 @@ def _configure_caching_from_image_tag( return cache - @property - def dockerfile_path(self) -> t.Optional[Path]: - path = self.component_dir / "Dockerfile" - return path if path.exists() else None + def dockerfile_path(self, path: t.Union[str, Path]) -> t.Optional[Path]: + if self._is_custom_component(path): + component_dir = Path(path) + else: + component_dir = self._get_registry_path(str(path)) + docker_path = component_dir / "Dockerfile" + return docker_path if docker_path.exists() else None @staticmethod def _is_custom_component(path_or_name: t.Union[str, Path]) -> bool: @@ -325,7 +350,7 @@ def register_operation( def read( self, - name_or_path: t.Union[str, Path], + ref: t.Any, *, produces: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None, arguments: t.Optional[t.Dict[str, t.Any]] = None, @@ -339,8 +364,8 @@ def read( Read data using the provided component. Args: - name_or_path: The name of a reusable component, or the path to the directory containing - a custom component. + ref: The name of a reusable component, or the path to the directory containing + a custom component, or a python component class. produces: A mapping to update the fields produced by the operation as defined in the component spec. The keys are the names of the fields to be received by the component, while the values are the type of the field, or the name of the field to @@ -362,16 +387,43 @@ def read( msg, ) - operation = ComponentOp( - name_or_path, - produces=produces, - arguments=arguments, - input_partition_rows=input_partition_rows, - resources=resources, - cache=cache, - cluster_type=cluster_type, - client_kwargs=client_kwargs, - ) + if inspect.isclass(ref) and issubclass(ref, PythonComponent): + name = ref.__name__ + image = ref.image() + description = ref.__doc__ or "python component" + + component_spec = ComponentSpec( + name, + image.base_image, # TODO: revisit + description=description, + consumes={"additionalProperties": True}, + produces={"additionalProperties": True}, + ) + + operation = ComponentOp( + name, + image, + component_spec, + produces=produces, + arguments=arguments, + input_partition_rows=input_partition_rows, + resources=resources, + cache=cache, + cluster_type=cluster_type, + client_kwargs=client_kwargs, + ) + + else: + operation = ComponentOp.from_component_yaml( + ref, + produces=produces, + arguments=arguments, + input_partition_rows=input_partition_rows, + resources=resources, + cache=cache, + cluster_type=cluster_type, + client_kwargs=client_kwargs, + ) manifest = Manifest.create( pipeline_name=self.name, @@ -546,7 +598,7 @@ def _apply(self, operation: ComponentOp) -> "Dataset": def apply( self, - name_or_path: t.Union[str, Path], + ref: t.Any, *, consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None, produces: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None, @@ -561,8 +613,8 @@ def apply( Apply the provided component on the dataset. Args: - name_or_path: The name of a reusable component, or the path to the directory containing - a custom component. + ref: The name of a reusable component, or the path to the directory containing + a custom component, or a python component class. consumes: A mapping to update the fields consumed by the operation as defined in the component spec. The keys are the names of the fields to be received by the component, while the values are the type of the field, or the name of the field to @@ -646,22 +698,51 @@ def apply( Returns: An intermediate dataset. """ - operation = ComponentOp( - name_or_path, - consumes=consumes, - produces=produces, - arguments=arguments, - input_partition_rows=input_partition_rows, - resources=resources, - cache=cache, - cluster_type=cluster_type, - client_kwargs=client_kwargs, - ) + if inspect.isclass(ref) and issubclass(ref, PythonComponent): + name = ref.__name__ + image = ref.image() + description = ref.__doc__ or "python component" + + component_spec = ComponentSpec( + name, + image.base_image, # TODO: revisit + description=description, + consumes={"additionalProperties": True}, + produces={"additionalProperties": True}, + ) + + operation = ComponentOp( + name, + image, + component_spec, + consumes=consumes, + produces=produces, + arguments=arguments, + input_partition_rows=input_partition_rows, + resources=resources, + cache=cache, + cluster_type=cluster_type, + client_kwargs=client_kwargs, + ) + + else: + operation = ComponentOp.from_component_yaml( + ref, + consumes=consumes, + produces=produces, + arguments=arguments, + input_partition_rows=input_partition_rows, + resources=resources, + cache=cache, + cluster_type=cluster_type, + client_kwargs=client_kwargs, + ) + return self._apply(operation) def write( self, - name_or_path: t.Union[str, Path], + ref: t.Any, *, consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None, arguments: t.Optional[t.Dict[str, t.Any]] = None, @@ -675,8 +756,8 @@ def write( Write the dataset using the provided component. Args: - name_or_path: The name of a reusable component, or the path to the directory containing - a custom component. + ref: The name of a reusable component, or the path to the directory containing + a custom component, or a python component class. consumes: A mapping to update the fields consumed by the operation as defined in the component spec. The keys are the names of the fields to be received by the component, while the values are the type of the field, or the name of the field to @@ -692,14 +773,41 @@ def write( Returns: An intermediate dataset. """ - operation = ComponentOp( - name_or_path, - consumes=consumes, - arguments=arguments, - input_partition_rows=input_partition_rows, - resources=resources, - cache=cache, - cluster_type=cluster_type, - client_kwargs=client_kwargs, - ) + if inspect.isclass(ref) and issubclass(ref, PythonComponent): + name = ref.__name__ + image = ref.image() + description = ref.__doc__ or "python component" + + component_spec = ComponentSpec( + name, + image.base_image, # TODO: revisit + description=description, + consumes={"additionalProperties": True}, + produces={"additionalProperties": True}, + ) + + operation = ComponentOp( + name, + image, + component_spec, + consumes=consumes, + arguments=arguments, + input_partition_rows=input_partition_rows, + resources=resources, + cache=cache, + cluster_type=cluster_type, + client_kwargs=client_kwargs, + ) + + else: + operation = ComponentOp.from_component_yaml( + ref, + consumes=consumes, + arguments=arguments, + input_partition_rows=input_partition_rows, + resources=resources, + cache=cache, + cluster_type=cluster_type, + client_kwargs=client_kwargs, + ) self._apply(operation) diff --git a/tests/component/test_component.py b/tests/component/test_component.py index 191e6e329..6890a0136 100644 --- a/tests/component/test_component.py +++ b/tests/component/test_component.py @@ -468,7 +468,7 @@ def test_wrap_transform(): - Trimming columns not specified in `produces` - Ordering columns according to specification (so `map_partitions` does not fail). """ - spec = ComponentSpec( + spec = ComponentSpec.from_dict( { "name": "Test component", "description": "Component for testing", diff --git a/tests/core/examples/component_specs/valid_component.yaml b/tests/core/examples/component_specs/valid_component.yaml index 1215af1bd..cd36de031 100644 --- a/tests/core/examples/component_specs/valid_component.yaml +++ b/tests/core/examples/component_specs/valid_component.yaml @@ -1,6 +1,6 @@ name: Example component -description: This is an example component image: example_component:latest +description: This is an example component tags: - Data loading diff --git a/tests/core/test_component_specs.py b/tests/core/test_component_specs.py index fbe63649f..9feea80dc 100644 --- a/tests/core/test_component_specs.py +++ b/tests/core/test_component_specs.py @@ -64,14 +64,14 @@ def valid_fondant_schema_generic_consumes_produces() -> dict: def test_component_spec_pkgutil_error(mock_get_data): """Test that FileNotFoundError is raised when pkgutil.get_data returns None.""" with pytest.raises(FileNotFoundError): - ComponentSpec("example_component.yaml") + ComponentSpec.from_file("example_component.yaml") def test_component_spec_validation(valid_fondant_schema, invalid_fondant_schema): """Test that the component spec is validated correctly on instantiation.""" - ComponentSpec(valid_fondant_schema) + ComponentSpec.from_dict(valid_fondant_schema) with pytest.raises(InvalidComponentSpec): - ComponentSpec(invalid_fondant_schema) + ComponentSpec.from_dict(invalid_fondant_schema) def test_component_spec_load_from_file(valid_fondant_schema, invalid_fondant_schema): @@ -87,7 +87,7 @@ def test_attribute_access(valid_fondant_schema): - Fixed properties should be accessible as an attribute - Dynamic properties should be accessible by lookup. """ - fondant_component = ComponentSpec(valid_fondant_schema) + fondant_component = ComponentSpec.from_dict(valid_fondant_schema) assert fondant_component.name == "Example component" assert fondant_component.description == "This is an example component" @@ -99,14 +99,14 @@ def test_attribute_access(valid_fondant_schema): def test_kfp_component_creation(valid_fondant_schema, valid_kubeflow_schema): """Test that the created kubeflow component matches the expected kubeflow component.""" - fondant_component = ComponentSpec(valid_fondant_schema) + fondant_component = ComponentSpec.from_dict(valid_fondant_schema) kubeflow_component = fondant_component.kubeflow_specification assert kubeflow_component._specification == valid_kubeflow_schema def test_component_spec_no_args(valid_fondant_schema_no_args): """Test that a component spec without args is supported.""" - fondant_component = ComponentSpec(valid_fondant_schema_no_args) + fondant_component = ComponentSpec.from_dict(valid_fondant_schema_no_args) assert fondant_component.name == "Example component" assert fondant_component.description == "This is an example component" @@ -115,7 +115,7 @@ def test_component_spec_no_args(valid_fondant_schema_no_args): def test_component_spec_to_file(valid_fondant_schema): """Test that the ComponentSpec can be written to a file.""" - component_spec = ComponentSpec(valid_fondant_schema) + component_spec = ComponentSpec.from_dict(valid_fondant_schema) with tempfile.TemporaryDirectory() as temp_dir: file_path = os.path.join(temp_dir, "component_spec.yaml") @@ -145,7 +145,7 @@ def test_kubeflow_component_spec_to_file(valid_kubeflow_schema): def test_component_spec_repr(valid_fondant_schema): """Test that the __repr__ method of ComponentSpec returns the expected string.""" - fondant_component = ComponentSpec(valid_fondant_schema) + fondant_component = ComponentSpec.from_dict(valid_fondant_schema) expected_repr = f"ComponentSpec({valid_fondant_schema!r})" assert repr(fondant_component) == expected_repr @@ -159,21 +159,23 @@ def test_kubeflow_component_spec_repr(valid_kubeflow_schema): def test_component_spec_generic_consumes(valid_fondant_schema_generic_consumes): """Test that a component spec with generic consumes is detected.""" - component_spec = ComponentSpec(valid_fondant_schema_generic_consumes) + component_spec = ComponentSpec.from_dict(valid_fondant_schema_generic_consumes) assert component_spec.is_generic("consumes") is True assert component_spec.is_generic("produces") is False def test_component_spec_generic_produces(valid_fondant_schema_generic_produces): """Test that a component spec with generic produces is detected.""" - component_spec = ComponentSpec(valid_fondant_schema_generic_produces) + component_spec = ComponentSpec.from_dict(valid_fondant_schema_generic_produces) assert component_spec.is_generic("consumes") is False assert component_spec.is_generic("produces") is True def test_operation_spec_parsing(valid_fondant_schema_generic_consumes_produces): """Test that the operation spec is parsed correctly.""" - component_spec = ComponentSpec(valid_fondant_schema_generic_consumes_produces) + component_spec = ComponentSpec.from_dict( + valid_fondant_schema_generic_consumes_produces, + ) operation_spec = OperationSpec( component_spec=component_spec, consumes={ diff --git a/tests/core/test_manifest_evolution.py b/tests/core/test_manifest_evolution.py index 71f39d4f9..659d9110a 100644 --- a/tests/core/test_manifest_evolution.py +++ b/tests/core/test_manifest_evolution.py @@ -86,7 +86,7 @@ def examples(manifest_examples): def test_evolution(input_manifest, component_spec, output_manifest, test_conditions): run_id = "custom_run_id" manifest = Manifest(input_manifest) - component_spec = ComponentSpec(component_spec) + component_spec = ComponentSpec.from_dict(component_spec) for test_condition in test_conditions: produces = test_condition["produces"] operation_spec = OperationSpec(component_spec, produces=produces) @@ -109,7 +109,7 @@ def test_invalid_evolution_examples( ): run_id = "custom_run_id" manifest = Manifest(input_manifest) - component_spec = ComponentSpec(component_spec) + component_spec = ComponentSpec.from_dict(component_spec) for test_condition in test_conditions: produces = test_condition["produces"] with pytest.raises(InvalidPipelineDefinition): # noqa: PT012 @@ -128,7 +128,7 @@ def test_component_spec_location_update(): specification = yaml.safe_load(f) manifest = Manifest(input_manifest) - component_spec = ComponentSpec(specification) + component_spec = ComponentSpec.from_dict(specification) evolved_manifest = manifest.evolve( operation_spec=OperationSpec(component_spec), run_id="123", diff --git a/tests/pipeline/test_compiler.py b/tests/pipeline/test_compiler.py index c02110bab..34167f0b4 100644 --- a/tests/pipeline/test_compiler.py +++ b/tests/pipeline/test_compiler.py @@ -32,7 +32,7 @@ "example_1", [ { - "component_op": ComponentOp( + "component_op": ComponentOp.from_component_yaml( Path(COMPONENTS_PATH / "example_1" / "first_component"), arguments={"storage_args": "a dummy string arg"}, input_partition_rows=10, @@ -45,7 +45,7 @@ "cache_key": "1", }, { - "component_op": ComponentOp( + "component_op": ComponentOp.from_component_yaml( Path(COMPONENTS_PATH / "example_1" / "second_component"), arguments={"storage_args": "a dummy string arg"}, input_partition_rows=10, @@ -53,7 +53,7 @@ "cache_key": "2", }, { - "component_op": ComponentOp( + "component_op": ComponentOp.from_component_yaml( Path(COMPONENTS_PATH / "example_1" / "third_component"), arguments={ "storage_args": "a dummy string arg", @@ -67,7 +67,7 @@ "example_2", [ { - "component_op": ComponentOp( + "component_op": ComponentOp.from_component_yaml( Path(COMPONENTS_PATH / "example_1" / "first_component"), arguments={"storage_args": "a dummy string arg"}, produces={"images_data": pa.binary()}, @@ -75,7 +75,7 @@ "cache_key": "1", }, { - "component_op": ComponentOp( + "component_op": ComponentOp.from_component_yaml( "crop_images", arguments={"cropping_threshold": 0, "padding": 0}, ), diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 706040baf..a58073291 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -43,7 +43,7 @@ def test_component_op( example_dir, component_names = valid_pipeline_example components_path = Path(valid_pipeline_path / example_dir) - ComponentOp( + ComponentOp.from_component_yaml( Path(components_path / component_names[0]), arguments=component_args, ) @@ -83,17 +83,17 @@ def test_component_op_hash( example_dir, component_names = valid_pipeline_example components_path = Path(valid_pipeline_path / example_dir) - comp_0_op_spec_0 = ComponentOp( + comp_0_op_spec_0 = ComponentOp.from_component_yaml( Path(components_path / component_names[0]), arguments={"storage_args": "a dummy string arg"}, ) - comp_0_op_spec_1 = ComponentOp( + comp_0_op_spec_1 = ComponentOp.from_component_yaml( Path(components_path / component_names[0]), arguments={"storage_args": "a different string arg"}, ) - comp_1_op_spec_0 = ComponentOp( + comp_1_op_spec_0 = ComponentOp.from_component_yaml( Path(components_path / component_names[1]), arguments={"storage_args": "a dummy string arg"}, ) @@ -122,7 +122,7 @@ def test_component_op_caching_strategy(monkeypatch): "image", f"fndnt/test_component:{tag}", ) - comp_0_op_spec_0 = ComponentOp( + comp_0_op_spec_0 = ComponentOp.from_component_yaml( components_path, arguments={"storage_args": "a dummy string arg"}, cache=True, @@ -388,8 +388,8 @@ def test_invalid_pipeline_declaration( def test_reusable_component_op(): - laion_retrieval_op = ComponentOp( - name_or_path="retrieve_laion_by_prompt", + laion_retrieval_op = ComponentOp.from_component_yaml( + "retrieve_laion_by_prompt", arguments={"num_images": 2, "aesthetic_score": 9, "aesthetic_weight": 0.5}, ) assert laion_retrieval_op.component_spec, "component_spec_path could not be loaded" @@ -399,12 +399,12 @@ def test_reusable_component_op(): ValueError, match=f"No reusable component with name {component_name} " "found.", ): - ComponentOp(component_name) + ComponentOp.from_component_yaml(component_name) def test_defining_reusable_component_op_with_custom_spec(): - load_from_hub_default_op = ComponentOp( - name_or_path="load_from_hf_hub", + load_from_hub_default_op = ComponentOp.from_component_yaml( + "load_from_hf_hub", arguments={ "dataset_name": "test_dataset", "column_name_mapping": {"foo": "bar"}, @@ -412,8 +412,8 @@ def test_defining_reusable_component_op_with_custom_spec(): }, ) - load_from_hub_custom_op = ComponentOp( - name_or_path=load_from_hub_default_op.component_dir, + load_from_hub_custom_op = ComponentOp.from_component_yaml( + load_from_hub_default_op.component_dir, arguments={ "dataset_name": "test_dataset", "column_name_mapping": {"foo": "bar"}, diff --git a/tests/pipeline/test_python_component.py b/tests/pipeline/test_python_component.py new file mode 100644 index 000000000..8f7ae4552 --- /dev/null +++ b/tests/pipeline/test_python_component.py @@ -0,0 +1,48 @@ +import dask.dataframe as dd +import pandas as pd +import pyarrow as pa +from fondant.component import DaskLoadComponent, PandasTransformComponent +from fondant.pipeline import Pipeline, lightweight_component + + +def test_lightweight_component(): + pipeline = Pipeline(name="dummy-pipeline", base_path="./data") + + @lightweight_component( + base_image="python:3.8-slim-buster", + extra_requires=["pandas", "dask"], + ) + class CreateData(DaskLoadComponent): + def __init__(self, **kwargs): + pass + + def load(self) -> dd.DataFrame: + df = pd.DataFrame( + { + "x": [1, 2, 3], + "y": [4, 5, 6], + }, + index=pd.Index(["a", "b", "c"], name="id"), + ) + return dd.from_pandas(df, npartitions=1) + + dataset = pipeline.read( + ref=CreateData, + produces={"x": pa.int32(), "y": pa.int32()}, + ) + + @lightweight_component() + class AddN(PandasTransformComponent): + def __init__(self, n: int): + self.n = n + + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + dataframe["x"] = dataframe["x"].map(lambda x: x + self.n) + return dataframe + + _ = dataset.apply( + ref=AddN, + produces={"x": pa.int32(), "y": pa.int32()}, + consumes={"x": pa.int32(), "y": pa.int32()}, + arguments={"n": 1}, + )