diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index a3b1b80dd8..0f651410bf 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -7,7 +7,7 @@ from collections import OrderedDict from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union -from typing_extensions import get_args, get_origin +from typing_extensions import get_args, get_origin, get_type_hints from flytekit.core import context_manager from flytekit.core.docstring import Docstring @@ -283,11 +283,8 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc For now the fancy object, maybe in the future a dumb object. """ - try: - # include_extras can only be used in python >= 3.9 - type_hints = typing.get_type_hints(fn, include_extras=True) - except TypeError: - type_hints = typing.get_type_hints(fn) + + type_hints = get_type_hints(fn, include_extras=True) signature = inspect.signature(fn) return_annotation = type_hints.get("return", None) @@ -395,7 +392,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... bases = return_annotation.__bases__ # type: ignore if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"): logger.debug(f"Task returns named tuple {return_annotation}") - return dict(typing.get_type_hints(return_annotation)) + return dict(get_type_hints(return_annotation, include_extras=True)) if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore # Handle option 3 diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 077b19f8b5..abdf69f5b0 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -16,7 +16,7 @@ from dataclasses_json import dataclass_json from google.protobuf.struct_pb2 import Struct from pandas._testing import assert_frame_equal -from typing_extensions import Annotated +from typing_extensions import Annotated, get_origin import flytekit import flytekit.configuration @@ -89,6 +89,15 @@ def my_task(a: int) -> typing.NamedTuple("OutputsBC", b=typing.ForwardRef("int") assert context_manager.FlyteContextManager.size() == 1 +def test_annotated_namedtuple_output(): + @task + def my_task(a: int) -> typing.NamedTuple("OutputA", a=Annotated[int, "metadata-a"]): + return a + 2 + + assert my_task(a=9) == (11,) + assert get_origin(my_task.python_interface.outputs["a"]) is Annotated + + def test_simple_input_no_output(): @task def my_task(a: int):