Skip to content

Commit

Permalink
Integrate argument inference (#788)
Browse files Browse the repository at this point in the history
Follow-up to use the argument inference functionality added in #763
  • Loading branch information
RobbeSneyders authored Jan 18, 2024
1 parent 538aa63 commit 71a8f72
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 7 deletions.
4 changes: 4 additions & 0 deletions examples/sample_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
],
)
class CalculateChunkLength(PandasTransformComponent):
def __init__(self, arg_x: bool, **kwargs):
self.arg_x = arg_x

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
dataframe["chunk_length"] = dataframe["text"].apply(len)
return dataframe
Expand All @@ -59,4 +62,5 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
ref=CalculateChunkLength,
consumes={"text": pa.string()},
produces={"chunk_length": pa.int32()},
arguments={"arg_x": "value_x"},
)
15 changes: 13 additions & 2 deletions src/fondant/core/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ def kubeflow_type(self) -> str:
}
return lookup[self.type]

def to_spec(self):
return {
k: v
for k, v in {
"description": self.description,
"type": self.type.__name__,
"default": self.default,
}.items()
if v is not None
}


class ComponentSpec:
"""
Expand Down Expand Up @@ -252,9 +263,9 @@ def args(self) -> t.Mapping[str, Argument]:
{
name: Argument(
name=name,
description=arg_info["description"],
description=arg_info.get("description"),
type=pydoc.locate(arg_info["type"]), # type: ignore
default=arg_info["default"] if "default" in arg_info else None,
default=arg_info.get("default"),
optional=arg_info.get("default") == "None",
)
for name, arg_info in self._specification.get("args", {}).items()
Expand Down
8 changes: 4 additions & 4 deletions src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from functools import wraps

from fondant.component import Component
from fondant.component import BaseComponent, Component


@dataclass
Expand All @@ -20,7 +20,7 @@ def __post_init__(self):
self.base_image = "fondant:latest"


class PythonComponent:
class PythonComponent(BaseComponent):
@classmethod
def image(cls) -> Image:
raise NotImplementedError
Expand All @@ -42,12 +42,12 @@ def wrapper(cls):

# updated=() is needed to prevent an attempt to update the class's __dict__
@wraps(cls, updated=())
class AppliedPythonComponent(cls, PythonComponent):
class PythonComponentOp(cls, PythonComponent):
@classmethod
def image(cls) -> Image:
return image

return AppliedPythonComponent
return PythonComponentOp

return wrapper

Expand Down
5 changes: 5 additions & 0 deletions src/fondant/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from fondant.core.manifest import Manifest
from fondant.core.schema import Field
from fondant.pipeline import Image, PythonComponent
from fondant.pipeline.argument_inference import infer_arguments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -222,6 +223,10 @@ def from_ref(cls, ref: t.Any, **kwargs) -> "ComponentOp":
description=description,
consumes={"additionalProperties": True},
produces={"additionalProperties": True},
args={
name: arg.to_spec()
for name, arg in infer_arguments(ref).items()
},
)

operation = cls(
Expand Down
7 changes: 6 additions & 1 deletion tests/pipeline/test_python_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fondant.component import DaskLoadComponent, PandasTransformComponent
from fondant.core.exceptions import InvalidPythonComponent
from fondant.pipeline import Pipeline, lightweight_component
from fondant.pipeline.compiler import DockerCompiler


def test_build_python_script():
Expand Down Expand Up @@ -91,7 +92,7 @@ def load(self) -> dd.DataFrame:

@lightweight_component()
class AddN(PandasTransformComponent):
def __init__(self, n: int):
def __init__(self, n: int, **kwargs):
self.n = n

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -102,6 +103,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
ref=AddN,
produces={"x": pa.int32(), "y": pa.int32()},
consumes={"x": pa.int32(), "y": pa.int32()},
arguments={"n": 1},
)
assert len(pipeline._graph.keys()) == 1 + 1
assert pipeline._graph["AddN"]["dependencies"] == ["CreateData"]
Expand All @@ -113,12 +115,15 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"description": "python component",
"consumes": {"additionalProperties": True},
"produces": {"additionalProperties": True},
"args": {"n": {"type": "int"}},
},
"consumes": {"x": {"type": "int32"}, "y": {"type": "int32"}},
"produces": {"x": {"type": "int32"}, "y": {"type": "int32"}},
}
pipeline._validate_pipeline_definition(run_id="dummy-run-id")

DockerCompiler().compile(pipeline)


def test_lightweight_component_missing_decorator():
pipeline = Pipeline(
Expand Down

0 comments on commit 71a8f72

Please sign in to comment.