Skip to content

Commit

Permalink
Handle Optional[FlyteFile] in Dataclass type transformer (#1393)
Browse files Browse the repository at this point in the history
* Add support for Optional to dataclass transformer

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add one more test

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add one more test

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix serialization of optional flyte types

Signed-off-by: Eduardo Apolinario <[email protected]>

Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario and eapolinario authored Dec 25, 2022
1 parent 425f488 commit 242a11c
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 1 deletion.
14 changes: 13 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A
from flytekit.types.schema.types import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset

# Handle Optional
if get_origin(python_type) is typing.Union and type(None) in get_args(python_type):
if python_val is None:
return None
return self._serialize_flyte_type(python_val, get_args(python_type)[0])

if hasattr(python_type, "__origin__") and python_type.__origin__ is list:
return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val]

Expand Down Expand Up @@ -400,12 +406,18 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A
python_val.__setattr__(v.name, self._serialize_flyte_type(val, field_type))
return python_val

def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T:
def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> Optional[T]:
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine

# Handle Optional
if get_origin(expected_python_type) is typing.Union and type(None) in get_args(expected_python_type):
if python_val is None:
return None
return self._deserialize_flyte_type(python_val, get_args(expected_python_type)[0])

if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is list:
return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val] # type: ignore

Expand Down
85 changes: 85 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import timedelta
from enum import Enum

import mock
import pandas as pd
import pyarrow as pa
import pytest
Expand Down Expand Up @@ -569,6 +570,90 @@ def test_dataclass_int_preserving():
assert ot == o


@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data")
def test_optional_flytefile_in_dataclass(mock_upload_dir):
mock_upload_dir.return_value = True

@dataclass_json
@dataclass
class A(object):
a: int

@dataclass_json
@dataclass
class TestFileStruct(object):
a: FlyteFile
b: typing.Optional[FlyteFile]
b_prime: typing.Optional[FlyteFile]
c: typing.Union[FlyteFile, None]
d: typing.List[FlyteFile]
e: typing.List[typing.Optional[FlyteFile]]
e_prime: typing.List[typing.Optional[FlyteFile]]
f: typing.Dict[str, FlyteFile]
g: typing.Dict[str, typing.Optional[FlyteFile]]
g_prime: typing.Dict[str, typing.Optional[FlyteFile]]
h: typing.Optional[FlyteFile] = None
h_prime: typing.Optional[FlyteFile] = None
i: typing.Optional[A] = None
i_prime: typing.Optional[A] = A(a=99)

remote_path = "s3://tmp/file"
with tempfile.TemporaryFile() as f:
f.write(b"abc")
f1 = FlyteFile("f1", remote_path=remote_path)
o = TestFileStruct(
a=f1,
b=f1,
b_prime=None,
c=f1,
d=[f1],
e=[f1],
e_prime=[None],
f={"a": f1},
g={"a": f1},
g_prime={"a": None},
h=f1,
i=A(a=42),
)

ctx = FlyteContext.current_context()
tf = DataclassTransformer()
lt = tf.get_literal_type(TestFileStruct)
lv = tf.to_literal(ctx, o, TestFileStruct, lt)

assert lv.scalar.generic["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["b"].fields["path"].string_value == remote_path
assert lv.scalar.generic["b_prime"] is None
assert lv.scalar.generic["c"].fields["path"].string_value == remote_path
assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path
assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path
assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value"
assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path
assert lv.scalar.generic["g_prime"]["a"] is None
assert lv.scalar.generic["h"].fields["path"].string_value == remote_path
assert lv.scalar.generic["h_prime"] is None
assert lv.scalar.generic["i"].fields["a"].number_value == 42
assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99

ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct)

assert o.a.path == ot.a.remote_source
assert o.b.path == ot.b.remote_source
assert ot.b_prime is None
assert o.c.path == ot.c.remote_source
assert o.d[0].path == ot.d[0].remote_source
assert o.e[0].path == ot.e[0].remote_source
assert o.e_prime == [None]
assert o.f["a"].path == ot.f["a"].remote_source
assert o.g["a"].path == ot.g["a"].remote_source
assert o.g_prime == {"a": None}
assert o.h.path == ot.h.remote_source
assert ot.h_prime is None
assert o.i == ot.i
assert o.i_prime == A(a=99)


def test_flyte_file_in_dataclass():
@dataclass_json
@dataclass
Expand Down

0 comments on commit 242a11c

Please sign in to comment.