Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(sdk.v2): use Annotated rather than Union hack #6573

Merged
merged 4 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 12 additions & 46 deletions sdk/python/kfp/v2/components/types/type_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import re
from typing import TypeVar, Union

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated

T = TypeVar('T')


Expand Down Expand Up @@ -56,37 +61,18 @@ class OutputAnnotation():
pass


# TODO: Use typing.Annotated instead of this hack.
# With typing.Annotated (Python 3.9+ or typing_extensions package), the
# following would look like:
# Input = typing.Annotated[T, InputAnnotation]
# Output = typing.Annotated[T, OutputAnnotation]

# Input represents an Input artifact of type T.
Input = Union[T, InputAnnotation]
Input = Annotated[T, InputAnnotation]

# Output represents an Output artifact of type T.
Output = Union[T, OutputAnnotation]
Output = Annotated[T, OutputAnnotation]


def is_artifact_annotation(typ) -> bool:
if hasattr(typ, '_subs_tree'): # Python 3.6
subs_tree = typ._subs_tree()
return len(
subs_tree) == 3 and subs_tree[0] == Union and subs_tree[2] in [
InputAnnotation, OutputAnnotation
]

if not hasattr(typ, '__origin__'):
return False

if typ.__origin__ != Union and type(typ.__origin__) != type(Union):
return False

if not hasattr(typ, '__args__') or len(typ.__args__) != 2:
if not hasattr(typ, '__metadata__'):
return False

if typ.__args__[1] not in [InputAnnotation, OutputAnnotation]:
if typ.__metadata__[0] not in [InputAnnotation, OutputAnnotation]:
return False

return True
Expand All @@ -97,23 +83,15 @@ def is_input_artifact(typ) -> bool:
if not is_artifact_annotation(typ):
return False

if hasattr(typ, '_subs_tree'): # Python 3.6
subs_tree = typ._subs_tree()
return len(subs_tree) == 3 and subs_tree[2] == InputAnnotation

return typ.__args__[1] == InputAnnotation
return typ.__metadata__[0] == InputAnnotation


def is_output_artifact(typ) -> bool:
"""Returns True if typ is of type Output[T]."""
if not is_artifact_annotation(typ):
return False

if hasattr(typ, '_subs_tree'): # Python 3.6
subs_tree = typ._subs_tree()
return len(subs_tree) == 3 and subs_tree[2] == OutputAnnotation

return typ.__args__[1] == OutputAnnotation
return typ.__metadata__[0] == OutputAnnotation


def get_io_artifact_class(typ):
Expand All @@ -122,26 +100,14 @@ def get_io_artifact_class(typ):
if typ == Input or typ == Output:
return None

if hasattr(typ, '_subs_tree'): # Python 3.6
subs_tree = typ._subs_tree()
if len(subs_tree) != 3:
return None
return subs_tree[1]

return typ.__args__[0]


def get_io_artifact_annotation(typ):
if not is_artifact_annotation(typ):
return None

if hasattr(typ, '_subs_tree'): # Python 3.6
subs_tree = typ._subs_tree()
if len(subs_tree) != 3:
return None
return subs_tree[2]

return typ.__args__[1]
return typ.__metadata__[0]


def maybe_strip_optional_from_annotation(annotation: T) -> T:
Expand Down
1 change: 1 addition & 0 deletions sdk/python/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ fire>=0.3.1,<1
google-api-python-client>=1.7.8,<2
dataclasses>=0.8,<1; python_version<"3.7"
pydantic>=1.8.2,<2
typing-extensions>=3.10.0.2,<4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add this in setup.py as well, we can't rely on an indirect dependency from pydantic.
Also this better be conditioned by a Python version, 3.9 I think?
See dataclass for example:

# Standard library backports
'dataclasses;python_version<"3.7"',

7 changes: 4 additions & 3 deletions sdk/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#
# pip-compile requirements.in
#

absl-py==0.11.0
# via -r requirements.in
attrs==20.3.0
Expand Down Expand Up @@ -141,8 +140,10 @@ tabulate==0.8.9
# via -r requirements.in
termcolor==1.1.0
# via fire
typing-extensions==3.10.0.0
# via pydantic
typing-extensions==3.10.0.2
# via
# -r requirements.in
# pydantic
uritemplate==3.0.1
# via google-api-python-client
urllib3==1.26.5
Expand Down