Skip to content

Commit

Permalink
Fix serialization of optional flyte types
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario committed Dec 24, 2022
1 parent c193d4e commit 7c20187
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 36 deletions.
6 changes: 6 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A
from flytekit.types.schema.types import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset

# Handle Optional
if get_origin(python_type) is typing.Union and type(None) in get_args(python_type):
if python_val is None:
return None
return self._serialize_flyte_type(python_val, get_args(python_type)[0])

if hasattr(python_type, "__origin__") and python_type.__origin__ is list:
return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val]

Expand Down
94 changes: 58 additions & 36 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import timedelta
from enum import Enum

import mock
import pandas as pd
import pyarrow as pa
import pytest
Expand Down Expand Up @@ -569,7 +570,10 @@ def test_dataclass_int_preserving():
assert ot == o


def test_optional_flytefile_in_dataclass():
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data")
def test_optional_flytefile_in_dataclass(mock_upload_dir):
mock_upload_dir.return_value = True

@dataclass_json
@dataclass
class A(object):
Expand All @@ -594,42 +598,60 @@ class TestFileStruct(object):
i_prime: typing.Optional[A] = A(a=99)

remote_path = "s3://tmp/file"
f1 = FlyteFile(remote_path)
o = TestFileStruct(
a=f1,
b=f1,
b_prime=None,
c=f1,
d=[f1],
e=[f1],
e_prime=[None],
f={"a": f1},
g={"a": f1},
g_prime={"a": None},
h=f1,
i=A(a=42),
)

ctx = FlyteContext.current_context()
tf = DataclassTransformer()
lt = tf.get_literal_type(TestFileStruct)
lv = tf.to_literal(ctx, o, TestFileStruct, lt)
ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct)
with tempfile.TemporaryFile() as f:
f.write(b"abc")
f1 = FlyteFile("f1", remote_path=remote_path)
o = TestFileStruct(
a=f1,
b=f1,
b_prime=None,
c=f1,
d=[f1],
e=[f1],
e_prime=[None],
f={"a": f1},
g={"a": f1},
g_prime={"a": None},
h=f1,
i=A(a=42),
)

assert o.a.path == ot.a.remote_source
assert o.b.path == ot.b.remote_source
assert ot.b_prime is None
assert o.c.path == ot.c.remote_source
assert o.d[0].path == ot.d[0].remote_source
assert o.e[0].path == ot.e[0].remote_source
assert o.e_prime == [None]
assert o.f["a"].path == ot.f["a"].remote_source
assert o.g["a"].path == ot.g["a"].remote_source
assert o.g_prime == {"a": None}
assert o.h.path == ot.h.remote_source
assert ot.h_prime is None
assert o.i == ot.i
assert o.i_prime == A(a=99)
ctx = FlyteContext.current_context()
tf = DataclassTransformer()
lt = tf.get_literal_type(TestFileStruct)
lv = tf.to_literal(ctx, o, TestFileStruct, lt)

assert lv.scalar.generic["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["b"].fields["path"].string_value == remote_path
assert lv.scalar.generic["b_prime"] is None
assert lv.scalar.generic["c"].fields["path"].string_value == remote_path
assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path
assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path
assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value"
assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["g_prime"]["a"] is None
assert lv.scalar.generic["h"].fields["path"].string_value == remote_path
assert lv.scalar.generic["h_prime"] is None
assert lv.scalar.generic["i"].fields["a"].number_value == 42
assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99

ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct)

assert o.a.path == ot.a.remote_source
assert o.b.path == ot.b.remote_source
assert ot.b_prime is None
assert o.c.path == ot.c.remote_source
assert o.d[0].path == ot.d[0].remote_source
assert o.e[0].path == ot.e[0].remote_source
assert o.e_prime == [None]
assert o.f["a"].path == ot.f["a"].remote_source
assert o.g["a"].path == ot.g["a"].remote_source
assert o.g_prime == {"a": None}
assert o.h.path == ot.h.remote_source
assert ot.h_prime is None
assert o.i == ot.i
assert o.i_prime == A(a=99)


def test_flyte_file_in_dataclass():
Expand Down

0 comments on commit 7c20187

Please sign in to comment.