Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop requiring users to import dataclasses_json or DataClassJSONMixin for dataclass #2279

Merged
merged 30 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2aedeed
remove dataclass_json
Future-Outlier Mar 18, 2024
58773b3
lint, use setattr and getattr, remove annotations
Future-Outlier Mar 18, 2024
aad91c3
update tests
Future-Outlier Mar 18, 2024
8a7b06d
remove annotations
Future-Outlier Mar 18, 2024
c0b5290
fix test
Future-Outlier Mar 18, 2024
fdbf58d
nit
Future-Outlier Mar 18, 2024
9ee2eb4
test dolt
Future-Outlier Mar 19, 2024
ce8edf5
trigger ci
Future-Outlier Mar 19, 2024
a714b83
test
Future-Outlier Mar 19, 2024
ce067a3
remove dolt ci test
Future-Outlier Mar 19, 2024
dce34ab
fix tests
Future-Outlier Mar 19, 2024
399f975
use old dolt
Future-Outlier Mar 19, 2024
dde160b
fix tests
Future-Outlier Mar 19, 2024
610e739
revert
Future-Outlier Mar 19, 2024
1819532
remove failed test
Future-Outlier Mar 19, 2024
8dff896
fix dict error
Future-Outlier Mar 19, 2024
da3768d
revert dolt plugin
Future-Outlier Mar 19, 2024
2757f34
revert and merge master
Future-Outlier Mar 19, 2024
102085b
make dolt plugin commented
Future-Outlier Mar 19, 2024
30223e4
update pyproject.toml
Future-Outlier Mar 21, 2024
9b1e617
Update Thomas's and Eduardo's advices
Future-Outlier Mar 22, 2024
38986f0
print
Future-Outlier Mar 22, 2024
5de561e
revert print
Future-Outlier Mar 22, 2024
bc3f58b
add encoder and decoder registry
Future-Outlier Mar 22, 2024
d6d4965
dolt revert
Future-Outlier Mar 22, 2024
7f759ac
add encoder and decoder registry tests
Future-Outlier Mar 22, 2024
100c838
Update Thomas's advice
Future-Outlier Apr 2, 2024
c1a404d
add comments by Thomas's suggestion
Future-Outlier Apr 3, 2024
0555f6e
lint
Future-Outlier Apr 3, 2024
0a33f53
Merge branch 'master' into remove-dataclass_json
Future-Outlier Apr 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 22 additions & 30 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from mashumaro.codecs.json import JSONDecoder, JSONEncoder
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin

Expand Down Expand Up @@ -326,13 +327,8 @@ class Test(DataClassJsonMixin):

def __init__(self):
super().__init__("Object-Dataclass-Transformer", object)
self._serializable_classes = [DataClassJSONMixin, DataClassJsonMixin]
try:
from mashumaro.mixins.orjson import DataClassORJSONMixin

self._serializable_classes.append(DataClassORJSONMixin)
except ModuleNotFoundError:
pass
self._encoder: Dict[Type, JSONEncoder] = {}
self._decoder: Dict[Type, JSONDecoder] = {}

def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type
Expand Down Expand Up @@ -425,11 +421,6 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
f"Type {t} cannot be parsed."
)

if not self.is_serializable_class(t):
raise AssertionError(
f"Dataclass {t} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be "
f"serialized correctly"
)
schema = None
try:
if issubclass(t, DataClassJsonMixin):
Expand Down Expand Up @@ -473,9 +464,6 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:

return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts)

def is_serializable_class(self, class_: Type[T]) -> bool:
return any(issubclass(class_, serializable_class) for serializable_class in self._serializable_classes)

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if isinstance(python_val, dict):
json_str = json.dumps(python_val)
Expand All @@ -486,14 +474,15 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for "
f"user defined datatypes in Flytekit"
)
if not self.is_serializable_class(type(python_val)):
raise TypeTransformerFailedError(
f"Dataclass {python_type} should be decorated with @dataclass_json or inherit DataClassJSONMixin to be "
f"serialized correctly"
)

self._serialize_flyte_type(python_val, python_type)

json_str = python_val.to_json() # type: ignore
if hasattr(python_val, "to_json"):
json_str = python_val.to_json()
Copy link
Member

@thomasjpfan thomasjpfan Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wild-endeavor With this check, we are backward compatible. If one used DataClassJsonMixin or @dataclass_json, then to_json is defined and called, which matches master's behavior. XREF: https://github.com/lidatong/dataclasses-json/blob/8512afc0a87053dbde52af0519c74198fa3bb873/dataclasses_json/api.py#L26

