Skip to content

Commit

Permalink
Validation for lightweight components (#793)
Browse files Browse the repository at this point in the history
Adding validation to the PythonComponent. 
- change BaseComponents methods to abstractmethods
- validate that all abstractmethods were implemented in the
PythonComponent
- validate that the function interface is the same as in the
BaseComponents

Fix #783
  • Loading branch information
mrchtr authored Jan 22, 2024
1 parent 9c4e3b1 commit f07ea5b
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 6 deletions.
11 changes: 7 additions & 4 deletions src/fondant/component/component.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module defines interfaces which components should implement to be executed by fondant."""

import typing as t
from abc import abstractmethod

import dask.dataframe as dd
import pandas as pd
Expand Down Expand Up @@ -33,13 +34,15 @@ def teardown(self) -> None:
class DaskLoadComponent(BaseComponent):
"""Component that loads data and returns a Dask DataFrame."""

@abstractmethod
def load(self) -> dd.DataFrame:
raise NotImplementedError
pass


class DaskTransformComponent(BaseComponent):
"""Component that transforms an incoming Dask DataFrame."""

@abstractmethod
def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame:
"""
Abstract method for applying data transformations to the input dataframe.
Expand All @@ -48,21 +51,22 @@ def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame:
dataframe: A Dask dataframe containing the data specified in the `consumes` section
of the component specification
"""
raise NotImplementedError


class DaskWriteComponent(BaseComponent):
"""Component that accepts a Dask DataFrame and writes its contents."""

@abstractmethod
def write(self, dataframe: dd.DataFrame) -> None:
raise NotImplementedError
pass


class PandasTransformComponent(BaseComponent):
"""Component that transforms the incoming dataset partition per partition as a pandas
DataFrame.
"""

@abstractmethod
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"""
Abstract method for applying data transformations to the input dataframe.
Expand All @@ -71,7 +75,6 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
Args:
dataframe: A Pandas dataframe containing a partition of the data
"""
raise NotImplementedError


Component = t.TypeVar("Component", bound=BaseComponent)
Expand Down
75 changes: 75 additions & 0 deletions src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def image(cls) -> Image:


def lightweight_component(
*args,
extra_requires: t.Optional[t.List[str]] = None,
base_image: t.Optional[str] = None,
):
Expand All @@ -40,6 +41,76 @@ def wrapper(cls):
script=script,
)

def get_base_cls(cls):
"""
Returns the BaseComponent. If the implementation inherits from several classes,
the Fondant base class is selected. If more than one Fondant base component
is implemented, an exception is raised.
"""
base_component_module = inspect.getmodule(Component).__name__
base_component_cls_list = [
base
for base in cls.__bases__
if base.__module__ == base_component_module
]
if len(base_component_cls_list) > 1:
msg = (
f"Multiple base classes detected. Only one component should be inherited or"
f" implemented."
f"Found classes: {', '.join([cls.__name__ for cls in base_component_cls_list])}"
)
raise ValueError(
msg,
)
return base_component_cls_list[0]

def validate_signatures(base_component_cls, cls_implementation):
"""
Compare the signature of overridden methods in a class with their counterparts
in the BaseComponent classes.
"""
for function_name in dir(cls_implementation):
if not function_name.startswith("__") and function_name in dir(
base_component_cls,
):
type_cls_implementation = inspect.signature(
getattr(cls_implementation, function_name, None),
)
type_base_cls = inspect.signature(
getattr(base_component_cls, function_name, None),
)
if type_cls_implementation != type_base_cls:
msg = (
f"Invalid function definition of function {function_name}. "
f"The expected function signature is {type_base_cls}"
)
raise ValueError(
msg,
)

def validate_abstract_methods_are_implemented(cls):
"""
Function to validate that a class has overridden every required function marked as
abstract.
"""
abstract_methods = [
name
for name, value in inspect.getmembers(cls)
if getattr(value, "__isabstractmethod__", False)
]
if len(abstract_methods) >= 1:
msg = (
f"Every required function must be overridden in the PythonComponent. "
f"Missing implementations for the following functions: {abstract_methods}"
)
raise ValueError(
msg,
)

validate_abstract_methods_are_implemented(cls)
base_component_cls = get_base_cls(cls)
validate_signatures(base_component_cls, cls)

# updated=() is needed to prevent an attempt to update the class's __dict__
@wraps(cls, updated=())
class PythonComponentOp(cls, PythonComponent):
Expand All @@ -49,6 +120,10 @@ def image(cls) -> Image:

return PythonComponentOp

# Call wrapper with function (`args[0]`) when no additional arguments were passed
if args:
return wrapper(args[0])

return wrapper


Expand Down
13 changes: 11 additions & 2 deletions tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import copy
from pathlib import Path

import dask.dataframe as dd
import pandas as pd
import pyarrow as pa
import pytest
import yaml
Expand Down Expand Up @@ -71,8 +73,15 @@ def test_component_op(
def test_component_op_python_component(default_pipeline_args):
@lightweight_component()
class Foo(DaskLoadComponent):
def load(self) -> str:
return ["bar"]
def load(self) -> dd.DataFrame:
df = pd.DataFrame(
{
"x": [1, 2, 3],
"y": [4, 5, 6],
},
index=pd.Index(["a", "b", "c"], name="id"),
)
return dd.from_pandas(df, npartitions=1)

component = ComponentOp.from_ref(Foo, produces={"bar": pa.string()})
assert component.component_spec._specification == {
Expand Down
107 changes: 107 additions & 0 deletions tests/pipeline/test_python_component.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
import textwrap

import dask.dataframe as dd
Expand Down Expand Up @@ -105,6 +106,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
consumes={"x": pa.int32(), "y": pa.int32()},
arguments={"n": 1},
)

assert len(pipeline._graph.keys()) == 1 + 1
assert pipeline._graph["AddN"]["dependencies"] == ["CreateData"]
operation_spec = pipeline._graph["AddN"]["operation"].operation_spec.to_json()
Expand Down Expand Up @@ -140,3 +142,108 @@ def load(self) -> str:
ref=Foo,
produces={"x": pa.int32(), "y": pa.int32()},
)


def test_valid_load_component():
@lightweight_component(
base_image="python:3.8-slim-buster",
)
class CreateData(DaskLoadComponent):
def load(self) -> dd.DataFrame:
df = pd.DataFrame(
{
"x": [1, 2, 3],
"y": [4, 5, 6],
},
index=pd.Index(["a", "b", "c"], name="id"),
)
return dd.from_pandas(df, npartitions=1)

CreateData(produces={}, consumes={})


def test_invalid_load_component():
with pytest.raises( # noqa: PT012
ValueError,
match="Every required function must be overridden in the PythonComponent. "
"Missing implementations for the following functions: \\['load'\\]",
):

@lightweight_component(
base_image="python:3.8-slim-buster",
)
class CreateData(DaskLoadComponent):
def custom_load(self) -> int:
return 1

CreateData(produces={}, consumes={})


def test_invalid_load_transform_component():
with pytest.raises( # noqa: PT012
ValueError,
match="Multiple base classes detected. Only one component should be inherited "
"or implemented.Found classes: DaskLoadComponent, PandasTransformComponent",
):

@lightweight_component(
base_image="python:3.8-slim-buster",
)
class CreateData(DaskLoadComponent, PandasTransformComponent):
def load(self) -> dd.DataFrame:
pass

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
pass

CreateData(produces={}, consumes={})


def test_invalid_load_component_wrong_return_type():
with pytest.raises( # noqa: PT012
ValueError,
match=re.escape(
"Invalid function definition of function load. "
"The expected function signature "
"is (self) -> dask.dataframe.core.DataFrame",
),
):

@lightweight_component(
base_image="python:3.8-slim-buster",
)
class CreateData(DaskLoadComponent):
def load(self) -> int:
return 1

CreateData(produces={}, consumes={})


def test_lightweight_component_decorator_without_parentheses():
@lightweight_component
class CreateData(DaskLoadComponent):
def load(self) -> dd.DataFrame:
return None

pipeline = Pipeline(
name="dummy-pipeline",
base_path="./data",
)

pipeline.read(
ref=CreateData,
)

assert len(pipeline._graph.keys()) == 1
operation_spec = pipeline._graph["CreateData"]["operation"].operation_spec.to_json()
assert json.loads(operation_spec) == {
"specification": {
"name": "CreateData",
"image": "fondant:latest",
"description": "python component",
"consumes": {"additionalProperties": True},
"produces": {"additionalProperties": True},
},
"consumes": {},
"produces": {},
}

0 comments on commit f07ea5b

Please sign in to comment.