Skip to content

Commit

Permalink
Improve how union fields are handled (#93)
Browse files Browse the repository at this point in the history
* Improve how union fields are handled

See #86
See #67

* Fix the tests
  • Loading branch information
lovasoa authored Jul 28, 2020
1 parent 2b5d5e5 commit a313210
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 14 deletions.
23 changes: 15 additions & 8 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class User:
})
Schema: ClassVar[Type[Schema]] = Schema # For the type checker
"""
import dataclasses
import inspect
from enum import EnumMeta
from functools import lru_cache
Expand All @@ -54,6 +53,7 @@ class User:
overload,
)

import dataclasses
import marshmallow
import typing_inspect

Expand Down Expand Up @@ -450,13 +450,20 @@ def field_for_schema(
metadata["required"] = False
return field_for_schema(subtyp, metadata=metadata, base_schema=base_schema)
elif typing_inspect.is_union_type(typ):
subfields = [
field_for_schema(subtyp, metadata=metadata, base_schema=base_schema)
for subtyp in arguments
]
import marshmallow_union

return marshmallow_union.Union(subfields, **metadata)
from . import union_field

return union_field.Union(
[
(
subtyp,
field_for_schema(
subtyp, metadata=metadata, base_schema=base_schema
),
)
for subtyp in arguments
],
**metadata,
)

# typing.NewType returns a function with a __supertype__ attribute
newtype_supertype = getattr(typ, "__supertype__", None)
Expand Down
59 changes: 59 additions & 0 deletions marshmallow_dataclass/union_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import copy
from typing import List, Tuple, Any

import typeguard
from marshmallow import fields, Schema, ValidationError


class Union(fields.Field):
"""A union field, composed other `Field` classes or instances.
This field serializes elements based on their type, with one of its child fields.
Example: ::
number_or_string = UnionField([
(float, fields.Float()),
(str, fields.Str())
])
:param union_fields: A list of types and their associated field instance.
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""

def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs):
super().__init__(**kwargs)
self.union_fields = union_fields

def _bind_to_schema(self, field_name: str, schema: Schema) -> None:
super()._bind_to_schema(field_name, schema)
new_union_fields = []
for typ, field in self.union_fields:
field = copy.deepcopy(field)
field._bind_to_schema(field_name, self)
new_union_fields.append((typ, field))

self.union_fields = new_union_fields

def _serialize(self, value: Any, attr: str, obj, **kwargs) -> Any:
errors = []
for typ, field in self.union_fields:
try:
typeguard.check_type(attr, value, typ)
return field._serialize(value, attr, obj, **kwargs)
except TypeError as e:
errors.append(e)
raise TypeError(
f"Unable to serialize value with any of the fields in the union: {errors}"
)

def _deserialize(self, value: Any, attr: str, data, **kwargs) -> Any:
errors = []
for typ, field in self.union_fields:
try:
result = field.deserialize(value, **kwargs)
typeguard.check_type(attr, result, typ)
return result
except (TypeError, ValidationError) as e:
errors.append(e)

raise ValidationError(errors)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

EXTRAS_REQUIRE = {
"enum": ["marshmallow-enum"],
"union": ["marshmallow-union"],
"union": ["typeguard"],
':python_version == "3.6"': ["dataclasses"],
"lint": ["pre-commit~=1.18"],
"docs": ["sphinx"],
Expand Down
12 changes: 7 additions & 5 deletions tests/test_field_for_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from marshmallow import fields, Schema

from marshmallow_dataclass import field_for_schema, dataclass
from marshmallow_dataclass import field_for_schema, dataclass, union_field


class TestFieldForSchema(unittest.TestCase):
Expand Down Expand Up @@ -89,12 +89,14 @@ class Color(Enum):
)

def test_union(self):
import marshmallow_union

