Skip to content

Commit

Permalink
use polyfield for union
Browse files Browse the repository at this point in the history
marshmallow-union is not supported anymore and have some known issues
see:
lovasoa#67

Author advise to switch to marshmallow-polyfield
  • Loading branch information
CedricCabessa committed Apr 27, 2020
1 parent bfda341 commit 5d20f20
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 13 deletions.
63 changes: 57 additions & 6 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,47 @@ def field_for_schema(
metadata["required"] = False
return field_for_schema(subtyp, metadata=metadata, base_schema=base_schema)
elif typing_inspect.is_union_type(typ):
subfields = [
field_for_schema(subtyp, metadata=metadata, base_schema=base_schema)
for subtyp in arguments
]
import marshmallow_union

return marshmallow_union.Union(subfields, **metadata)
def deserialization_disambiguation(obj_dict, base_dict):
for argument in arguments:
field = field_for_schema(argument) # todo: precise options
try:
field.deserialize(obj_dict, base_dict)
# deserialization support field, don't bother creating a
# schema
return field
except marshmallow.exceptions.ValidationError:
pass
else:
raise marshmallow.exceptions.ValidationError(
"cannot deserialize")

def serialization_disambiguation(obj, base_obj):
for subtype in arguments:
@dataclass
class dclass:
field: subtype

try:
schema = dclass.Schema()
dump = schema.dump(dclass(obj))
load = schema.load(dump)
if type(load.field) != type(obj):
continue
return SchemaPolyfieldProxy(dclass)
except Exception:
pass
else:
raise marshmallow.exceptions.ValidationError(
"cannot serialize")

import marshmallow_polyfield

return marshmallow_polyfield.PolyField(
deserialization_schema_selector=deserialization_disambiguation,
serialization_schema_selector=serialization_disambiguation,
**metadata,
)

# typing.NewType returns a function with a __supertype__ attribute
newtype_supertype = getattr(typ, "__supertype__", None)
Expand Down Expand Up @@ -533,6 +567,23 @@ def new_type(x: _U):
return new_type


class SchemaPolyfieldProxy:
""" Proxy class that implement Schema interface to proxify call to a
dataclass. It is used in order to disambiguate Union subtype.
By convention, we assume the dataclass has one field called "field"
"""
def __init__(self, dataclass):
self.schema = dataclass.Schema()
self.dataclass = dataclass

@property
def context(self):
return self.schema.context

def dump(self, value):
return self.schema.dump(self.dataclass(value))["field"]


if __name__ == "__main__":
import doctest

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

VERSION = "7.5.1"
VERSION = "7.5.2-dev1"

CLASSIFIERS = [
"Development Status :: 4 - Beta",
Expand All @@ -14,7 +14,7 @@

EXTRAS_REQUIRE = {
"enum": ["marshmallow-enum"],
"union": ["marshmallow-union"],
"union": ["marshmallow-polyfield"],
':python_version == "3.6"': ["dataclasses"],
"lint": ["pre-commit~=1.18"],
"docs": ["sphinx"],
Expand Down
7 changes: 2 additions & 5 deletions tests/test_field_for_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def attrs(x):
for k, v in x.__dict__.items()
if not k.startswith("_")
}

self.assertEqual(attrs(a), attrs(b))

def test_int(self):
Expand Down Expand Up @@ -87,13 +86,11 @@ class Color(Enum):
)

def test_union(self):
import marshmallow_union
import marshmallow_polyfield

self.assertFieldsEqual(
field_for_schema(Union[int, str]),
marshmallow_union.Union(
fields=[fields.Integer(), fields.String()], required=True
),
marshmallow_polyfield.PolyField(required=True),
)

def test_newtype(self):
Expand Down
137 changes: 137 additions & 0 deletions tests/test_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import unittest
import marshmallow
from marshmallow_dataclass import dataclass
from typing import List, Union, Dict


class TestClassSchema(unittest.TestCase):
def test_simple_union(self):
@dataclass
class Dclass:
value: Union[int, str]

schema = Dclass.Schema()
data_in = {"value": "foo"}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

data_in = {"value": 42}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_list_union_builtin(self):
@dataclass
class Dclass:
value: List[Union[int, str]]

schema = Dclass.Schema()
data_in = {"value": ["hello", 42]}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_list_union_object(self):
@dataclass
class Elm1:
elm1: str

@dataclass
class Elm2:
elm2: str

@dataclass
class Dclass:
value: List[Union[Elm1, Elm2]]

schema = Dclass.Schema()
data_in = {"value": [{"elm1": "foo"}, {"elm2": "bar"}]}
load = schema.load(data_in)
self.assertIsInstance(load, Dclass)
self.assertIsInstance(load.value[0], Elm1)
self.assertIsInstance(load.value[1], Elm2)
self.assertEqual(schema.dump(load), data_in)

def test_union_list(self):
@dataclass
class Elm1:
elm1: int

@dataclass
class Elm2:
elm2: int

@dataclass
class TestDataClass:
value: Union[List[Elm1], List[Elm2]]

schema = TestDataClass.Schema()

data_in = {"value": [{"elm1": 10}, {"elm1": 11}]}
load = schema.load(data_in)
self.assertIsInstance(load.value[0], Elm1)
self.assertEqual(schema.dump(load), data_in)

data_in = {"value": [{"elm2": 10}, {"elm2": 11}]}
load = schema.load(data_in)
self.assertIsInstance(load.value[0], Elm2)
self.assertEqual(schema.dump(load), data_in)

dictwrong_in = {"value": [{"elm1": 10}, {"elm2": 11}]}
with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load(dictwrong_in)

def test_many_nested_union(self):
@dataclass
class Elm1:
elm1: str

@dataclass
class Dclass:
value: List[Union[List[Union[int, str, Elm1]], int]]

schema = Dclass.Schema()
data_in = {"value": [42, ["hello", 13, {"elm1": "foo"}]]}

self.assertEqual(schema.dump(schema.load(data_in)), data_in)
with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load({"value": [42, ["hello", 13, {"elm2": "foo"}]]})

def test_union_dict(self):
@dataclass
class Dclass:
value: List[Union[Dict[int, Union[int, str]], Union[int, str]]]

schema = Dclass.Schema()
data_in = {"value": [42, {12: 13, 13: "hello"}, "foo"]}

self.assertEqual(schema.dump(schema.load(data_in)), data_in)

with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load({"value": [(42,), {12: 13, 13: "hello"}, "foo"]})

def test_union_list_dict(self):
@dataclass
class Elm:
elm: int

@dataclass
class Dclass:
value: Union[List[int], Dict[str, Elm]]

schema = Dclass.Schema()

data_in = {
"value": {"a": {"elm": 10}, "b": {"elm": 10}},
}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

data_in = {
"value": [1, 2, 3, 4]
}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_union_noschema(self):
@dataclass
class Dclass:
value: Union[int, str]

schema = Dclass.Schema()
data_in = {"value": [1.4, 4.2]}
with self.assertRaises(marshmallow.exceptions.ValidationError):
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

0 comments on commit 5d20f20

Please sign in to comment.