@Future-Outlier Can you include a comment here about how this preserves backward compatibility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @wild-endeavor and @thomasjpfan

In flytekit-python, we use the function to_json to serialize dataclasses and from_json to deserialize them. Previously, we required users in older flytekit releases to add thedataclass_json decorator or inherit from the DataclassJsonMixin class.
This was because both provided the necessary methods to serialize (convert a dataclass to bytes) and deserialize (convert bytes back to a dataclass).

In this PR, the mashumaro module introduces two classes, JSONEncoder and JSONDecoder, for serializing and deserializing dataclasses.
These new classes eliminate the reliance on the to_json and from_json methods.

Initially, we used if hasattr(python_val, "to_json"): to check for the method's presence.
Therefore, introducing these changes will not cause any breaking changes.
This means that dataclasses inheriting from DataclassJsonMixin will continue to use to_json and from_json for serialization and deserialization when using this version of flytekit.

to_json REF: https://github.com/flyteorg/flytekit/blob/master/flytekit/core/type_engine.py#L498
from_json REF: https://github.com/flyteorg/flytekit/blob/master/flytekit/core/type_engine.py#L750

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, I wanted to have a short comment in the code that states how the "to_json" check, helps preserves backward compatibility.

else:
if not self._encoder.get(python_type):
self._encoder[python_type] = JSONEncoder(python_type)
json_str = self._encoder[python_type].encode(python_val)
pingsutw marked this conversation as resolved.
Show resolved Hide resolved

return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore

Expand Down Expand Up @@ -720,7 +709,7 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:

return val

def _fix_dataclass_int(self, dc_type: Type[DataClassJsonMixin], dc: DataClassJsonMixin) -> DataClassJsonMixin:
def _fix_dataclass_int(self, dc_type: Type, dc: dataclasses.dataclass) -> dataclasses.dataclass: # type: ignore
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure is here has a better way to write type annotation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can dc_type be Type[dataclasses.dataclass]? (I think this is required for dataclasses.fields to work.)

Given how dynamic the code is, I think we have to go with dc: typing.Any for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right!
Have updated your advice, thank you.

