Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TypeTransformer] Support non-Any Python types as Any input in workflows #2432

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,22 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
modify_literal_uris(lv)
if hash is not None:
lv.hash = hash

metadata = lv.metadata or {}
# print("type engine python type", python_type.__name__)
# f"{Datum.__module__}.{Datum.__qualname__}"
print("Python Type:", python_type)
try:
# print("python_type:", python_type)
# print("python_dotted_path:", f"{python_type.__module__}.{python_type.__qualname__}")
# metadata.update({"python_dotted_path": f"{python_type.__module__}.{python_type.__qualname__}"})
metadata.update({"python_type": str(python_type)})

lv.set_metadata(metadata=metadata)
except AttributeError as e:
logger.warning(f"Attribute error occurred: {e}")
print("@@@ final metadata:", metadata)
# print("@@@ type engine final literal:", lv)
return lv

@classmethod
Expand Down Expand Up @@ -1407,7 +1423,8 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
t = self.get_sub_type(python_type)
lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore
return Literal(collection=LiteralCollection(literals=lit_list))

# literal: List[int]
#
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore
try:
lits = lv.collection.literals
Expand Down
7 changes: 6 additions & 1 deletion flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import cloudpickle
import rich_click as click
import yaml
from dataclasses_json import DataClassJsonMixin
from dataclasses_json import DataClassJsonMixin, dataclass_json
from pytimeparse import parse

from flytekit import BlobType, FlyteContext, FlyteContextManager, Literal, LiteralType, StructuredDataset
Expand Down Expand Up @@ -273,6 +273,11 @@ def convert(

if is_pydantic_basemodel(self._python_type):
return self._python_type.parse_raw(json.dumps(parsed_value)) # type: ignore

# Ensure that the python type has `from_json` function
if not hasattr(self._python_type, "from_json"):
self._python_type = dataclass_json(self._python_type)

return cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(parsed_value))


Expand Down
1 change: 1 addition & 0 deletions flytekit/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class SimpleType(object):
BINARY = _types_pb2.BINARY
ERROR = _types_pb2.ERROR
STRUCT = _types_pb2.STRUCT
ANY = _types_pb2.ANY


class SchemaType(_common.FlyteIdlEntity):
Expand Down
53 changes: 45 additions & 8 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.models.types import LiteralType, SimpleType

T = typing.TypeVar("T")

Expand Down Expand Up @@ -88,10 +88,49 @@ def assert_type(self, t: Type[T], v: T):
...

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
uri = lv.scalar.blob.uri
return FlytePickle.from_pickle(uri)
from flytekit import BlobType
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
# print("@@@ lv.scalar:", lv.scalar)
print("@@@ lv.metadata:", lv.metadata)
print("@@@ lv:", lv)
try:
# if lv is pickle file
# blob.format
uri = lv.scalar.blob.uri
if lv.scalar.blob.metadata.type.format == self.PYTHON_PICKLE_FORMAT:
return FlytePickle.from_pickle(uri)
elif lv.scalar.blob.metadata.type.dimensionality == BlobType.BlobDimensionality.MULTIPART:
return TypeEngine.to_python_value(ctx, lv, FlyteDirectory)
elif lv.scalar.blob.metadata.type.dimensionality == BlobType.BlobDimensionality.SINGLE:
return TypeEngine.to_python_value(ctx, lv, FlyteFile)
except Exception as e:
from pydoc import locate
metadata = lv.metadata
if metadata and metadata.get("python_type"):
python_type = metadata.get("python_type")
py_type = eval(python_type)
print("@@@ pickle -> py_type:", py_type)
if py_type != typing.Any:
# int -> type: int
# dataclass Datum -> type: Datum
# List[int] -> type: List[int]
return TypeEngine.to_python_value(ctx, lv, py_type)

# This method is for dataclass
# if metadata and metadata.get("python_dotted_path"):
# python_dotted_path = metadata.get("python_dotted_path")
# py_type = locate(python_dotted_path)
# print("@@@ pickle -> py_type:", py_type)
# if py_type != typing.Any:
# # int -> type: int
# # dataclass Datum -> type: Datum
# # List[int] -> type: List[int]
# return TypeEngine.to_python_value(ctx, lv, py_type)
raise e

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
# to blob or bytes
if python_val is None:
raise AssertionError("Cannot pickle None Value.")
meta = BlobMetadata(
Expand All @@ -113,12 +152,10 @@ def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlytePickl
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")

def get_literal_type(self, t: Type[T]) -> LiteralType:
lt = LiteralType(
blob=_core_types.BlobType(
format=self.PYTHON_PICKLE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)
lt = LiteralType(simple=SimpleType.ANY)
lt.metadata = {"python_class_name": str(t)}
lt.metadata = {"isAny": str(t == typing.Any)}

return lt


Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-huggingface/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

plugin_requires = [
"flytekit>=1.3.0b2,<2.0.0",
"datasets>=2.4.0",
"datasets>=2.4.0,<2.19.2",
]

__version__ = "0.0.0+develop"
Expand Down
81 changes: 79 additions & 2 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum, auto
from typing import List, Optional, Type
from typing import Any, List, Optional, Tuple, Type

import mock
import pyarrow as pa
Expand Down Expand Up @@ -2017,6 +2017,79 @@ def __init__(self, number: int):
TypeEngine.to_literal(ctx, 1, typing.Optional[typing.Any], lt)


def test_non_any_as_any_input_workflow():
@task
def foo(a: Any) -> int:
if type(a) == int:
return a + 1
return 0

@workflow
def wf_int(a: int) -> int:
return foo(a=a)

@workflow
def wf_float(a: float) -> int:
return foo(a=a)

@workflow
def wf_str(a: str) -> int:
return foo(a=a)

@workflow
def wf_bool(a: bool) -> int:
return foo(a=a)

@workflow
def wf_datetime(a: datetime.datetime) -> int:
return foo(a=a)

@workflow
def wf_duration(a: datetime.timedelta) -> int:
return foo(a=a)

assert wf_int(a=1) == 2
assert wf_float(a=1.0) == 0
assert wf_str(a="1") == 0
assert wf_bool(a=True) == 0
assert wf_datetime(a=datetime.datetime.now()) == 0
assert wf_duration(a=datetime.timedelta(seconds=1)) == 0


def test_non_any_as_any_output_workflow():
now = datetime.datetime.now(datetime.timezone.utc)

@task
def foo_int() -> int:
return 1

@task
def foo_float() -> float:
return 1.0

@task
def foo_str() -> str:
return "1"

@task
def foo_bool() -> bool:
return True

@task
def foo_datetime() -> datetime.datetime:
return now

@task
def foo_duration() -> datetime.timedelta:
return datetime.timedelta(seconds=1)

@workflow
def wf() -> Tuple[Any, Any, Any, Any, Any, Any]:
return foo_int(), foo_float(), foo_str(), foo_bool(), foo_datetime(), foo_duration()

assert wf() == (1, 1.0, "1", True, now, datetime.timedelta(seconds=1))


def test_enum_in_dataclass():
@dataclass
class Datum(DataClassJsonMixin):
Expand Down Expand Up @@ -2203,7 +2276,11 @@ def t1(a: int) -> int:

return v_1

assert t1(a=3) == 9
@workflow
def wf(a: int) -> int:
return t1(a=a)

assert wf(a=3) == 9


def test_literal_hash_int_can_be_set():
Expand Down