Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start from dataset schema for lightweight python component consumes #789

Merged
merged 22 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7a6b78a
Integrate argument inference
RobbeSneyders Jan 17, 2024
c007244
Add compilation to python component test
RobbeSneyders Jan 17, 2024
e52a5c4
Add argument inference to integration test
RobbeSneyders Jan 17, 2024
161f214
Start from dataset schema for python component consumes
RobbeSneyders Jan 17, 2024
66a9103
add option to define consumes in mapping
PhilippeMoussalli Jan 22, 2024
2e10af1
add option to define consumes and generic in mapping
PhilippeMoussalli Jan 22, 2024
e8d763f
Merge branch 'feature/python-consumes-mapping-3' into feature/python-…
PhilippeMoussalli Jan 22, 2024
6619b3a
small fixes
PhilippeMoussalli Jan 22, 2024
d898e4a
make lightweight consumes generic by default
PhilippeMoussalli Jan 23, 2024
2d80a77
Merge branch 'main' into feature/python-consumes-mapping
PhilippeMoussalli Jan 23, 2024
cef482a
revert to desired behaviour
PhilippeMoussalli Jan 23, 2024
8c9d154
update sample pipeline
PhilippeMoussalli Jan 23, 2024
4c97282
update based on feedback
PhilippeMoussalli Jan 23, 2024
3ab1bae
implement PR feedback
PhilippeMoussalli Jan 25, 2024
b59fb8c
add docstrings
PhilippeMoussalli Jan 25, 2024
de5a3c1
update consumes based on new proposal
PhilippeMoussalli Jan 30, 2024
3943c4b
Merge branch 'main' into feature/python-consumes-mapping
PhilippeMoussalli Jan 30, 2024
d8e5563
Update src/fondant/pipeline/lightweight_component.py
PhilippeMoussalli Jan 30, 2024
85f0994
enable default behavior of passing all dataset fields
PhilippeMoussalli Jan 30, 2024
5b69298
implement PR feedback
PhilippeMoussalli Jan 30, 2024
12c6f37
Merge branch 'main' into feature/python-consumes-mapping
GeorgesLorre Jan 30, 2024
60dc6f6
Merge branch 'main' into feature/python-consumes-mapping
PhilippeMoussalli Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/sample_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,17 @@
],
)
class CalculateChunkLength(PandasTransformComponent):
def __init__(self, feature_name: str, **kwargs):
self.feature_name = feature_name

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
dataframe["chunk_length"] = dataframe["text"].apply(len)
dataframe[self.feature_name] = dataframe["chunk"].apply(len)
return dataframe


_ = dataset.apply(
ref=CalculateChunkLength,
consumes={"text": pa.string()},
consumes={"chunk": "text"},
produces={"chunk_length": pa.int32()},
arguments={"feature_name": "chunk_length"},
)
19 changes: 15 additions & 4 deletions src/fondant/core/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ def kubeflow_type(self) -> str:
}
return lookup[self.type]

def to_spec(self):
return {
k: v
for k, v in {
"description": self.description,
"type": self.type.__name__,
"default": self.default,
}.items()
if v is not None
}


