Skip to content

Commit

Permalink
support any and non any workflow
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier committed May 22, 2024
1 parent 01b8842 commit fb11101
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 29 deletions.
5 changes: 0 additions & 5 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,6 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T
"""
Converts a Literal value with an expected python type into a python value.
"""
# print("Expected Python Type: ", expected_python_type)
transformer = cls.get_transformer(expected_python_type)
return transformer.to_python_value(ctx, lv, expected_python_type)

Expand Down Expand Up @@ -1223,13 +1222,9 @@ def literal_map_to_kwargs(
kwargs = {}
for i, k in enumerate(lm.literals):
try:
# print("converting input: ", k, " with value: ", lm.literals[k])
# print("Type 1: ", python_interface_inputs[k])
kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_interface_inputs[k])
# print("kwargs[k]:", kwargs[k])
except TypeTransformerFailedError as exc:
raise TypeTransformerFailedError(f"Error converting input '{k}' at position {i}:\n {exc}") from exc

return kwargs

@classmethod
Expand Down
23 changes: 0 additions & 23 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,30 +87,7 @@ def assert_type(self, t: Type[T], v: T):
# Every type can serialize to pickle, so we don't need to check the type here.
...

# def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
# try:
# uri = lv.scalar.blob.uri
# return FlytePickle.from_pickle(uri)
# except Exception as e:
# from datetime import datetime, timedelta

# if lv.scalar:
# if lv.scalar.primitive:
# if lv.scalar.primitive.integer:
# return TypeEngine.to_python_value(ctx, lv, int)
# elif lv.scalar.primitive.float_value:
# return TypeEngine.to_python_value(ctx, lv, float)
# elif lv.scalar.primitive.string_value:
# return TypeEngine.to_python_value(ctx, lv, str)
# elif lv.scalar.primitive.boolean:
# return TypeEngine.to_python_value(ctx, lv, bool)
# elif lv.scalar.primitive.datetime:
# return TypeEngine.to_python_value(ctx, lv, datetime)
# elif lv.scalar.primitive.duration:
# return TypeEngine.to_python_value(ctx, lv, timedelta)
# raise None
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
print("lv:", lv)
primitive = lv.scalar.primitive
if primitive:
from datetime import datetime, timedelta
Expand Down
75 changes: 74 additions & 1 deletion tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum, auto
from typing import List, Optional, Type
from typing import Any, List, Optional, Tuple, Type

import mock
import pyarrow as pa
Expand Down Expand Up @@ -2017,6 +2017,79 @@ def __init__(self, number: int):
TypeEngine.to_literal(ctx, 1, typing.Optional[typing.Any], lt)


def test_non_any_as_any_input_workflow():
@task
def foo(a: Any) -> int:
if type(a) == int:
return a + 1
return 0

@workflow
def wf_int(a: int) -> int:
return foo(a=a)

@workflow
def wf_float(a: float) -> int:
return foo(a=a)

@workflow
def wf_str(a: str) -> int:
return foo(a=a)

@workflow
def wf_bool(a: bool) -> int:
return foo(a=a)

@workflow
def wf_datetime(a: datetime.datetime) -> int:
return foo(a=a)

@workflow
def wf_duration(a: datetime.timedelta) -> int:
return foo(a=a)

assert wf_int(a=1) == 2
assert wf_float(a=1.0) == 0
assert wf_str(a="1") == 0
assert wf_bool(a=True) == 0
assert wf_datetime(a=datetime.datetime.now()) == 0
assert wf_duration(a=datetime.timedelta(seconds=1)) == 0


def test_non_any_as_any_output_workflow():
now = datetime.datetime.now(datetime.timezone.utc)

@task
def foo_int() -> int:
return 1

@task
def foo_float() -> float:
return 1.0

@task
def foo_str() -> str:
return "1"

@task
def foo_bool() -> bool:
return True

@task
def foo_datetime() -> datetime.datetime:
return now

@task
def foo_duration() -> datetime.timedelta:
return datetime.timedelta(seconds=1)

@workflow
def wf() -> Tuple[Any, Any, Any, Any, Any, Any]:
return foo_int(), foo_float(), foo_str(), foo_bool(), foo_datetime(), foo_duration()

assert wf() == (1, 1.0, "1", True, now, datetime.timedelta(seconds=1))


def test_enum_in_dataclass():
@dataclass
class Datum(DataClassJsonMixin):
Expand Down

0 comments on commit fb11101

Please sign in to comment.