-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add component argument inference (#763)
Fixes #751 This PR introduces functionality to infer the arguments from a `Component` class. The result is a dictionary with the argument names as keys, and `Argument` instances as values, which is the format of [`component_spec.args`.](https://github.com/ml6team/fondant/blob/8e828441eec8ff91074e5c8ccf16fe405b719594/src/fondant/core/component_spec.py#L193) We can leverage this behavior for Lightweight Python components as described in #750. Did some TDD here, let me know if I missed any cases.
- Loading branch information
1 parent
b62c4f2
commit 0861022
Showing
3 changed files
with
354 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import inspect | ||
import typing as t | ||
|
||
from fondant.component import Component | ||
from fondant.core.component_spec import Argument | ||
from fondant.core.exceptions import UnsupportedTypeAnnotation | ||
|
||
BUILTIN_TYPES = [str, int, float, bool, dict, list] | ||
|
||
|
||
def annotation_to_type(annotation: t.Any) -> t.Type: | ||
"""Extract the simple built-in type from an Annotation. | ||
Examples: | ||
dict[str, int] -> dict | ||
t.Optional[str] -> str | ||
Args: | ||
annotation: Annotation of an argument as returned by inspect.signature | ||
Raises: | ||
UnsupportedTypeAnnotation: If the annotation is not simple or not based on a built-in type | ||
""" | ||
# If no type annotation is present, default to str | ||
if annotation == inspect.Parameter.empty: | ||
return str | ||
|
||
# Unpack the annotation until we get a simple type. | ||
# This removes complex structures such as Optional | ||
while t.get_origin(annotation) not in [*BUILTIN_TYPES, None]: | ||
# Filter out NoneType values (Optional[x] is represented as Union[x, NoneType] | ||
annotation_args = [ | ||
arg for arg in t.get_args(annotation) if arg is not type(None) | ||
] | ||
|
||
# Multiple arguments remaining (eg. Union[str, int]) | ||
# Raise error since we cannot infer type unambiguously | ||
if len(annotation_args) > 1: | ||
msg = ( | ||
f"Fondant only supports simple types for component arguments." | ||
f"Expected one of {BUILTIN_TYPES}, received {annotation} instead." | ||
) | ||
raise UnsupportedTypeAnnotation(msg) | ||
|
||
annotation = annotation_args[0] | ||
|
||
# Remove any subscription (eg. dict[str, int] -> dict) | ||
annotation = t.get_origin(annotation) or annotation | ||
|
||
# Check for classes not supported as argument | ||
if annotation not in BUILTIN_TYPES: | ||
msg = ( | ||
f"Fondant only supports builtin types for component arguments." | ||
f"Expected one of {BUILTIN_TYPES}, received {annotation} instead." | ||
) | ||
raise UnsupportedTypeAnnotation(msg) | ||
|
||
return annotation | ||
|
||
|
||
def is_optional(parameter: inspect.Parameter) -> bool: | ||
"""Check if an inspect.Parameter is optional. We check this based on the presence of a | ||
default value instead of based on the type, since this is more trustworthy. | ||
""" | ||
return parameter.default != inspect.Parameter.empty | ||
|
||
|
||
def get_default(parameter: inspect.Parameter) -> t.Any: | ||
"""Get the default value from an inspect.Parameter.""" | ||
if parameter.default == inspect.Parameter.empty: | ||
return None | ||
return parameter.default | ||
|
||
|
||
def parameter_to_argument(parameter: inspect.Parameter) -> Argument: | ||
"""Translate an inspect.Parameter into a Fondant Argument.""" | ||
return Argument( | ||
name=parameter.name, | ||
type=annotation_to_type(parameter.annotation), | ||
optional=is_optional(parameter), | ||
default=get_default(parameter), | ||
) | ||
|
||
|
||
def infer_arguments(component: t.Type[Component]) -> t.Dict[str, Argument]: | ||
"""Infer the user arguments from a Python Component class. | ||
Default arguments are skipped. | ||
Args: | ||
component: Component class to inspect. | ||
""" | ||
signature = inspect.signature(component) | ||
|
||
arguments = {} | ||
for name, parameter in signature.parameters.items(): | ||
# Skip non-user arguments | ||
if name in ["self", "consumes", "produces", "kwargs"]: | ||
continue | ||
|
||
arguments[name] = parameter_to_argument(parameter) | ||
|
||
return arguments |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
import sys | ||
import typing as t | ||
|
||
import pytest | ||
from fondant.component import PandasTransformComponent | ||
from fondant.core.component_spec import Argument | ||
from fondant.core.exceptions import UnsupportedTypeAnnotation | ||
from fondant.pipeline.argument_inference import infer_arguments | ||
|
||
|
||
def test_no_init(): | ||
class TestComponent(PandasTransformComponent): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == {} | ||
|
||
|
||
def test_no_arguments(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__(self, **kwargs): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == {} | ||
|
||
|
||
def test_missing_types(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__(self, *, argument, **kwargs): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"argument": Argument( | ||
name="argument", | ||
type=str, | ||
optional=False, | ||
), | ||
} | ||
|
||
|
||
def test_types(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
str_argument: str, | ||
int_argument: int, | ||
float_argument: float, | ||
bool_argument: bool, | ||
dict_argument: dict, | ||
list_argument: list, | ||
**kwargs, | ||
): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"str_argument": Argument( | ||
name="str_argument", | ||
type=str, | ||
optional=False, | ||
), | ||
"int_argument": Argument( | ||
name="int_argument", | ||
type=int, | ||
optional=False, | ||
), | ||
"float_argument": Argument( | ||
name="float_argument", | ||
type=float, | ||
optional=False, | ||
), | ||
"bool_argument": Argument( | ||
name="bool_argument", | ||
type=bool, | ||
optional=False, | ||
), | ||
"dict_argument": Argument( | ||
name="dict_argument", | ||
type=dict, | ||
optional=False, | ||
), | ||
"list_argument": Argument( | ||
name="list_argument", | ||
type=list, | ||
optional=False, | ||
), | ||
} | ||
|
||
|
||
def test_optional_types(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
str_argument: t.Optional[str] = "", | ||
int_argument: t.Optional[int] = 1, | ||
float_argument: t.Optional[float] = 1.0, | ||
bool_argument: t.Optional[bool] = False, | ||
dict_argument: t.Optional[dict] = None, | ||
list_argument: t.Optional[list] = None, | ||
**kwargs, | ||
): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"str_argument": Argument( | ||
name="str_argument", | ||
type=str, | ||
optional=True, | ||
default="", | ||
), | ||
"int_argument": Argument( | ||
name="int_argument", | ||
type=int, | ||
optional=True, | ||
default=1, | ||
), | ||
"float_argument": Argument( | ||
name="float_argument", | ||
type=float, | ||
optional=True, | ||
default=1.0, | ||
), | ||
"bool_argument": Argument( | ||
name="bool_argument", | ||
type=bool, | ||
optional=True, | ||
default=False, | ||
), | ||
"dict_argument": Argument( | ||
name="dict_argument", | ||
type=dict, | ||
optional=True, | ||
default=None, | ||
), | ||
"list_argument": Argument( | ||
name="list_argument", | ||
type=list, | ||
optional=True, | ||
default=None, | ||
), | ||
} | ||
|
||
|
||
def test_parametrized_types_old(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
dict_argument: t.Dict[str, t.Any], | ||
list_argument: t.Optional[t.List[int]] = None, | ||
**kwargs, | ||
): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"dict_argument": Argument( | ||
name="dict_argument", | ||
type=dict, | ||
optional=False, | ||
default=None, | ||
), | ||
"list_argument": Argument( | ||
name="list_argument", | ||
type=list, | ||
optional=True, | ||
default=None, | ||
), | ||
} | ||
|
||
|
||
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") | ||
def test_parametrized_types_new(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
dict_argument: dict[str, t.Any], | ||
list_argument: list[int] | None = None, | ||
**kwargs, | ||
): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"dict_argument": Argument( | ||
name="dict_argument", | ||
type=dict, | ||
optional=False, | ||
default=None, | ||
), | ||
"list_argument": Argument( | ||
name="list_argument", | ||
type=list, | ||
optional=True, | ||
default=None, | ||
), | ||
} | ||
|
||
|
||
def test_unsupported_complex_type(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
union_argument: t.Union[str, int], | ||
**kwargs, | ||
): | ||
pass | ||
|
||
with pytest.raises( | ||
UnsupportedTypeAnnotation, | ||
match="Fondant only supports simple types", | ||
): | ||
infer_arguments(TestComponent) | ||
|
||
|
||
def test_unsupported_custom_type(): | ||
class CustomClass: | ||
pass | ||
|
||
class TestComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
class_argument: CustomClass, | ||
**kwargs, | ||
): | ||
pass | ||
|
||
with pytest.raises( | ||
UnsupportedTypeAnnotation, | ||
match="Fondant only supports builtin types", | ||
): | ||
infer_arguments(TestComponent) | ||
|
||
|
||
def test_consumes_produces(): | ||
class TestComponent(PandasTransformComponent): | ||
def __init__(self, *, argument, consumes, **kwargs): | ||
pass | ||
|
||
assert infer_arguments(TestComponent) == { | ||
"argument": Argument( | ||
name="argument", | ||
type=str, | ||
optional=False, | ||
), | ||
} |