diff --git a/examples/sample_pipeline/pipeline.py b/examples/sample_pipeline/pipeline.py index f984c129..eff75db7 100644 --- a/examples/sample_pipeline/pipeline.py +++ b/examples/sample_pipeline/pipeline.py @@ -50,6 +50,9 @@ ], ) class CalculateChunkLength(PandasTransformComponent): + def __init__(self, arg_x: bool, **kwargs): + self.arg_x = arg_x + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: dataframe["chunk_length"] = dataframe["text"].apply(len) return dataframe @@ -59,4 +62,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: ref=CalculateChunkLength, consumes={"text": pa.string()}, produces={"chunk_length": pa.int32()}, + arguments={"arg_x": "value_x"}, ) diff --git a/src/fondant/core/component_spec.py b/src/fondant/core/component_spec.py index 452d92d1..64507ec8 100644 --- a/src/fondant/core/component_spec.py +++ b/src/fondant/core/component_spec.py @@ -55,6 +55,17 @@ def kubeflow_type(self) -> str: } return lookup[self.type] + def to_spec(self): + return { + k: v + for k, v in { + "description": self.description, + "type": self.type.__name__, + "default": self.default, + }.items() + if v is not None + } + class ComponentSpec: """ @@ -252,9 +263,9 @@ def args(self) -> t.Mapping[str, Argument]: { name: Argument( name=name, - description=arg_info["description"], + description=arg_info.get("description"), type=pydoc.locate(arg_info["type"]), # type: ignore - default=arg_info["default"] if "default" in arg_info else None, + default=arg_info.get("default"), optional=arg_info.get("default") == "None", ) for name, arg_info in self._specification.get("args", {}).items() diff --git a/src/fondant/pipeline/lightweight_component.py b/src/fondant/pipeline/lightweight_component.py index d44dd9ec..2f548b51 100644 --- a/src/fondant/pipeline/lightweight_component.py +++ b/src/fondant/pipeline/lightweight_component.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import wraps -from fondant.component import Component +from fondant.component import BaseComponent, Component @dataclass @@ -20,7 +20,7 @@ def __post_init__(self): self.base_image = "fondant:latest" -class PythonComponent: +class PythonComponent(BaseComponent): @classmethod def image(cls) -> Image: raise NotImplementedError @@ -42,12 +42,12 @@ def wrapper(cls): # updated=() is needed to prevent an attempt to update the class's __dict__ @wraps(cls, updated=()) - class AppliedPythonComponent(cls, PythonComponent): + class PythonComponentOp(cls, PythonComponent): @classmethod def image(cls) -> Image: return image - return AppliedPythonComponent + return PythonComponentOp return wrapper diff --git a/src/fondant/pipeline/pipeline.py b/src/fondant/pipeline/pipeline.py index 629e42a3..18d16a11 100644 --- a/src/fondant/pipeline/pipeline.py +++ b/src/fondant/pipeline/pipeline.py @@ -23,6 +23,7 @@ from fondant.core.manifest import Manifest from fondant.core.schema import Field from fondant.pipeline import Image, PythonComponent +from fondant.pipeline.argument_inference import infer_arguments logger = logging.getLogger(__name__) @@ -222,6 +223,10 @@ def from_ref(cls, ref: t.Any, **kwargs) -> "ComponentOp": description=description, consumes={"additionalProperties": True}, produces={"additionalProperties": True}, + args={ + name: arg.to_spec() + for name, arg in infer_arguments(ref).items() + }, ) operation = cls( diff --git a/tests/pipeline/test_python_component.py b/tests/pipeline/test_python_component.py index c06cdba0..8f1c57f8 100644 --- a/tests/pipeline/test_python_component.py +++ b/tests/pipeline/test_python_component.py @@ -8,6 +8,7 @@ from fondant.component import DaskLoadComponent, PandasTransformComponent from fondant.core.exceptions import InvalidPythonComponent from fondant.pipeline import Pipeline, lightweight_component +from fondant.pipeline.compiler import DockerCompiler def test_build_python_script(): @@ -91,7 +92,7 @@ def load(self) -> dd.DataFrame: @lightweight_component() class AddN(PandasTransformComponent): - def __init__(self, n: int): + def __init__(self, n: int, **kwargs): self.n = n def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: @@ -102,6 +103,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: ref=AddN, produces={"x": pa.int32(), "y": pa.int32()}, consumes={"x": pa.int32(), "y": pa.int32()}, + arguments={"n": 1}, ) assert len(pipeline._graph.keys()) == 1 + 1 assert pipeline._graph["AddN"]["dependencies"] == ["CreateData"] @@ -113,12 +115,15 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: "description": "python component", "consumes": {"additionalProperties": True}, "produces": {"additionalProperties": True}, + "args": {"n": {"type": "int"}}, }, "consumes": {"x": {"type": "int32"}, "y": {"type": "int32"}}, "produces": {"x": {"type": "int32"}, "y": {"type": "int32"}}, } pipeline._validate_pipeline_definition(run_id="dummy-run-id") + DockerCompiler().compile(pipeline) + def test_lightweight_component_missing_decorator(): pipeline = Pipeline(