diff --git a/src/fondant/pipeline/lightweight_component.py b/src/fondant/pipeline/lightweight_component.py index f7a29197..bbced0b3 100644 --- a/src/fondant/pipeline/lightweight_component.py +++ b/src/fondant/pipeline/lightweight_component.py @@ -8,7 +8,7 @@ import pyarrow as pa from fondant.component import BaseComponent, Component -from fondant.core.schema import Field, Type +from fondant.core.schema import Field @dataclass @@ -32,7 +32,7 @@ def image(cls) -> Image: raise NotImplementedError @classmethod - def consumes(cls) -> t.Optional[t.Union[list, str]]: + def consumes(cls) -> t.Optional[list]: pass @classmethod @@ -48,7 +48,7 @@ def lightweight_component( *args, extra_requires: t.Optional[t.List[str]] = None, base_image: t.Optional[str] = None, - consumes: t.Optional[t.Union[list, str]] = None, + consumes: t.Optional[list] = None, ): """Decorator to enable a python component.""" @@ -126,6 +126,41 @@ def validate_abstract_methods_are_implemented(cls): msg, ) + def modify_consumes_spec(apply_consumes, consumes_spec): + """Modify fields based on the consumes argument in the 'apply' method.""" + if apply_consumes: + for k, v in apply_consumes.items(): + if isinstance(v, str): + consumes_spec[k] = consumes_spec.pop(v) + elif isinstance(v, pa.DataType): + pass + else: + msg = ( + f"Invalid data type for field `{k}` in the `apply_consumes` " + f"argument. Only string and pa.DataType are allowed." + ) + raise ValueError( + msg, + ) + return consumes_spec + + def filter_consumes_spec(python_component_consumes, consumes_spec): + """Filter for values that are not in the user defined consumes list.""" + if python_component_consumes: + for field_to_consume in python_component_consumes: + if field_to_consume not in consumes_spec.keys(): + msg = f"Field `{field_to_consume}` is not available in the dataset." + raise ValueError( + msg, + ) + + consumes_spec = { + k: v + for k, v in consumes_spec.items() + if k in python_component_consumes + } + return consumes_spec + validate_abstract_methods_are_implemented(cls) base_component_cls = get_base_cls(cls) validate_signatures(base_component_cls, cls) @@ -138,7 +173,7 @@ def image(cls) -> Image: return image @classmethod - def consumes(cls) -> t.Optional[t.Union[list, str]]: + def consumes(cls) -> t.Optional[list]: return consumes @classmethod @@ -147,42 +182,19 @@ def get_consumes_spec( dataset_fields: t.Mapping[str, Field], apply_consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]], ): - consumes = cls.consumes() - - if consumes == "generic": - return {"additionalProperties": True} + python_component_consumes = cls.consumes() # Get consumes spec from the dataset consumes_spec = {k: v.type.to_dict() for k, v in dataset_fields.items()} # Modify naming based on the consumes argument in the 'apply' method - if apply_consumes: - for k, v in apply_consumes.items(): - if isinstance(v, str): - consumes_spec[k] = consumes_spec.pop(v) - elif isinstance(v, pa.DataType): - consumes_spec[k] = Type(v).to_dict() - else: - msg = ( - f"Invalid data type for field `{k}` in the `apply_consumes` " - f"argument. Only string and pa.DataType are allowed." - ) - raise ValueError( - msg, - ) + consumes_spec = modify_consumes_spec(apply_consumes, consumes_spec) # Filter for values that are not in the user defined consumes list - if consumes: - for field_to_consume in consumes: - if field_to_consume not in consumes_spec.keys(): - msg = f"Field `{field_to_consume}` is not available in the dataset." - raise ValueError( - msg, - ) - - consumes_spec = { - k: v for k, v in consumes_spec.items() if k in consumes - } + consumes_spec = filter_consumes_spec( + python_component_consumes, + consumes_spec, + ) return consumes_spec diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 8fba1090..1b49e76c 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -71,7 +71,7 @@ def test_component_op( def test_component_op_python_component(default_pipeline_args): - @lightweight_component(consumes="generic") + @lightweight_component() class Foo(DaskLoadComponent): def load(self) -> dd.DataFrame: df = pd.DataFrame( diff --git a/tests/pipeline/test_python_component.py b/tests/pipeline/test_python_component.py index ecbc20e6..362bb5b8 100644 --- a/tests/pipeline/test_python_component.py +++ b/tests/pipeline/test_python_component.py @@ -24,7 +24,6 @@ def load_pipeline(): @lightweight_component( base_image="python:3.8-slim-buster", extra_requires=["pandas", "dask"], - consumes="generic", ) class CreateData(DaskLoadComponent): def load(self) -> dd.DataFrame: @@ -100,7 +99,7 @@ def test_lightweight_component_sdk(load_pipeline): }, } - @lightweight_component(consumes="generic") + @lightweight_component class AddN(PandasTransformComponent): def __init__(self, n: int, **kwargs): self.n = n @@ -112,7 +111,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: _ = dataset.apply( ref=AddN, produces={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()}, - consumes={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()}, + consumes=None, arguments={"n": 1}, ) assert len(pipeline._graph.keys()) == 1 + 1 @@ -123,15 +122,15 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: "name": "AddN", "image": "fondant:latest", "description": "python component", - "consumes": {"additionalProperties": True}, + "consumes": { + "x": {"type": "int32"}, + "y": {"type": "int32"}, + "z": {"type": "int32"}, + }, "produces": {"additionalProperties": True}, "args": {"n": {"type": "int"}}, }, - "consumes": { - "x": {"type": "int32"}, - "y": {"type": "int32"}, - "z": {"type": "int32"}, - }, + "consumes": {}, "produces": { "x": {"type": "int32"}, "y": {"type": "int32"},