Skip to content

Commit

Permalink
Add support for parentheses in decorator calls
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchtr committed Jan 22, 2024
1 parent a3bb42c commit 0237b69
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 65 deletions.
133 changes: 70 additions & 63 deletions src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def image(cls) -> Image:


def lightweight_component(
*args,
extra_requires: t.Optional[t.List[str]] = None,
base_image: t.Optional[str] = None,
):
Expand All @@ -40,6 +41,72 @@ def wrapper(cls):
script=script,
)

def get_base_cls(cls):
"""
Returns the BaseComponent. If the implementation inherits from several classes,
the Fondant base class is selected. If more than one Fondant base component
is implemented, an exception is raised.
"""
base_component_module = inspect.getmodule(Component).__name__
base_component_cls_list = [
base
for base in cls.__bases__
if base.__module__ == base_component_module
]
if len(base_component_cls_list) > 1:
msg = (
f"Multiple base classes detected. Only one component should be inherited or"
f" implemented."
f"Found classes: {', '.join([cls.__name__ for cls in base_component_cls_list])}"
)
raise ValueError(
msg,
)
return base_component_cls_list[0]

def validate_signatures(base_component_cls, cls_implementation):
"""
Compare the signature of overridden methods in a class with their counterparts
in the BaseComponent classes.
"""
for function_name in dir(cls_implementation):
if not function_name.startswith("__") and function_name in dir(
base_component_cls,
):
type_cls_implementation = inspect.signature(
getattr(cls_implementation, function_name, None),
)
type_base_cls = inspect.signature(
getattr(base_component_cls, function_name, None),
)
if type_cls_implementation != type_base_cls:
msg = (
f"Invalid function definition of function {function_name}. "
f"The expected function signature is {type_base_cls}"
)
raise ValueError(
msg,
)

def validate_abstract_methods_are_implemented(cls):
"""
Function to validate that a class has overridden every required function marked as
abstract.
"""
abstract_methods = [
name
for name, value in inspect.getmembers(cls)
if getattr(value, "__isabstractmethod__", False)
]
if len(abstract_methods) >= 1:
msg = (
f"Every required function must be overridden in the PythonComponent. "
f"Missing implementations for the following functions: {abstract_methods}"
)
raise ValueError(
msg,
)

validate_abstract_methods_are_implemented(cls)
base_component_cls = get_base_cls(cls)
validate_signatures(base_component_cls, cls)
Expand All @@ -53,69 +120,9 @@ def image(cls) -> Image:

return AppliedPythonComponent

def get_base_cls(cls):
"""
Returns the BaseComponent. If the implementation inherits from several classes,
the Fondant base class is selected. If more than one Fondant base component
is implemented, an exception is raised.
"""
base_component_module = inspect.getmodule(Component).__name__
base_component_cls_list = [
base for base in cls.__bases__ if base.__module__ == base_component_module
]
if len(base_component_cls_list) > 1:
msg = (
f"Multiple base classes detected. Only one component should be inherited or"
f" implemented."
f"Found classes: {', '.join([cls.__name__ for cls in base_component_cls_list])}"
)
raise ValueError(
msg,
)
return base_component_cls_list[0]

def validate_signatures(base_component_cls, cls_implementation):
"""
Compare the signature of overridden methods in a class with their counterparts
in the BaseComponent classes.
"""
for function_name in dir(cls_implementation):
if not function_name.startswith("__") and function_name in dir(
base_component_cls,
):
type_cls_implementation = inspect.signature(
getattr(cls_implementation, function_name, None),
)
type_base_cls = inspect.signature(
getattr(base_component_cls, function_name, None),
)
if type_cls_implementation != type_base_cls:
msg = (
f"Invalid function definition of function {function_name}. The expected "
f"function signature is {type_base_cls}"
)
raise ValueError(
msg,
)

def validate_abstract_methods_are_implemented(cls):
"""
Function to validate that a class has overridden every required function marked as
abstract.
"""
abstract_methods = [
name
for name, value in inspect.getmembers(cls)
if getattr(value, "__isabstractmethod__", False)
]
if len(abstract_methods) >= 1:
msg = (
f"Every required function must be overridden in the PythonComponent. "
f"Missing implementations for the following functions: {abstract_methods}"
)
raise ValueError(
msg,
)
# Call wrapper with function (`args[0]`) when no additional arguments were passed
if args:
return wrapper(args[0])

return wrapper

Expand Down
38 changes: 36 additions & 2 deletions tests/pipeline/test_python_component.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
import textwrap

import dask.dataframe as dd
Expand Down Expand Up @@ -196,8 +197,11 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
def test_invalid_load_component_wrong_return_type():
with pytest.raises( # noqa: PT012
ValueError,
match="Invalid function definition of function load. The return type "
"has to be <class 'dask.dataframe.core.DataFrame'>",
match=re.escape(
"Invalid function definition of function load. "
"The expected function signature "
"is (self) -> dask.dataframe.core.DataFrame",
),
):

@lightweight_component(
Expand All @@ -208,3 +212,33 @@ def load(self) -> int:
return 1

CreateData(produces={}, consumes={})


def test_lightweight_component_decorator_without_parentheses():
@lightweight_component
class CreateData(DaskLoadComponent):
def load(self) -> dd.DataFrame:
return None

pipeline = Pipeline(
name="dummy-pipeline",
base_path="./data",
)

pipeline.read(
ref=CreateData,
)

assert len(pipeline._graph.keys()) == 1
operation_spec = pipeline._graph["CreateData"]["operation"].operation_spec.to_json()
assert json.loads(operation_spec) == {
"specification": {
"name": "CreateData",
"image": "fondant:latest",
"description": "python component",
"consumes": {"additionalProperties": True},
"produces": {"additionalProperties": True},
},
"consumes": {},
"produces": {},
}

0 comments on commit 0237b69

Please sign in to comment.