From 5bb085238c9624568cbfcbc8e219b1aff2ad85a7 Mon Sep 17 00:00:00 2001 From: Vincent Chen <62143443+mao3267@users.noreply.github.com> Date: Thu, 7 Nov 2024 02:12:16 +0800 Subject: [PATCH] Type Mismatching while Serializing Dataclass with Union (#2859) Signed-off-by: mao3267 Signed-off-by: Katrina Rogan --- flytekit/core/type_engine.py | 27 +++++++++++++++++--- tests/flytekit/unit/core/test_dataclass.py | 14 ++++++++++ tests/flytekit/unit/core/test_flytetypes.py | 17 ++++++++++++ tests/flytekit/unit/core/test_type_engine.py | 4 +++ 4 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 tests/flytekit/unit/core/test_flytetypes.py diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 6e3e652307..5102a7df74 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -781,12 +781,33 @@ def t1() -> DC: """ from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile + from flytekit.types.structured import StructuredDataset # Handle Optional if UnionTransformer.is_optional_type(python_type): - if python_val is None: - return None - return self._make_dataclass_serializable(python_val, get_args(python_type)[0]) + + def get_expected_type(python_val: T, types: tuple) -> Type[T | None]: + if len(set(types) & {FlyteFile, FlyteDirectory, StructuredDataset}) > 1: + raise ValueError( + "Cannot have more than one Flyte type in the Union when attempting to use the string shortcut. Please specify the full object (e.g. FlyteFile(...)) instead of just passing a string." + ) + + for t in types: + try: + trans = TypeEngine.get_transformer(t) # type: ignore + if trans: + trans.assert_type(t, python_val) + return t + except Exception: + continue + return type(None) + + # Get the expected type in the Union type + expected_type = type(None) + if python_val is not None: + expected_type = get_expected_type(python_val, get_args(python_type)) # type: ignore + + return self._make_dataclass_serializable(python_val, expected_type) if hasattr(python_type, "__origin__") and get_origin(python_type) is list: if python_val is None: diff --git a/tests/flytekit/unit/core/test_dataclass.py b/tests/flytekit/unit/core/test_dataclass.py index 58dfcd1e45..4e098c254b 100644 --- a/tests/flytekit/unit/core/test_dataclass.py +++ b/tests/flytekit/unit/core/test_dataclass.py @@ -1118,3 +1118,17 @@ def empty_nested_dc_wf() -> NestedFlyteTypes: empty_nested_flyte_types = empty_nested_dc_wf() DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types) + +def test_dataclass_serialize_with_multiple_dataclass_union(): + @dataclass + class A(): + x: int + + @dataclass + class B(): + x: FlyteFile + + b = B(x="s3://my-bucket/my-file") + res = DataclassTransformer()._make_dataclass_serializable(b, Union[None, A, B]) + + assert res.x.path == "s3://my-bucket/my-file" diff --git a/tests/flytekit/unit/core/test_flytetypes.py b/tests/flytekit/unit/core/test_flytetypes.py new file mode 100644 index 0000000000..366c3547c7 --- /dev/null +++ b/tests/flytekit/unit/core/test_flytetypes.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +from flytekit.types.file import FlyteFile +from flytekit.types.structured.structured_dataset import StructuredDataset +from flytekit.core.type_engine import DataclassTransformer +from typing import Union +import pytest +import re + +def test_dataclass_union_with_multiple_flytetypes_error(): + @dataclass + class DC(): + x: Union[None, StructuredDataset, FlyteFile] + + + dc = DC(x="s3://my-bucket/my-file") + with pytest.raises(ValueError, match=re.escape("Cannot have more than one Flyte type in the Union when attempting to use the string shortcut. Please specify the full object (e.g. FlyteFile(...)) instead of just passing a string.")): + DataclassTransformer()._make_dataclass_serializable(dc, DC) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 0db5f10a46..1d6552a7c5 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -957,6 +957,7 @@ class TestFileStruct(DataClassJsonMixin): b: typing.Optional[FlyteFile] b_prime: typing.Optional[FlyteFile] c: typing.Union[FlyteFile, None] + c_prime: typing.Union[None, FlyteFile] d: typing.List[FlyteFile] e: typing.List[typing.Optional[FlyteFile]] e_prime: typing.List[typing.Optional[FlyteFile]] @@ -979,6 +980,7 @@ class TestFileStruct(DataClassJsonMixin): b=f1, b_prime=None, c=f1, + c_prime=f1, d=[f1], e=[f1], e_prime=[None], @@ -1001,6 +1003,7 @@ class TestFileStruct(DataClassJsonMixin): assert dict_obj["b"]["path"] == remote_path assert dict_obj["b_prime"] is None assert dict_obj["c"]["path"] == remote_path + assert dict_obj["c_prime"]["path"] == remote_path assert dict_obj["d"][0]["path"] == remote_path assert dict_obj["e"][0]["path"] == remote_path assert dict_obj["e_prime"][0] is None @@ -1018,6 +1021,7 @@ class TestFileStruct(DataClassJsonMixin): assert o.b.remote_path == ot.b.remote_source assert ot.b_prime is None assert o.c.remote_path == ot.c.remote_source + assert o.c_prime.remote_path == ot.c_prime.remote_source assert o.d[0].remote_path == ot.d[0].remote_source assert o.e[0].remote_path == ot.e[0].remote_source assert o.e_prime == [None]