From 0237b69c1a315f48cd7580fae193144f1fd596ec Mon Sep 17 00:00:00 2001 From: Matthias Richter Date: Mon, 22 Jan 2024 08:18:40 +0100 Subject: [PATCH] Add support for parentheses in decorator calls --- src/fondant/pipeline/lightweight_component.py | 133 +++++++++--------- tests/pipeline/test_python_component.py | 38 ++++- 2 files changed, 106 insertions(+), 65 deletions(-) diff --git a/src/fondant/pipeline/lightweight_component.py b/src/fondant/pipeline/lightweight_component.py index 2a541887..cd0755c0 100644 --- a/src/fondant/pipeline/lightweight_component.py +++ b/src/fondant/pipeline/lightweight_component.py @@ -27,6 +27,7 @@ def image(cls) -> Image: def lightweight_component( + *args, extra_requires: t.Optional[t.List[str]] = None, base_image: t.Optional[str] = None, ): @@ -40,6 +41,72 @@ def wrapper(cls): script=script, ) + def get_base_cls(cls): + """ + Returns the BaseComponent. If the implementation inherits from several classes, + the Fondant base class is selected. If more than one Fondant base component + is implemented, an exception is raised. + """ + base_component_module = inspect.getmodule(Component).__name__ + base_component_cls_list = [ + base + for base in cls.__bases__ + if base.__module__ == base_component_module + ] + if len(base_component_cls_list) > 1: + msg = ( + f"Multiple base classes detected. Only one component should be inherited or" + f" implemented." + f"Found classes: {', '.join([cls.__name__ for cls in base_component_cls_list])}" + ) + raise ValueError( + msg, + ) + return base_component_cls_list[0] + + def validate_signatures(base_component_cls, cls_implementation): + """ + Compare the signature of overridden methods in a class with their counterparts + in the BaseComponent classes. + """ + for function_name in dir(cls_implementation): + if not function_name.startswith("__") and function_name in dir( + base_component_cls, + ): + type_cls_implementation = inspect.signature( + getattr(cls_implementation, function_name, None), + ) + type_base_cls = inspect.signature( + getattr(base_component_cls, function_name, None), + ) + if type_cls_implementation != type_base_cls: + msg = ( + f"Invalid function definition of function {function_name}. " + f"The expected function signature is {type_base_cls}" + ) + raise ValueError( + msg, + ) + + def validate_abstract_methods_are_implemented(cls): + """ + Function to validate that a class has overridden every required function marked as + abstract. + """ + abstract_methods = [ + name + for name, value in inspect.getmembers(cls) + if getattr(value, "__isabstractmethod__", False) + ] + if len(abstract_methods) >= 1: + msg = ( + f"Every required function must be overridden in the PythonComponent. " + f"Missing implementations for the following functions: {abstract_methods}" + ) + raise ValueError( + msg, + ) + validate_abstract_methods_are_implemented(cls) base_component_cls = get_base_cls(cls) validate_signatures(base_component_cls, cls) @@ -53,69 +120,9 @@ def image(cls) -> Image: return AppliedPythonComponent - def get_base_cls(cls): - """ - Returns the BaseComponent. If the implementation inherits from several classes, - the Fondant base class is selected. If more than one Fondant base component - is implemented, an exception is raised. - """ - base_component_module = inspect.getmodule(Component).__name__ - base_component_cls_list = [ - base for base in cls.__bases__ if base.__module__ == base_component_module - ] - if len(base_component_cls_list) > 1: - msg = ( - f"Multiple base classes detected. Only one component should be inherited or" - f" implemented." - f"Found classes: {', '.join([cls.__name__ for cls in base_component_cls_list])}" - ) - raise ValueError( - msg, - ) - return base_component_cls_list[0] - - def validate_signatures(base_component_cls, cls_implementation): - """ - Compare the signature of overridden methods in a class with their counterparts - in the BaseComponent classes. - """ - for function_name in dir(cls_implementation): - if not function_name.startswith("__") and function_name in dir( - base_component_cls, - ): - type_cls_implementation = inspect.signature( - getattr(cls_implementation, function_name, None), - ) - type_base_cls = inspect.signature( - getattr(base_component_cls, function_name, None), - ) - if type_cls_implementation != type_base_cls: - msg = ( - f"Invalid function definition of function {function_name}. The expected " - f"function signature is {type_base_cls}" - ) - raise ValueError( - msg, - ) - - def validate_abstract_methods_are_implemented(cls): - """ - Function to validate that a class has overridden every required function marked as - abstract. - """ - abstract_methods = [ - name - for name, value in inspect.getmembers(cls) - if getattr(value, "__isabstractmethod__", False) - ] - if len(abstract_methods) >= 1: - msg = ( - f"Every required function must be overridden in the PythonComponent. " - f"Missing implementations for the following functions: {abstract_methods}" - ) - raise ValueError( - msg, - ) + # Call wrapper with function (`args[0]`) when no additional arguments were passed + if args: + return wrapper(args[0]) return wrapper diff --git a/tests/pipeline/test_python_component.py b/tests/pipeline/test_python_component.py index fc2aec7a..572f3b5d 100644 --- a/tests/pipeline/test_python_component.py +++ b/tests/pipeline/test_python_component.py @@ -1,4 +1,5 @@ import json +import re import textwrap import dask.dataframe as dd @@ -196,8 +197,11 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: def test_invalid_load_component_wrong_return_type(): with pytest.raises( # noqa: PT012 ValueError, - match="Invalid function definition of function load. The return type " - "has to be ", + match=re.escape( + "Invalid function definition of function load. " + "The expected function signature " + "is (self) -> dask.dataframe.core.DataFrame", + ), ): @lightweight_component( @@ -208,3 +212,33 @@ def load(self) -> int: return 1 CreateData(produces={}, consumes={}) + + +def test_lightweight_component_decorator_without_parentheses(): + @lightweight_component + class CreateData(DaskLoadComponent): + def load(self) -> dd.DataFrame: + return None + + pipeline = Pipeline( + name="dummy-pipeline", + base_path="./data", + ) + + pipeline.read( + ref=CreateData, + ) + + assert len(pipeline._graph.keys()) == 1 + operation_spec = pipeline._graph["CreateData"]["operation"].operation_spec.to_json() + assert json.loads(operation_spec) == { + "specification": { + "name": "CreateData", + "image": "fondant:latest", + "description": "python component", + "consumes": {"additionalProperties": True}, + "produces": {"additionalProperties": True}, + }, + "consumes": {}, + "produces": {}, + }