class ComponentSpec:
"""
Expand Down Expand Up @@ -83,8 +94,8 @@ def __init__(
image: str,
*,
description: t.Optional[str] = None,
consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
produces: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
consumes: t.Optional[t.Mapping[str, t.Union[str, pa.DataType]]] = None,
produces: t.Optional[t.Mapping[str, t.Union[str, pa.DataType]]] = None,
previous_index: t.Optional[str] = None,
args: t.Optional[t.Dict[str, t.Any]] = None,
tags: t.Optional[t.List[str]] = None,
Expand Down Expand Up @@ -252,9 +263,9 @@ def args(self) -> t.Mapping[str, Argument]:
{
name: Argument(
name=name,
description=arg_info["description"],
description=arg_info.get("description"),
type=pydoc.locate(arg_info["type"]), # type: ignore
default=arg_info["default"] if "default" in arg_info else None,
default=arg_info.get("default"),
optional=arg_info.get("default") == "None",
)
for name, arg_info in self._specification.get("args", {}).items()
Expand Down
8 changes: 4 additions & 4 deletions src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from functools import wraps

from fondant.component import Component
from fondant.component import BaseComponent, Component


@dataclass
Expand All @@ -20,7 +20,7 @@ def __post_init__(self):
self.base_image = "fondant:latest"


class PythonComponent:
class PythonComponent(BaseComponent):
@classmethod
def image(cls) -> Image:
raise NotImplementedError
Expand All @@ -42,12 +42,12 @@ def wrapper(cls):

# updated=() is needed to prevent an attempt to update the class's __dict__
@wraps(cls, updated=())
class AppliedPythonComponent(cls, PythonComponent):
class PythonComponentOp(cls, PythonComponent):
@classmethod
def image(cls) -> Image:
return image

return AppliedPythonComponent
return PythonComponentOp

return wrapper

Expand Down
19 changes: 16 additions & 3 deletions src/fondant/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fondant.core.manifest import Manifest
from fondant.core.schema import Field
from fondant.pipeline import Image, PythonComponent
from fondant.pipeline.argument_inference import infer_arguments

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -395,8 +396,8 @@ def read(
name,
image.base_image, # TODO: revisit
description=description,
consumes={"additionalProperties": True},
produces={"additionalProperties": True},
args={k: v.to_spec() for k, v in infer_arguments(ref).items()},
)

operation = ComponentOp(
Expand Down Expand Up @@ -702,12 +703,18 @@ def apply(
image = ref.image()
description = ref.__doc__ or "python component"

consumes_spec = {k: v.type.to_json() for k, v in self.fields.items()}
if consumes:
for k, v in consumes.items():
consumes_spec[k] = consumes_spec[v]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs more checks and might have to be moved to a better place. Just wanted to get a PoC working.

component_spec = ComponentSpec(
name,
image.base_image, # TODO: revisit
description=description,
consumes={"additionalProperties": True},
consumes=consumes_spec,
produces={"additionalProperties": True},
args={k: v.to_spec() for k, v in infer_arguments(ref).items()},
)

operation = ComponentOp(
Expand Down Expand Up @@ -777,12 +784,18 @@ def write(
image = ref.image()
description = ref.__doc__ or "python component"

consumes_spec = {k: v.type.to_json() for k, v in self.fields.items()}
if consumes:
for k, v in consumes.items():
consumes_spec[k] = consumes_spec[v]

component_spec = ComponentSpec(
name,
image.base_image, # TODO: revisit
description=description,
consumes={"additionalProperties": True},
consumes=consumes_spec,
produces={"additionalProperties": True},
args={k: v.to_spec() for k, v in infer_arguments(ref).items()},
)

operation = ComponentOp(
Expand Down
73 changes: 68 additions & 5 deletions tests/pipeline/test_python_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pyarrow as pa
from fondant.component import DaskLoadComponent, PandasTransformComponent
from fondant.pipeline import Pipeline, lightweight_component
from fondant.pipeline.compiler import DockerCompiler


def test_build_python_script():
Expand Down Expand Up @@ -46,15 +47,17 @@ def load(self) -> dd.DataFrame:
)


def test_lightweight_component(tmp_path_factory):
def test_compile_lightweight_component(tmp_path_factory):
pipeline = Pipeline(
name="dummy-pipeline",
base_path="./data",
)

@lightweight_component(
base_image="python:3.8-slim-buster",
extra_requires=["pandas", "dask"],
base_image="python:3.8",
extra_requires=[
"fondant[component]@git+https://github.com/ml6team/fondant@main",
],
)
class CreateData(DaskLoadComponent):
def load(self) -> dd.DataFrame:
Expand All @@ -72,9 +75,14 @@ def load(self) -> dd.DataFrame:
produces={"x": pa.int32(), "y": pa.int32()},
)

@lightweight_component()
@lightweight_component(
base_image="python:3.8",
extra_requires=[
"fondant[component]@git+https://github.com/ml6team/fondant@main",
],
)
class AddN(PandasTransformComponent):
def __init__(self, n: int):
def __init__(self, n: int, **kwargs):
self.n = n

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -85,4 +93,59 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
ref=AddN,
produces={"x": pa.int32(), "y": pa.int32()},
consumes={"x": pa.int32(), "y": pa.int32()},
arguments={"n": 1},
)

DockerCompiler().compile(pipeline)


def test_consumes_mapping():
pipeline = Pipeline(
name="dummy-pipeline",
base_path="./data",
)

@lightweight_component(
base_image="python:3.8",
extra_requires=[
"fondant[component]@git+https://github.com/ml6team/fondant@main",
],
)
class CreateData(DaskLoadComponent):
def load(self) -> dd.DataFrame:
df = pd.DataFrame(
{
"x": [1, 2, 3],
"y": [4, 5, 6],
},
index=pd.Index(["a", "b", "c"], name="id"),
)
return dd.from_pandas(df, npartitions=1)

dataset = pipeline.read(
ref=CreateData,
produces={"x": pa.int32(), "y": pa.int32()},
)

@lightweight_component(
base_image="python:3.8",
extra_requires=[
"fondant[component]@git+https://github.com/ml6team/fondant@main",
],
)
class AddN(PandasTransformComponent):
def __init__(self, n: int, **kwargs):
self.n = n

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
dataframe["a"] = dataframe["a"].map(lambda x: x + self.n)
return dataframe

_ = dataset.apply(
ref=AddN,
consumes={"a": "x"},
produces={"a": pa.int32()},
arguments={"n": 1},
)

DockerCompiler().compile(pipeline)
Loading