Skip to content

Commit

Permalink
Fix the type of optional[int] in nested dataclass (#1148)
Browse files Browse the repository at this point in the history
* Fix the type of optional[int] in nested dataclass

Signed-off-by: Kevin Su <[email protected]>

* update tests

Signed-off-by: Kevin Su <[email protected]>

* update comments

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Sep 9, 2022
1 parent 3e080b5 commit 4368e98
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
9 changes: 8 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) ->
def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
if val is None:
return val
if t == int or t == typing.Optional[int]:
if t == int:
return int(val)

if isinstance(val, list):
Expand All @@ -503,6 +503,13 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
# Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}})
return {self._fix_val_int(ktype, k): self._fix_val_int(vtype, v) for k, v in val.items()}

if get_origin(t) is typing.Union and type(None) in get_args(t):
# Handle optional type. e.g. Optional[int], Optional[dataclass]
# Marshmallow doesn't support union type, so the type here is always an optional type.
# https://github.com/marshmallow-code/marshmallow/issues/1191#issuecomment-480831796
# Note: Union[None, int] is also an optional type, but Marshmallow does not support it.
return self._fix_val_int(get_args(t)[0], val)

if dataclasses.is_dataclass(t):
return self._fix_dataclass_int(t, val) # type: ignore

Expand Down
6 changes: 4 additions & 2 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_list_of_dataclass_getting_python_value():
@dataclass_json
@dataclass()
class Bar(object):
v: typing.Union[int, None]
w: typing.Optional[str]
x: float
y: str
Expand All @@ -163,7 +164,7 @@ class Foo(object):
y: typing.Dict[str, str]
z: Bar

foo = Foo(u=5, v=None, w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False}))
foo = Foo(u=5, v=None, w=1, x=[1], y={"hello": "10"}, z=Bar(v=3, w=None, x=1.0, y="hello", z={"world": False}))
generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct())
lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))]))

Expand All @@ -174,7 +175,6 @@ class Foo(object):
foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema")

guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class])
print("=====")
pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo])
assert isinstance(guessed_pv, list)
assert guessed_pv[0].u == pv[0].u
Expand All @@ -186,7 +186,9 @@ class Foo(object):
assert type(guessed_pv[0].u) == int
assert guessed_pv[0].v is None
assert type(guessed_pv[0].w) == int
assert type(guessed_pv[0].z.v) == int
assert type(guessed_pv[0].z.x) == float
assert guessed_pv[0].z.v == pv[0].z.v
assert guessed_pv[0].z.y == pv[0].z.y
assert guessed_pv[0].z.z == pv[0].z.z
assert pv[0] == dataclass_from_dict(Foo, asdict(guessed_pv[0]))
Expand Down

0 comments on commit 4368e98

Please sign in to comment.