diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 0426022750..43c11065ba 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -491,7 +491,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if val is None: return val - if t == int or t == typing.Optional[int]: + if t == int: return int(val) if isinstance(val, list): @@ -503,6 +503,13 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: # Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}}) return {self._fix_val_int(ktype, k): self._fix_val_int(vtype, v) for k, v in val.items()} + if get_origin(t) is typing.Union and type(None) in get_args(t): + # Handle optional type. e.g. Optional[int], Optional[dataclass] + # Marshmallow doesn't support union type, so the type here is always an optional type. + # https://github.com/marshmallow-code/marshmallow/issues/1191#issuecomment-480831796 + # Note: Union[None, int] is also an optional type, but Marshmallow does not support it. + return self._fix_val_int(get_args(t)[0], val) + if dataclasses.is_dataclass(t): return self._fix_dataclass_int(t, val) # type: ignore diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index abf9c47cf8..0fe2513908 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -148,6 +148,7 @@ def test_list_of_dataclass_getting_python_value(): @dataclass_json @dataclass() class Bar(object): + v: typing.Union[int, None] w: typing.Optional[str] x: float y: str @@ -163,7 +164,7 @@ class Foo(object): y: typing.Dict[str, str] z: Bar - foo = Foo(u=5, v=None, w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False})) + foo = Foo(u=5, v=None, w=1, x=[1], y={"hello": "10"}, z=Bar(v=3, w=None, x=1.0, y="hello", z={"world": False})) generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct()) lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) @@ -174,7 +175,6 @@ class Foo(object): foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema") guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) - print("=====") pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) assert isinstance(guessed_pv, list) assert guessed_pv[0].u == pv[0].u @@ -186,7 +186,9 @@ class Foo(object): assert type(guessed_pv[0].u) == int assert guessed_pv[0].v is None assert type(guessed_pv[0].w) == int + assert type(guessed_pv[0].z.v) == int assert type(guessed_pv[0].z.x) == float + assert guessed_pv[0].z.v == pv[0].z.v assert guessed_pv[0].z.y == pv[0].z.y assert guessed_pv[0].z.z == pv[0].z.z assert pv[0] == dataclass_from_dict(Foo, asdict(guessed_pv[0]))