Skip to content

Commit

Permalink
Add component argument inference (#763)
Browse files Browse the repository at this point in the history
Fixes #751 

This PR introduces functionality to infer the arguments from a
`Component` class. The result is a dictionary with the argument names as
keys, and `Argument` instances as values, which is the format of
[`component_spec.args`.](https://github.com/ml6team/fondant/blob/8e828441eec8ff91074e5c8ccf16fe405b719594/src/fondant/core/component_spec.py#L193)

We can leverage this behavior for Lightweight Python components as
described in #750.

Did some TDD here, let me know if I missed any cases.
  • Loading branch information
RobbeSneyders authored Jan 9, 2024
1 parent b62c4f2 commit 0861022
Show file tree
Hide file tree
Showing 3 changed files with 354 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/fondant/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ class InvalidPipelineDefinition(ValidationError, FondantException):

class InvalidTypeSchema(ValidationError, FondantException):
"""Thrown when a Type schema definition is invalid."""


class UnsupportedTypeAnnotation(FondantException):
"""Thrown when an unsupported type annotation is encountered during type inference."""
103 changes: 103 additions & 0 deletions src/fondant/pipeline/argument_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import inspect
import typing as t

from fondant.component import Component
from fondant.core.component_spec import Argument
from fondant.core.exceptions import UnsupportedTypeAnnotation

BUILTIN_TYPES = [str, int, float, bool, dict, list]


def annotation_to_type(annotation: t.Any) -> t.Type:
"""Extract the simple built-in type from an Annotation.
Examples:
dict[str, int] -> dict
t.Optional[str] -> str
Args:
annotation: Annotation of an argument as returned by inspect.signature
Raises:
UnsupportedTypeAnnotation: If the annotation is not simple or not based on a built-in type
"""
# If no type annotation is present, default to str
if annotation == inspect.Parameter.empty:
return str

# Unpack the annotation until we get a simple type.
# This removes complex structures such as Optional
while t.get_origin(annotation) not in [*BUILTIN_TYPES, None]:
# Filter out NoneType values (Optional[x] is represented as Union[x, NoneType]
annotation_args = [
arg for arg in t.get_args(annotation) if arg is not type(None)
]

# Multiple arguments remaining (eg. Union[str, int])
# Raise error since we cannot infer type unambiguously
if len(annotation_args) > 1:
msg = (
f"Fondant only supports simple types for component arguments."
f"Expected one of {BUILTIN_TYPES}, received {annotation} instead."
)
raise UnsupportedTypeAnnotation(msg)

annotation = annotation_args[0]

# Remove any subscription (eg. dict[str, int] -> dict)
annotation = t.get_origin(annotation) or annotation

# Check for classes not supported as argument
if annotation not in BUILTIN_TYPES:
msg = (
f"Fondant only supports builtin types for component arguments."
f"Expected one of {BUILTIN_TYPES}, received {annotation} instead."
)
raise UnsupportedTypeAnnotation(msg)

return annotation


def is_optional(parameter: inspect.Parameter) -> bool:
"""Check if an inspect.Parameter is optional. We check this based on the presence of a
default value instead of based on the type, since this is more trustworthy.
"""
return parameter.default != inspect.Parameter.empty


def get_default(parameter: inspect.Parameter) -> t.Any:
"""Get the default value from an inspect.Parameter."""
if parameter.default == inspect.Parameter.empty:
return None
return parameter.default


def parameter_to_argument(parameter: inspect.Parameter) -> Argument:
"""Translate an inspect.Parameter into a Fondant Argument."""
return Argument(
name=parameter.name,
type=annotation_to_type(parameter.annotation),
optional=is_optional(parameter),
default=get_default(parameter),
)


def infer_arguments(component: t.Type[Component]) -> t.Dict[str, Argument]:
"""Infer the user arguments from a Python Component class.
Default arguments are skipped.
Args:
component: Component class to inspect.
"""
signature = inspect.signature(component)

arguments = {}
for name, parameter in signature.parameters.items():
# Skip non-user arguments
if name in ["self", "consumes", "produces", "kwargs"]:
continue

arguments[name] = parameter_to_argument(parameter)

return arguments
247 changes: 247 additions & 0 deletions tests/pipeline/test_argument_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import sys
import typing as t

import pytest
from fondant.component import PandasTransformComponent
from fondant.core.component_spec import Argument
from fondant.core.exceptions import UnsupportedTypeAnnotation
from fondant.pipeline.argument_inference import infer_arguments


def test_no_init():
class TestComponent(PandasTransformComponent):
pass

assert infer_arguments(TestComponent) == {}


def test_no_arguments():
class TestComponent(PandasTransformComponent):
def __init__(self, **kwargs):
pass

assert infer_arguments(TestComponent) == {}


def test_missing_types():
class TestComponent(PandasTransformComponent):
def __init__(self, *, argument, **kwargs):
pass

assert infer_arguments(TestComponent) == {
"argument": Argument(
name="argument",
type=str,
optional=False,
),
}


def test_types():
class TestComponent(PandasTransformComponent):
def __init__(
self,
*,
str_argument: str,
int_argument: int,
float_argument: float,
bool_argument: bool,
dict_argument: dict,
list_argument: list,
**kwargs,
):
pass

assert infer_arguments(TestComponent) == {
"str_argument": Argument(
name="str_argument",
type=str,
optional=False,
),
"int_argument": Argument(
name="int_argument",
type=int,
optional=False,
),
"float_argument": Argument(
name="float_argument",
type=float,
optional=False,
),
"bool_argument": Argument(
name="bool_argument",
type=bool,
optional=False,
),
"dict_argument": Argument(
name="dict_argument",
type=dict,
optional=False,
),
"list_argument": Argument(
name="list_argument",
type=list,
optional=False,
),
}


def test_optional_types():
class TestComponent(PandasTransformComponent):
def __init__(
self,
*,
str_argument: t.Optional[str] = "",
int_argument: t.Optional[int] = 1,
float_argument: t.Optional[float] = 1.0,
bool_argument: t.Optional[bool] = False,
dict_argument: t.Optional[dict] = None,
list_argument: t.Optional[list] = None,
**kwargs,
):
pass

assert infer_arguments(TestComponent) == {
"str_argument": Argument(
name="str_argument",
type=str,
optional=True,
default="",
),
"int_argument": Argument(
name="int_argument",
type=int,
optional=True,
default=1,
),
"float_argument": Argument(
name="float_argument",
type=float,
optional=True,
default=1.0,
),
"bool_argument": Argument(
name="bool_argument",
type=bool,
optional=True,
default=False,
),
"dict_argument": Argument(
name="dict_argument",
type=dict,
optional=True,
default=None,
),
"list_argument": Argument(
name="list_argument",
type=list,
optional=True,
default=None,
),
}


def test_parametrized_types_old():
class TestComponent(PandasTransformComponent):
def __init__(
self,
*,
dict_argument: t.Dict[str, t.Any],
list_argument: t.Optional[t.List[int]] = None,
**kwargs,
):
pass

assert infer_arguments(TestComponent) == {
"dict_argument": Argument(
name="dict_argument",
type=dict,
optional=False,
default=None,
),
"list_argument": Argument(
name="list_argument",
type=list,
optional=True,
default=None,
),
}


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
def test_parametrized_types_new():
class TestComponent(PandasTransformComponent):
def __init__(
self,
*,
dict_argument: dict[str, t.Any],
list_argument: list[int] | None = None,
**kwargs,
):
pass

assert infer_arguments(TestComponent) == {
"dict_argument": Argument(
name="dict_argument",
type=dict,
optional=False,
default=None,
),
"list_argument": Argument(
name="list_argument",
type=list,
optional=True,
default=None,
),
}


def test_unsupported_complex_type():
class TestComponent(PandasTransformComponent):
def __init__(
self,
*,
union_argument: t.Union[str, int],
**kwargs,
):
pass

with pytest.raises(
UnsupportedTypeAnnotation,
match="Fondant only supports simple types",
):
infer_arguments(TestComponent)


def test_unsupported_custom_type():
class CustomClass:
pass

class TestComponent(PandasTransformComponent):
def __init__(
self,
*,
class_argument: CustomClass,
**kwargs,
):
pass

with pytest.raises(
UnsupportedTypeAnnotation,
match="Fondant only supports builtin types",
):
infer_arguments(TestComponent)


def test_consumes_produces():
class TestComponent(PandasTransformComponent):
def __init__(self, *, argument, consumes, **kwargs):
pass

assert infer_arguments(TestComponent) == {
"argument": Argument(
name="argument",
type=str,
optional=False,
),
}

0 comments on commit 0861022

Please sign in to comment.