-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stop requiring users to import dataclasses_json
or DataClassJSONMixin
for dataclass
#2279
Changes from 26 commits
2aedeed
58773b3
aad91c3
8a7b06d
c0b5290
fdbf58d
9ee2eb4
ce8edf5
a714b83
ce067a3
dce34ab
399f975
dde160b
610e739
1819532
8dff896
da3768d
2757f34
102085b
30223e4
9b1e617
38986f0
5de561e
bc3f58b
d6d4965
7f759ac
100c838
c1a404d
0555f6e
0a33f53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
from google.protobuf.message import Message | ||
from google.protobuf.struct_pb2 import Struct | ||
from marshmallow_enum import EnumField, LoadDumpOptions | ||
from mashumaro.codecs.json import JSONDecoder, JSONEncoder | ||
from mashumaro.mixins.json import DataClassJSONMixin | ||
from typing_extensions import Annotated, get_args, get_origin | ||
|
||
|
@@ -326,13 +327,8 @@ class Test(DataClassJsonMixin): | |
|
||
def __init__(self): | ||
super().__init__("Object-Dataclass-Transformer", object) | ||
self._serializable_classes = [DataClassJSONMixin, DataClassJsonMixin] | ||
try: | ||
from mashumaro.mixins.orjson import DataClassORJSONMixin | ||
|
||
self._serializable_classes.append(DataClassORJSONMixin) | ||
except ModuleNotFoundError: | ||
pass | ||
self._encoder: Dict[Type, JSONEncoder] = {} | ||
self._decoder: Dict[Type, JSONDecoder] = {} | ||
|
||
def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): | ||
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type | ||
|
@@ -425,11 +421,6 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: | |
f"Type {t} cannot be parsed." | ||
) | ||
|
||
if not self.is_serializable_class(t): | ||
raise AssertionError( | ||
f"Dataclass {t} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be " | ||
f"serialized correctly" | ||
) | ||
schema = None | ||
try: | ||
if issubclass(t, DataClassJsonMixin): | ||
|
@@ -473,9 +464,6 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: | |
|
||
return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts) | ||
|
||
def is_serializable_class(self, class_: Type[T]) -> bool: | ||
return any(issubclass(class_, serializable_class) for serializable_class in self._serializable_classes) | ||
|
||
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
if isinstance(python_val, dict): | ||
json_str = json.dumps(python_val) | ||
|
@@ -486,14 +474,15 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp | |
f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " | ||
f"user defined datatypes in Flytekit" | ||
) | ||
if not self.is_serializable_class(type(python_val)): | ||
raise TypeTransformerFailedError( | ||
f"Dataclass {python_type} should be decorated with @dataclass_json or inherit DataClassJSONMixin to be " | ||
f"serialized correctly" | ||
) | ||
|
||
self._serialize_flyte_type(python_val, python_type) | ||
|
||
json_str = python_val.to_json() # type: ignore | ||
if hasattr(python_val, "to_json"): | ||
json_str = python_val.to_json() | ||
else: | ||
if not self._encoder.get(python_type): | ||
self._encoder[python_type] = JSONEncoder(python_type) | ||
json_str = self._encoder[python_type].encode(python_val) | ||
pingsutw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore | ||
|
||
|
@@ -720,7 +709,7 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: | |
|
||
return val | ||
|
||
def _fix_dataclass_int(self, dc_type: Type[DataClassJsonMixin], dc: DataClassJsonMixin) -> DataClassJsonMixin: | ||
def _fix_dataclass_int(self, dc_type: Type, dc: dataclasses.dataclass) -> dataclasses.dataclass: # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure is here has a better way to write type annotation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can Given how dynamic the code is, I think we have to go with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right! |
||
""" | ||
This is a performance penalty to convert to the right types, but this is expected by the user and hence | ||
needs to be done | ||
|
@@ -729,8 +718,9 @@ def _fix_dataclass_int(self, dc_type: Type[DataClassJsonMixin], dc: DataClassJso | |
# https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#google.protobuf.Value | ||
# Thus we will have to walk the given dataclass and typecast values to int, where expected. | ||
for f in dataclasses.fields(dc_type): | ||
val = dc.__getattribute__(f.name) | ||
dc.__setattr__(f.name, self._fix_val_int(f.type, val)) | ||
val = getattr(dc, f.name) | ||
setattr(dc, f.name, self._fix_val_int(f.type, val)) | ||
pingsutw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return dc | ||
|
||
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: | ||
|
@@ -739,13 +729,15 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: | |
f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for " | ||
"user defined datatypes in Flytekit" | ||
) | ||
if not self.is_serializable_class(expected_python_type): | ||
raise TypeTransformerFailedError( | ||
f"Dataclass {expected_python_type} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be " | ||
f"serialized correctly" | ||
) | ||
|
||
json_str = _json_format.MessageToJson(lv.scalar.generic) | ||
dc = expected_python_type.from_json(json_str) # type: ignore | ||
|
||
if hasattr(expected_python_type, "from_json"): | ||
dc = expected_python_type.from_json(json_str) # type: ignore | ||
else: | ||
if not self._decoder.get(expected_python_type): | ||
self._decoder[expected_python_type] = JSONDecoder(expected_python_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here regular |
||
dc = self._decoder[expected_python_type].decode(json_str) | ||
|
||
dc = self._fix_structured_dataset_type(expected_python_type, dc) | ||
return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
from dataclasses import asdict, dataclass, field | ||
from datetime import timedelta | ||
from enum import Enum, auto | ||
from typing import Optional, Type | ||
from typing import List, Optional, Type | ||
|
||
import mock | ||
import pyarrow as pa | ||
|
@@ -25,13 +25,11 @@ | |
from mashumaro.mixins.orjson import DataClassORJSONMixin | ||
from typing_extensions import Annotated, get_args, get_origin | ||
|
||
from flytekit import kwtypes | ||
from flytekit import dynamic, kwtypes, task, workflow | ||
from flytekit.core.annotation import FlyteAnnotation | ||
from flytekit.core.context_manager import FlyteContext, FlyteContextManager | ||
from flytekit.core.data_persistence import flyte_tmp_dir | ||
from flytekit.core.dynamic_workflow_task import dynamic | ||
from flytekit.core.hash import HashMethod | ||
from flytekit.core.task import task | ||
from flytekit.core.type_engine import ( | ||
DataclassTransformer, | ||
DictTransformer, | ||
|
@@ -2499,16 +2497,29 @@ class DatumMashumaro(DataClassJSONMixin): | |
|
||
@dataclass_json | ||
@dataclass | ||
class Datum(DataClassJSONMixin): | ||
class DatumDataclassJson(DataClassJSONMixin): | ||
x: int | ||
y: Color | ||
|
||
transformer = DataclassTransformer() | ||
@dataclass | ||
class DatumDataclass: | ||
x: int | ||
y: Color | ||
Comment on lines
+2522
to
+2524
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add another test to see what happens with @dataclass
class DatumDataUnion:
path: typing.Union[str, os.PathLike] If the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found that @dataclass
class DatumDataUnion(DataClassJSONMixin):
path: typing.Union[str, os.PathLike]
lt = TypeEngine.to_literal_type(DatumDataUnion)
datum_dataunion = DatumDataUnion(Path("/tmp"))
lv = transformer.to_literal(ctx, datum_dataunion, DatumDataUnion, lt)
gt = transformer.guess_python_type(lt)
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert datum_dataunion.path == pv.path There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found that TypeTransformer can't handle support for typing.Union[str, os.PathLike] or typing.Union[str, FlyteFile]. Support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, thank you for checking. When I suggested In any case, I'm okay with the current scope of this PR. |
||
|
||
transformer = TypeEngine.get_transformer(DatumDataclass) | ||
ctx = FlyteContext.current_context() | ||
|
||
lt = TypeEngine.to_literal_type(Datum) | ||
datum = Datum(5, Color.RED) | ||
lv = transformer.to_literal(ctx, datum, Datum, lt) | ||
lt = TypeEngine.to_literal_type(DatumDataclass) | ||
datum_dataclass = DatumDataclass(5, Color.RED) | ||
lv = transformer.to_literal(ctx, datum_dataclass, DatumDataclass, lt) | ||
gt = transformer.guess_python_type(lt) | ||
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) | ||
assert datum_dataclass.x == pv.x | ||
assert datum_dataclass.y.value == pv.y | ||
|
||
lt = TypeEngine.to_literal_type(DatumDataclassJson) | ||
datum = DatumDataclassJson(5, Color.RED) | ||
lv = transformer.to_literal(ctx, datum, DatumDataclassJson, lt) | ||
gt = transformer.guess_python_type(lt) | ||
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) | ||
assert datum.x == pv.x | ||
|
@@ -2533,6 +2544,44 @@ class Datum(DataClassJSONMixin): | |
assert datum_mashumaro_orjson.z.isoformat() == pv.z | ||
|
||
|
||
def test_dataclass_encoder_and_decoder_registry(): | ||
iterations = 10 | ||
|
||
@dataclass | ||
class Datum: | ||
x: int | ||
y: str | ||
z: dict[int, int] | ||
w: List[int] | ||
|
||
@task | ||
def create_dataclasses() -> List[Datum]: | ||
return [Datum(x=1, y="1", z={1: 1}, w=[1, 1, 1, 1])] | ||
|
||
@task | ||
def concat_dataclasses(x: List[Datum], y: List[Datum]) -> List[Datum]: | ||
return x + y | ||
|
||
@dynamic | ||
def dynamic_wf() -> List[Datum]: | ||
all_dataclasses: List[Datum] = [] | ||
for _ in range(iterations): | ||
data = create_dataclasses() | ||
all_dataclasses = concat_dataclasses(x=all_dataclasses, y=data) | ||
return all_dataclasses | ||
|
||
@workflow | ||
def wf() -> List[Datum]: | ||
return dynamic_wf() | ||
|
||
datum_list = wf() | ||
assert len(datum_list) == iterations | ||
|
||
transformer = TypeEngine.get_transformer(Datum) | ||
assert transformer._encoder.get(Datum) | ||
assert transformer._decoder.get(Datum) | ||
|
||
|
||
def test_ListTransformer_get_sub_type(): | ||
assert ListTransformer.get_sub_type_or_none(typing.List[str]) is str | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wild-endeavor With this check, we are backward compatible. If one used
DataClassJsonMixin
or@dataclass_json
, thento_json
is defined and called, which matches master's behavior. XREF: https://github.com/lidatong/dataclasses-json/blob/8512afc0a87053dbde52af0519c74198fa3bb873/dataclasses_json/api.py#L26@Future-Outlier Can you include a comment here about how this preserves backward compatibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @wild-endeavor and @thomasjpfan
In flytekit-python, we use the function
to_json
to serialize dataclasses andfrom_json
to deserialize them. Previously, we required users in older flytekit releases to add thedataclass_json
decorator or inherit from theDataclassJsonMixin
class.This was because both provided the necessary methods to serialize (convert a dataclass to bytes) and deserialize (convert bytes back to a dataclass).
In this PR, the
mashumaro
module introduces two classes,JSONEncoder
andJSONDecoder
, for serializing and deserializing dataclasses.These new classes eliminate the reliance on the
to_json
andfrom_json
methods.Initially, we used
if hasattr(python_val, "to_json"):
to check for the method's presence.Therefore, introducing these changes will not cause any breaking changes.
This means that dataclasses inheriting from
DataclassJsonMixin
will continue to useto_json
andfrom_json
for serialization and deserialization when using this version of flytekit.to_json REF: https://github.com/flyteorg/flytekit/blob/master/flytekit/core/type_engine.py#L498
from_json REF: https://github.com/flyteorg/flytekit/blob/master/flytekit/core/type_engine.py#L750
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify, I wanted to have a short comment in the code that states how the "to_json" check, helps preserves backward compatibility.