"""
This is a performance penalty to convert to the right types, but this is expected by the user and hence
needs to be done
Expand All @@ -729,8 +718,9 @@ def _fix_dataclass_int(self, dc_type: Type[DataClassJsonMixin], dc: DataClassJso
# https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#google.protobuf.Value
# Thus we will have to walk the given dataclass and typecast values to int, where expected.
for f in dataclasses.fields(dc_type):
val = dc.__getattribute__(f.name)
dc.__setattr__(f.name, self._fix_val_int(f.type, val))
val = getattr(dc, f.name)
setattr(dc, f.name, self._fix_val_int(f.type, val))
pingsutw marked this conversation as resolved.
Show resolved Hide resolved

return dc

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
Expand All @@ -739,13 +729,15 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for "
"user defined datatypes in Flytekit"
)
if not self.is_serializable_class(expected_python_type):
raise TypeTransformerFailedError(
f"Dataclass {expected_python_type} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be "
f"serialized correctly"
)

json_str = _json_format.MessageToJson(lv.scalar.generic)
dc = expected_python_type.from_json(json_str) # type: ignore

if hasattr(expected_python_type, "from_json"):
dc = expected_python_type.from_json(json_str) # type: ignore
else:
if not self._decoder.get(expected_python_type):
self._decoder[expected_python_type] = JSONDecoder(expected_python_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here regular try: except KeyError:

dc = self._decoder[expected_python_type].decode(json_str)

dc = self._fix_structured_dataset_type(expected_python_type, dc)
return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"markdown-it-py",
"marshmallow-enum",
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.9.1",
"mashumaro>=3.11",
"protobuf!=4.25.0",
"pyarrow",
"pygments",
Expand Down
67 changes: 58 additions & 9 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum, auto
from typing import Optional, Type
from typing import List, Optional, Type

import mock
import pyarrow as pa
Expand All @@ -25,13 +25,11 @@
from mashumaro.mixins.orjson import DataClassORJSONMixin
from typing_extensions import Annotated, get_args, get_origin

from flytekit import kwtypes
from flytekit import dynamic, kwtypes, task, workflow
from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import flyte_tmp_dir
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.hash import HashMethod
from flytekit.core.task import task
from flytekit.core.type_engine import (
DataclassTransformer,
DictTransformer,
Expand Down Expand Up @@ -2499,16 +2497,29 @@ class DatumMashumaro(DataClassJSONMixin):

@dataclass_json
@dataclass
class Datum(DataClassJSONMixin):
class DatumDataclassJson(DataClassJSONMixin):
x: int
y: Color

transformer = DataclassTransformer()
@dataclass
class DatumDataclass:
x: int
y: Color
Comment on lines +2522 to +2524
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add another test to see what happens with typing.Union? Specifically:

@dataclass
class DatumDataUnion:
    path: typing.Union[str, os.PathLike]

If the path started out as a pathlib.Path(...), does it deserialize into a string or a Path object?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that Path object is not serializable, but I will test it with other cases!

@dataclass
    class DatumDataUnion(DataClassJSONMixin):
        path: typing.Union[str, os.PathLike]

lt = TypeEngine.to_literal_type(DatumDataUnion)
    datum_dataunion = DatumDataUnion(Path("/tmp"))
    lv = transformer.to_literal(ctx, datum_dataunion, DatumDataUnion, lt)
    gt = transformer.guess_python_type(lt)
    pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
    assert datum_dataunion.path == pv.path

screenshots:
image

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that TypeTransformer can't handle support for typing.Union[str, os.PathLike] or typing.Union[str, FlyteFile].
(I thought that os.PathLike is decided not supported here.
https://docs.flyte.org/en/latest/api/flytekit/generated/flytekit.types.file.FlyteFile.html#flytekit.types.file.FlyteFile.path

Support typing.Union[str, FlyteFile] in flyte could be an enhancement.
For more details, I need to trace more code to find the core reason why we can support these cases now, but it definitely is not because of this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, thank you for checking. When I suggested typing.Union[str, os.PathLike], I was thinking about FlyteFile. But FlyteFile has it's own type transformer, so it's okay.

In any case, I'm okay with the current scope of this PR.


transformer = TypeEngine.get_transformer(DatumDataclass)
ctx = FlyteContext.current_context()

lt = TypeEngine.to_literal_type(Datum)
datum = Datum(5, Color.RED)
lv = transformer.to_literal(ctx, datum, Datum, lt)
lt = TypeEngine.to_literal_type(DatumDataclass)
datum_dataclass = DatumDataclass(5, Color.RED)
lv = transformer.to_literal(ctx, datum_dataclass, DatumDataclass, lt)
gt = transformer.guess_python_type(lt)
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert datum_dataclass.x == pv.x
assert datum_dataclass.y.value == pv.y

lt = TypeEngine.to_literal_type(DatumDataclassJson)
datum = DatumDataclassJson(5, Color.RED)
lv = transformer.to_literal(ctx, datum, DatumDataclassJson, lt)
gt = transformer.guess_python_type(lt)
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert datum.x == pv.x
Expand All @@ -2533,6 +2544,44 @@ class Datum(DataClassJSONMixin):
assert datum_mashumaro_orjson.z.isoformat() == pv.z


def test_dataclass_encoder_and_decoder_registry():
iterations = 10

@dataclass
class Datum:
x: int
y: str
z: dict[int, int]
w: List[int]

@task
def create_dataclasses() -> List[Datum]:
return [Datum(x=1, y="1", z={1: 1}, w=[1, 1, 1, 1])]

@task
def concat_dataclasses(x: List[Datum], y: List[Datum]) -> List[Datum]:
return x + y

@dynamic
def dynamic_wf() -> List[Datum]:
all_dataclasses: List[Datum] = []
for _ in range(iterations):
data = create_dataclasses()
all_dataclasses = concat_dataclasses(x=all_dataclasses, y=data)
return all_dataclasses

@workflow
def wf() -> List[Datum]:
return dynamic_wf()

datum_list = wf()
assert len(datum_list) == iterations

transformer = TypeEngine.get_transformer(Datum)
assert transformer._encoder.get(Datum)
assert transformer._decoder.get(Datum)


def test_ListTransformer_get_sub_type():
assert ListTransformer.get_sub_type_or_none(typing.List[str]) is str

Expand Down
12 changes: 0 additions & 12 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,18 +1102,6 @@ def wf():
wf()


def test_wf_custom_types_missing_dataclass_json():
with pytest.raises(AssertionError):

@dataclass
class MyCustomType(object):
pass

@task
def t1(a: int) -> MyCustomType:
return MyCustomType()


def test_wf_custom_types():
@dataclass
class MyCustomType(DataClassJsonMixin):
Expand Down
Loading