diff --git a/sdk/RELEASE.md b/sdk/RELEASE.md index bda352b19a9..47b6a2da06b 100644 --- a/sdk/RELEASE.md +++ b/sdk/RELEASE.md @@ -15,6 +15,7 @@ * v2 compiler to throw no task defined error. [\#6545](https://github.com/kubeflow/pipelines/pull/6545) * Improve output parameter type checking in V2 SDK. [\#6566](https://github.com/kubeflow/pipelines/pull/6566) +* Use `Annotated` rather than `Union` for `Input` and `Output`. [\#6573](https://github.com/kubeflow/pipelines/pull/6573) ## Documentation Updates diff --git a/sdk/python/kfp/v2/components/types/type_annotations.py b/sdk/python/kfp/v2/components/types/type_annotations.py index 0ebbea46c2c..b736c4448e9 100644 --- a/sdk/python/kfp/v2/components/types/type_annotations.py +++ b/sdk/python/kfp/v2/components/types/type_annotations.py @@ -19,6 +19,11 @@ import re from typing import TypeVar, Union +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + T = TypeVar('T') @@ -56,37 +61,18 @@ class OutputAnnotation(): pass -# TODO: Use typing.Annotated instead of this hack. -# With typing.Annotated (Python 3.9+ or typing_extensions package), the -# following would look like: -# Input = typing.Annotated[T, InputAnnotation] -# Output = typing.Annotated[T, OutputAnnotation] - # Input represents an Input artifact of type T. -Input = Union[T, InputAnnotation] +Input = Annotated[T, InputAnnotation] # Output represents an Output artifact of type T. -Output = Union[T, OutputAnnotation] +Output = Annotated[T, OutputAnnotation] def is_artifact_annotation(typ) -> bool: - if hasattr(typ, '_subs_tree'): # Python 3.6 - subs_tree = typ._subs_tree() - return len( - subs_tree) == 3 and subs_tree[0] == Union and subs_tree[2] in [ - InputAnnotation, OutputAnnotation - ] - - if not hasattr(typ, '__origin__'): - return False - - if typ.__origin__ != Union and type(typ.__origin__) != type(Union): - return False - - if not hasattr(typ, '__args__') or len(typ.__args__) != 2: + if not hasattr(typ, '__metadata__'): return False - if typ.__args__[1] not in [InputAnnotation, OutputAnnotation]: + if typ.__metadata__[0] not in [InputAnnotation, OutputAnnotation]: return False return True @@ -97,11 +83,7 @@ def is_input_artifact(typ) -> bool: if not is_artifact_annotation(typ): return False - if hasattr(typ, '_subs_tree'): # Python 3.6 - subs_tree = typ._subs_tree() - return len(subs_tree) == 3 and subs_tree[2] == InputAnnotation - - return typ.__args__[1] == InputAnnotation + return typ.__metadata__[0] == InputAnnotation def is_output_artifact(typ) -> bool: @@ -109,11 +91,7 @@ def is_output_artifact(typ) -> bool: if not is_artifact_annotation(typ): return False - if hasattr(typ, '_subs_tree'): # Python 3.6 - subs_tree = typ._subs_tree() - return len(subs_tree) == 3 and subs_tree[2] == OutputAnnotation - - return typ.__args__[1] == OutputAnnotation + return typ.__metadata__[0] == OutputAnnotation def get_io_artifact_class(typ): @@ -122,12 +100,6 @@ def get_io_artifact_class(typ): if typ == Input or typ == Output: return None - if hasattr(typ, '_subs_tree'): # Python 3.6 - subs_tree = typ._subs_tree() - if len(subs_tree) != 3: - return None - return subs_tree[1] - return typ.__args__[0] @@ -135,13 +107,7 @@ def get_io_artifact_annotation(typ): if not is_artifact_annotation(typ): return None - if hasattr(typ, '_subs_tree'): # Python 3.6 - subs_tree = typ._subs_tree() - if len(subs_tree) != 3: - return None - return subs_tree[2] - - return typ.__args__[1] + return typ.__metadata__[0] def maybe_strip_optional_from_annotation(annotation: T) -> T: diff --git a/sdk/python/requirements.in b/sdk/python/requirements.in index c389eb87aed..bda5ca6be6d 100644 --- a/sdk/python/requirements.in +++ b/sdk/python/requirements.in @@ -34,3 +34,4 @@ fire>=0.3.1,<1 google-api-python-client>=1.7.8,<2 dataclasses>=0.8,<1; python_version<"3.7" pydantic>=1.8.2,<2 +typing-extensions>=3.10.0.2,<4 diff --git a/sdk/python/requirements.txt b/sdk/python/requirements.txt index 9cd6d620c78..12bb102e4a6 100644 --- a/sdk/python/requirements.txt +++ b/sdk/python/requirements.txt @@ -4,7 +4,6 @@ # # pip-compile requirements.in # - absl-py==0.11.0 # via -r requirements.in attrs==20.3.0 @@ -141,8 +140,10 @@ tabulate==0.8.9 # via -r requirements.in termcolor==1.1.0 # via fire -typing-extensions==3.10.0.0 - # via pydantic +typing-extensions==3.10.0.2 + # via + # -r requirements.in + # pydantic uritemplate==3.0.1 # via google-api-python-client urllib3==1.26.5 diff --git a/sdk/python/setup.py b/sdk/python/setup.py index fdf1c99dcc9..acdb6f03b86 100644 --- a/sdk/python/setup.py +++ b/sdk/python/setup.py @@ -52,6 +52,7 @@ 'protobuf>=3.13.0,<4', # Standard library backports 'dataclasses;python_version<"3.7"', + 'typing-extensions>=3.10.0.2,<4;python_version<"3.9"', 'pydantic>=1.8.2,<2', ]