Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation for lightweight components #793

Merged
merged 9 commits into from
Jan 22, 2024
18 changes: 14 additions & 4 deletions src/fondant/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class DaskLoadComponent(BaseComponent):
"""Component that loads data and returns a Dask DataFrame."""

def load(self) -> dd.DataFrame:
raise NotImplementedError
msg = "Implementation of the `load` method is required in theDaskLoadComponent class."
raise NotImplementedError(msg)


class DaskTransformComponent(BaseComponent):
Expand All @@ -48,14 +49,19 @@ def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame:
dataframe: A Dask dataframe containing the data specified in the `consumes` section
of the component specification
"""
raise NotImplementedError
msg = (
"Implementation of the `transform` method is required in theDaskTransformComponent "
"class."
)
raise NotImplementedError(msg)


class DaskWriteComponent(BaseComponent):
"""Component that accepts a Dask DataFrame and writes its contents."""

def write(self, dataframe: dd.DataFrame) -> None:
raise NotImplementedError
msg = "Implementation of the `write` method is required in theDaskWriteComponent class."
raise NotImplementedError(msg)


class PandasTransformComponent(BaseComponent):
Expand All @@ -71,7 +77,11 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
Args:
dataframe: A Pandas dataframe containing a partition of the data
"""
raise NotImplementedError
msg = (
"Implementation of the `transform` method is required in "
"the PandasTransformComponent class."
)
raise NotImplementedError(msg)


Component = t.TypeVar("Component", bound=BaseComponent)
Expand Down
21 changes: 21 additions & 0 deletions src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ def wrapper(cls):
script=script,
)

base_component_cls = cls.__bases__[0]
mrchtr marked this conversation as resolved.
Show resolved Hide resolved

for function_name in dir(cls):
if not function_name.startswith("__") and function_name in dir(
base_component_cls,
):
type_cls_implementation = inspect.signature(
getattr(cls, function_name, None),
).return_annotation
type_base_cls = inspect.signature(
getattr(base_component_cls, function_name, None),
).return_annotation
if type_cls_implementation != type_base_cls:
msg = (
f"Invalid function definition of function {function_name}. The return type "
f"has to be {type_base_cls}"
)
raise ValueError(
msg,
)

mrchtr marked this conversation as resolved.
Show resolved Hide resolved
# updated=() is needed to prevent an attempt to update the class's __dict__
@wraps(cls, updated=())
class AppliedPythonComponent(cls, PythonComponent):
Expand Down
40 changes: 39 additions & 1 deletion tests/pipeline/test_python_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import dask.dataframe as dd
import pandas as pd
import pyarrow as pa
from fondant.component import DaskLoadComponent, PandasTransformComponent
import pytest
from fondant.component import (
DaskLoadComponent,
PandasTransformComponent,
)
from fondant.pipeline import Pipeline, lightweight_component


Expand Down Expand Up @@ -86,3 +90,37 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
produces={"x": pa.int32(), "y": pa.int32()},
consumes={"x": pa.int32(), "y": pa.int32()},
)


def test_valid_load_component():
@lightweight_component(
base_image="python:3.8-slim-buster",
)
class CreateData(DaskLoadComponent):
def load(self) -> dd.DataFrame:
df = pd.DataFrame(
{
"x": [1, 2, 3],
"y": [4, 5, 6],
},
index=pd.Index(["a", "b", "c"], name="id"),
)
return dd.from_pandas(df, npartitions=1)

CreateData()


def test_invalid_load_component():
@lightweight_component(
base_image="python:3.8-slim-buster",
)
class CreateData(DaskLoadComponent):
def load(self) -> int:
return 1

with pytest.raises(
ValueError,
match="Invalid function definition of function load."
"The return type has to be <class 'dask.dataframe.core.DataFrame'>",
):
CreateData()
Loading