Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

Commit

Permalink
use polyfield for union
Browse files Browse the repository at this point in the history
marshmallow-union is not supported anymore and have some known issues
see:
lovasoa#67

Author advise to switch to marshmallow-polyfield
  • Loading branch information
CedricCabessa committed May 4, 2020
1 parent b974a43 commit 5133691
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 11 deletions.
78 changes: 72 additions & 6 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,46 @@ def field_for_schema(
metadata["required"] = False
return field_for_schema(subtyp, metadata=metadata, base_schema=base_schema)
elif typing_inspect.is_union_type(typ):
subfields = [
field_for_schema(subtyp, metadata=metadata, base_schema=base_schema)
for subtyp in arguments
]
import marshmallow_union

return marshmallow_union.Union(subfields, **metadata)
def deserialization_disambiguation(obj_dict, base_dict):
for subtype in arguments:
@dataclass
class dclass:
field: subtype

try:
candidate = SchemaPolyfieldProxy(dclass)
candidate.check_deserialization(obj_dict)
return candidate
except Exception:
pass
else:
raise marshmallow.exceptions.ValidationError(
"cannot deserialize")

def serialization_disambiguation(obj, base_obj):
for subtype in arguments:
@dataclass
class dclass:
field: subtype

try:
candidate = SchemaPolyfieldProxy(dclass)
candidate.check_serialization(obj)
return candidate
except Exception:
pass
else:
raise marshmallow.exceptions.ValidationError(
"cannot serialize")

import marshmallow_polyfield

return marshmallow_polyfield.PolyField(
deserialization_schema_selector=deserialization_disambiguation,
serialization_schema_selector=serialization_disambiguation,
**metadata,
)

# typing.NewType returns a function with a __supertype__ attribute
newtype_supertype = getattr(typ, "__supertype__", None)
Expand Down Expand Up @@ -534,6 +567,39 @@ def new_type(x: _U):
return new_type


# extends Schema for instance check by polyfield
class SchemaPolyfieldProxy(marshmallow.Schema):
""" Proxy class that implement Schema interface to proxify call to a
dataclass. It is used in order to disambiguate Union subtype.
By convention, we assume the dataclass has one field called "field"
"""
def __init__(self, dataclass):
self.schema = dataclass.Schema()
self.dataclass = dataclass

@property
def context(self):
return self.schema.context

def dump(self, value):
return self.schema.dump(self.dataclass(value))["field"]

def load(self, value):
return self.schema.load({"field": value}).field

def check_deserialization(self, obj_dict):
load = self.schema.load({"field": obj_dict})
dump = self.schema.dump(load)["field"]
if type(dump) != type(obj_dict):
raise TypeError("types do not match")

def check_serialization(self, obj):
dump = self.schema.dump(self.dataclass(obj))
load = self.schema.load(dump)
if type(load.field) != type(obj):
raise TypeError("types do not match")


if __name__ == "__main__":
import doctest

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

EXTRAS_REQUIRE = {
"enum": ["marshmallow-enum"],
"union": ["marshmallow-union"],
"union": ["marshmallow-polyfield"],
':python_version == "3.6"': ["dataclasses"],
"lint": ["pre-commit~=1.18"],
"docs": ["sphinx"],
Expand Down
6 changes: 2 additions & 4 deletions tests/test_field_for_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,11 @@ class Color(Enum):
)

def test_union(self):
import marshmallow_union
import marshmallow_polyfield

self.assertFieldsEqual(
field_for_schema(Union[int, str]),
marshmallow_union.Union(
fields=[fields.Integer(), fields.String()], required=True
),
marshmallow_polyfield.PolyField(required=True),
)

def test_newtype(self):
Expand Down
137 changes: 137 additions & 0 deletions tests/test_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import unittest
import marshmallow
from marshmallow_dataclass import dataclass
from typing import List, Union, Dict


class TestClassSchema(unittest.TestCase):
def test_simple_union(self):
@dataclass
class Dclass:
value: Union[int, str]

schema = Dclass.Schema()
data_in = {"value": "42"}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

data_in = {"value": 42}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_list_union_builtin(self):
@dataclass
class Dclass:
value: List[Union[int, str]]

schema = Dclass.Schema()
data_in = {"value": ["hello", 42]}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_list_union_object(self):
@dataclass
class Elm1:
elm1: str

@dataclass
class Elm2:
elm2: str

@dataclass
class Dclass:
value: List[Union[Elm1, Elm2]]

schema = Dclass.Schema()
data_in = {"value": [{"elm1": "foo"}, {"elm2": "bar"}]}
load = schema.load(data_in)
self.assertIsInstance(load, Dclass)
self.assertIsInstance(load.value[0], Elm1)
self.assertIsInstance(load.value[1], Elm2)
self.assertEqual(schema.dump(load), data_in)

def test_union_list(self):
@dataclass
class Elm1:
elm1: int

@dataclass
class Elm2:
elm2: int

@dataclass
class TestDataClass:
value: Union[List[Elm1], List[Elm2]]

schema = TestDataClass.Schema()

data_in = {"value": [{"elm1": 10}, {"elm1": 11}]}
load = schema.load(data_in)
self.assertIsInstance(load.value[0], Elm1)
self.assertEqual(schema.dump(load), data_in)

data_in = {"value": [{"elm2": 10}, {"elm2": 11}]}
load = schema.load(data_in)
self.assertIsInstance(load.value[0], Elm2)
self.assertEqual(schema.dump(load), data_in)

dictwrong_in = {"value": [{"elm1": 10}, {"elm2": 11}]}
with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load(dictwrong_in)

def test_many_nested_union(self):
@dataclass
class Elm1:
elm1: str

@dataclass
class Dclass:
value: List[Union[List[Union[int, str, Elm1]], int]]

schema = Dclass.Schema()
data_in = {"value": [42, ["hello", 13, {"elm1": "foo"}]]}

self.assertEqual(schema.dump(schema.load(data_in)), data_in)
with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load({"value": [42, ["hello", 13, {"elm2": "foo"}]]})

def test_union_dict(self):
@dataclass
class Dclass:
value: List[Union[Dict[int, Union[int, str]], Union[int, str]]]

schema = Dclass.Schema()
data_in = {"value": [42, {12: 13, 13: "hello"}, "foo"]}

self.assertEqual(schema.dump(schema.load(data_in)), data_in)

with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load({"value": [(42,), {12: 13, 13: "hello"}, "foo"]})

def test_union_list_dict(self):
@dataclass
class Elm:
elm: int

@dataclass
class Dclass:
value: Union[List[int], Dict[str, Elm]]

schema = Dclass.Schema()

data_in = {
"value": {"a": {"elm": 10}, "b": {"elm": 10}},
}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

data_in = {
"value": [1, 2, 3, 4]
}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_union_noschema(self):
@dataclass
class Dclass:
value: Union[int, str]

schema = Dclass.Schema()
data_in = {"value": [1.4, 4.2]}
with self.assertRaises(marshmallow.exceptions.ValidationError):
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

0 comments on commit 5133691

Please sign in to comment.