Skip to content

Commit

Permalink
Addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchtr committed Jan 22, 2024
1 parent a181116 commit a3bb42c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 10 deletions.
41 changes: 31 additions & 10 deletions src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def wrapper(cls):
script=script,
)

validate_required_functions_are_implemented(cls)
base_component_cls = cls.__bases__[0]
validate_return_types(base_component_cls, cls)
validate_abstract_methods_are_implemented(cls)
base_component_cls = get_base_cls(cls)
validate_signatures(base_component_cls, cls)

# updated=() is needed to prevent an attempt to update the class's __dict__
@wraps(cls, updated=())
Expand All @@ -53,9 +53,30 @@ def image(cls) -> Image:

return AppliedPythonComponent

def validate_return_types(base_component_cls, cls_implementation):
def get_base_cls(cls):
"""
Compare the return types of overridden methods in a class with their counterparts
Returns the BaseComponent. If the implementation inherits from several classes,
the Fondant base class is selected. If more than one Fondant base component
is implemented, an exception is raised.
"""
base_component_module = inspect.getmodule(Component).__name__
base_component_cls_list = [
base for base in cls.__bases__ if base.__module__ == base_component_module
]
if len(base_component_cls_list) > 1:
msg = (
f"Multiple base classes detected. Only one component should be inherited or"
f" implemented."
f"Found classes: {', '.join([cls.__name__ for cls in base_component_cls_list])}"
)
raise ValueError(
msg,
)
return base_component_cls_list[0]

def validate_signatures(base_component_cls, cls_implementation):
"""
Compare the signature of overridden methods in a class with their counterparts
in the BaseComponent classes.
"""
for function_name in dir(cls_implementation):
Expand All @@ -64,20 +85,20 @@ def validate_return_types(base_component_cls, cls_implementation):
):
type_cls_implementation = inspect.signature(
getattr(cls_implementation, function_name, None),
).return_annotation
)
type_base_cls = inspect.signature(
getattr(base_component_cls, function_name, None),
).return_annotation
)
if type_cls_implementation != type_base_cls:
msg = (
f"Invalid function definition of function {function_name}. The return type "
f"has to be {type_base_cls}"
f"Invalid function definition of function {function_name}. The expected "
f"function signature is {type_base_cls}"
)
raise ValueError(
msg,
)

def validate_required_functions_are_implemented(cls):
def validate_abstract_methods_are_implemented(cls):
"""
Function to validate that a class has overridden every required function marked as
abstract.
Expand Down
20 changes: 20 additions & 0 deletions tests/pipeline/test_python_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,26 @@ def custom_load(self) -> int:
CreateData(produces={}, consumes={})


def test_invalid_load_transform_component():
with pytest.raises( # noqa: PT012
ValueError,
match="Multiple base classes detected. Only one component should be inherited "
"or implemented.Found classes: DaskLoadComponent, PandasTransformComponent",
):

@lightweight_component(
base_image="python:3.8-slim-buster",
)
class CreateData(DaskLoadComponent, PandasTransformComponent):
def load(self) -> dd.DataFrame:
pass

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
pass

CreateData(produces={}, consumes={})


def test_invalid_load_component_wrong_return_type():
with pytest.raises( # noqa: PT012
ValueError,
Expand Down

0 comments on commit a3bb42c

Please sign in to comment.