diff --git a/src/fondant/component/component.py b/src/fondant/component/component.py index 7d9a8611..08b9d1b9 100644 --- a/src/fondant/component/component.py +++ b/src/fondant/component/component.py @@ -1,6 +1,7 @@ """This module defines interfaces which components should implement to be executed by fondant.""" import typing as t +from abc import abstractmethod import dask.dataframe as dd import pandas as pd @@ -33,13 +34,15 @@ def teardown(self) -> None: class DaskLoadComponent(BaseComponent): """Component that loads data and returns a Dask DataFrame.""" + @abstractmethod def load(self) -> dd.DataFrame: - raise NotImplementedError + pass class DaskTransformComponent(BaseComponent): """Component that transforms an incoming Dask DataFrame.""" + @abstractmethod def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame: """ Abstract method for applying data transformations to the input dataframe. @@ -48,14 +51,14 @@ def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame: dataframe: A Dask dataframe containing the data specified in the `consumes` section of the component specification """ - raise NotImplementedError class DaskWriteComponent(BaseComponent): """Component that accepts a Dask DataFrame and writes its contents.""" + @abstractmethod def write(self, dataframe: dd.DataFrame) -> None: - raise NotImplementedError + pass class PandasTransformComponent(BaseComponent): @@ -63,6 +66,7 @@ class PandasTransformComponent(BaseComponent): DataFrame. """ + @abstractmethod def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: """ Abstract method for applying data transformations to the input dataframe. @@ -71,7 +75,6 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: Args: dataframe: A Pandas dataframe containing a partition of the data """ - raise NotImplementedError Component = t.TypeVar("Component", bound=BaseComponent) diff --git a/src/fondant/pipeline/lightweight_component.py b/src/fondant/pipeline/lightweight_component.py index d44dd9ec..cd0755c0 100644 --- a/src/fondant/pipeline/lightweight_component.py +++ b/src/fondant/pipeline/lightweight_component.py @@ -27,6 +27,7 @@ def image(cls) -> Image: def lightweight_component( + *args, extra_requires: t.Optional[t.List[str]] = None, base_image: t.Optional[str] = None, ): @@ -40,6 +41,76 @@ def wrapper(cls): script=script, ) + def get_base_cls(cls): + """ + Returns the BaseComponent. If the implementation inherits from several classes, + the Fondant base class is selected. If more than one Fondant base component + is implemented, an exception is raised. + """ + base_component_module = inspect.getmodule(Component).__name__ + base_component_cls_list = [ + base + for base in cls.__bases__ + if base.__module__ == base_component_module + ] + if len(base_component_cls_list) > 1: + msg = ( + f"Multiple base classes detected. Only one component should be inherited or" + f" implemented." + f"Found classes: {', '.join([cls.__name__ for cls in base_component_cls_list])}" + ) + raise ValueError( + msg, + ) + return base_component_cls_list[0] + + def validate_signatures(base_component_cls, cls_implementation): + """ + Compare the signature of overridden methods in a class with their counterparts + in the BaseComponent classes. + """ + for function_name in dir(cls_implementation): + if not function_name.startswith("__") and function_name in dir( + base_component_cls, + ): + type_cls_implementation = inspect.signature( + getattr(cls_implementation, function_name, None), + ) + type_base_cls = inspect.signature( + getattr(base_component_cls, function_name, None), + ) + if type_cls_implementation != type_base_cls: + msg = ( + f"Invalid function definition of function {function_name}. " + f"The expected function signature is {type_base_cls}" + ) + raise ValueError( + msg, + ) + + def validate_abstract_methods_are_implemented(cls): + """ + Function to validate that a class has overridden every required function marked as + abstract. + """ + abstract_methods = [ + name + for name, value in inspect.getmembers(cls) + if getattr(value, "__isabstractmethod__", False) + ] + if len(abstract_methods) >= 1: + msg = ( + f"Every required function must be overridden in the PythonComponent. " + f"Missing implementations for the following functions: {abstract_methods}" + ) + raise ValueError( + msg, + ) + + validate_abstract_methods_are_implemented(cls) + base_component_cls = get_base_cls(cls) + validate_signatures(base_component_cls, cls) + # updated=() is needed to prevent an attempt to update the class's __dict__ @wraps(cls, updated=()) class AppliedPythonComponent(cls, PythonComponent): @@ -49,6 +120,10 @@ def image(cls) -> Image: return AppliedPythonComponent + # Call wrapper with function (`args[0]`) when no additional arguments were passed + if args: + return wrapper(args[0]) + return wrapper diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 5a553557..1b49e76c 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -2,6 +2,8 @@ import copy from pathlib import Path +import dask.dataframe as dd +import pandas as pd import pyarrow as pa import pytest import yaml @@ -71,8 +73,15 @@ def test_component_op( def test_component_op_python_component(default_pipeline_args): @lightweight_component() class Foo(DaskLoadComponent): - def load(self) -> str: - return ["bar"] + def load(self) -> dd.DataFrame: + df = pd.DataFrame( + { + "x": [1, 2, 3], + "y": [4, 5, 6], + }, + index=pd.Index(["a", "b", "c"], name="id"), + ) + return dd.from_pandas(df, npartitions=1) component = ComponentOp.from_ref(Foo, produces={"bar": pa.string()}) assert component.component_spec._specification == { diff --git a/tests/pipeline/test_python_component.py b/tests/pipeline/test_python_component.py index c06cdba0..572f3b5d 100644 --- a/tests/pipeline/test_python_component.py +++ b/tests/pipeline/test_python_component.py @@ -1,4 +1,5 @@ import json +import re import textwrap import dask.dataframe as dd @@ -103,6 +104,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: produces={"x": pa.int32(), "y": pa.int32()}, consumes={"x": pa.int32(), "y": pa.int32()}, ) + assert len(pipeline._graph.keys()) == 1 + 1 assert pipeline._graph["AddN"]["dependencies"] == ["CreateData"] operation_spec = pipeline._graph["AddN"]["operation"].operation_spec.to_json() @@ -135,3 +137,108 @@ def load(self) -> str: ref=Foo, produces={"x": pa.int32(), "y": pa.int32()}, ) + + +def test_valid_load_component(): + @lightweight_component( + base_image="python:3.8-slim-buster", + ) + class CreateData(DaskLoadComponent): + def load(self) -> dd.DataFrame: + df = pd.DataFrame( + { + "x": [1, 2, 3], + "y": [4, 5, 6], + }, + index=pd.Index(["a", "b", "c"], name="id"), + ) + return dd.from_pandas(df, npartitions=1) + + CreateData(produces={}, consumes={}) + + +def test_invalid_load_component(): + with pytest.raises( # noqa: PT012 + ValueError, + match="Every required function must be overridden in the PythonComponent. " + "Missing implementations for the following functions: \\['load'\\]", + ): + + @lightweight_component( + base_image="python:3.8-slim-buster", + ) + class CreateData(DaskLoadComponent): + def custom_load(self) -> int: + return 1 + + CreateData(produces={}, consumes={}) + + +def test_invalid_load_transform_component(): + with pytest.raises( # noqa: PT012 + ValueError, + match="Multiple base classes detected. Only one component should be inherited " + "or implemented.Found classes: DaskLoadComponent, PandasTransformComponent", + ): + + @lightweight_component( + base_image="python:3.8-slim-buster", + ) + class CreateData(DaskLoadComponent, PandasTransformComponent): + def load(self) -> dd.DataFrame: + pass + + def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: + pass + + CreateData(produces={}, consumes={}) + + +def test_invalid_load_component_wrong_return_type(): + with pytest.raises( # noqa: PT012 + ValueError, + match=re.escape( + "Invalid function definition of function load. " + "The expected function signature " + "is (self) -> dask.dataframe.core.DataFrame", + ), + ): + + @lightweight_component( + base_image="python:3.8-slim-buster", + ) + class CreateData(DaskLoadComponent): + def load(self) -> int: + return 1 + + CreateData(produces={}, consumes={}) + + +def test_lightweight_component_decorator_without_parentheses(): + @lightweight_component + class CreateData(DaskLoadComponent): + def load(self) -> dd.DataFrame: + return None + + pipeline = Pipeline( + name="dummy-pipeline", + base_path="./data", + ) + + pipeline.read( + ref=CreateData, + ) + + assert len(pipeline._graph.keys()) == 1 + operation_spec = pipeline._graph["CreateData"]["operation"].operation_spec.to_json() + assert json.loads(operation_spec) == { + "specification": { + "name": "CreateData", + "image": "fondant:latest", + "description": "python component", + "consumes": {"additionalProperties": True}, + "produces": {"additionalProperties": True}, + }, + "consumes": {}, + "produces": {}, + }