diff --git a/tests/flytekit/unit/core/test_typing_annotation.py b/tests/flytekit/unit/core/test_typing_annotation.py new file mode 100644 index 0000000000..cfcfd6b50d --- /dev/null +++ b/tests/flytekit/unit/core/test_typing_annotation.py @@ -0,0 +1,61 @@ +import typing +from collections import OrderedDict + +from flytekit.common.translator import get_serializable +from flytekit.core import context_manager +from flytekit.core.annotation import FlyteAnnotation +from flytekit.core.context_manager import Image, ImageConfig +from flytekit.core.task import task +from flytekit.models.annotation import TypeAnnotation + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = context_manager.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) +entity_mapping = OrderedDict() + + +@task +def x(a: typing.Annotated[int, FlyteAnnotation({"foo": {"bar": 1}})], b: str): + ... + + +@task +def y0(a: typing.List[typing.Annotated[int, FlyteAnnotation({"foo": {"bar": 1}})]]): + ... + + +@task +def y1(a: typing.Annotated[typing.List[int], FlyteAnnotation({"foo": {"bar": 1}})]): + ... + + +def test_get_variable_descriptions(): + + x_tsk = get_serializable(entity_mapping, serialization_settings, x) + x_input_vars = x_tsk.template.interface.inputs + + a_ann = x_input_vars["a"].type.annotation + assert isinstance(a_ann, TypeAnnotation) + assert a_ann.annotations["foo"] == {"bar": 1} + + b_ann = x_input_vars["b"].type.annotation + assert b_ann is None + + # Annotated simple type within list generic + y0_tsk = get_serializable(entity_mapping, serialization_settings, y0) + y0_input_vars = y0_tsk.template.interface.inputs + y0_a_ann = y0_input_vars["a"].type.collection_type.annotation + assert isinstance(y0_a_ann, TypeAnnotation) + assert y0_a_ann.annotations["foo"] == {"bar": 1} + + # Annotated list generic + y1_tsk = get_serializable(entity_mapping, serialization_settings, y1) + y1_input_vars = y1_tsk.template.interface.inputs + y1_a_ann = y1_input_vars["a"].type.annotation + assert isinstance(y1_a_ann, TypeAnnotation) + assert y1_a_ann.annotations["foo"] == {"bar": 1}