self.assertFieldsEqual(
field_for_schema(Union[int, str]),
marshmallow_union.Union(
fields=[fields.Integer(), fields.String()], required=True
union_field.Union(
[
(int, fields.Integer(required=True)),
(str, fields.String(required=True)),
],
required=True,
),
)

Expand Down
135 changes: 135 additions & 0 deletions tests/test_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import unittest
from typing import List, Union, Dict

import marshmallow

from marshmallow_dataclass import dataclass


class TestClassSchema(unittest.TestCase):
def test_simple_union(self):
@dataclass
class IntOrStr:
value: Union[int, str]

schema = IntOrStr.Schema()
data_in = {"value": "hello"}
loaded = schema.load(data_in)
self.assertEqual(loaded, IntOrStr(value="hello"))
self.assertEqual(schema.dump(loaded), data_in)

data_in = {"value": 42}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_list_union_builtin(self):
@dataclass
class Dclass2:
value: List[Union[int, str]]

schema = Dclass2.Schema()
data_in = {"value": ["hello", 42]}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_list_union_object(self):
@dataclass
class Elm1:
elm1: str

@dataclass
class Elm2:
elm2: str

@dataclass
class Dclass:
value: List[Union[Elm1, Elm2]]

schema = Dclass.Schema()
data_in = {"value": [{"elm1": "foo"}, {"elm2": "bar"}]}
load = schema.load(data_in)
self.assertEqual(load, Dclass(value=[Elm1(elm1="foo"), Elm2(elm2="bar")]))
self.assertEqual(schema.dump(load), data_in)

def test_union_list(self):
@dataclass
class Elm1:
elm1: int

@dataclass
class Elm2:
elm2: int

@dataclass
class TestDataClass:
value: Union[List[Elm1], List[Elm2]]

schema = TestDataClass.Schema()

data_in = {"value": [{"elm1": 10}, {"elm1": 11}]}
load = schema.load(data_in)
self.assertEqual(load, TestDataClass(value=[Elm1(elm1=10), Elm1(elm1=11)]))
self.assertEqual(schema.dump(load), data_in)

data_in = {"value": [{"elm2": 10}, {"elm2": 11}]}
load = schema.load(data_in)
self.assertEqual(load, TestDataClass(value=[Elm2(elm2=10), Elm2(elm2=11)]))
self.assertEqual(schema.dump(load), data_in)

dictwrong_in = {"value": [{"elm1": 10}, {"elm2": 11}]}
with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load(dictwrong_in)

def test_many_nested_union(self):
@dataclass
class Elm1:
elm1: str

@dataclass
class Dclass:
value: List[Union[List[Union[int, str, Elm1]], int]]

schema = Dclass.Schema()
data_in = {"value": [42, ["hello", 13, {"elm1": "foo"}]]}

self.assertEqual(schema.dump(schema.load(data_in)), data_in)
with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load({"value": [42, ["hello", 13, {"elm2": "foo"}]]})

def test_union_dict(self):
@dataclass
class Dclass:
value: List[Union[Dict[int, Union[int, str]], Union[int, str]]]

schema = Dclass.Schema()
data_in = {"value": [42, {12: 13, 13: "hello"}, "foo"]}

self.assertEqual(schema.dump(schema.load(data_in)), data_in)

with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load({"value": [(42,), {12: 13, 13: "hello"}, "foo"]})

def test_union_list_dict(self):
@dataclass
class Elm:
elm: int

@dataclass
class Dclass:
value: Union[List[int], Dict[str, Elm]]

schema = Dclass.Schema()

data_in = {"value": {"a": {"elm": 10}, "b": {"elm": 10}}}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

data_in = {"value": [1, 2, 3, 4]}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_union_noschema(self):
@dataclass
class Dclass:
value: Union[int, str]

schema = Dclass.Schema()
data_in = {"value": [1.4, 4.2]}
with self.assertRaises(marshmallow.exceptions.ValidationError):
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

0 comments on commit a313210

Please sign in to comment.