From a313210d1d74fd2a5b1d5775cd9f8be170162d81 Mon Sep 17 00:00:00 2001 From: Ophir LOJKINE Date: Tue, 28 Jul 2020 21:57:02 +0200 Subject: [PATCH] Improve how union fields are handled (#93) * Improve how union fields are handled See #86 See #67 * Fix the tests --- marshmallow_dataclass/__init__.py | 23 +++-- marshmallow_dataclass/union_field.py | 59 ++++++++++++ setup.py | 2 +- tests/test_field_for_schema.py | 12 ++- tests/test_union.py | 135 +++++++++++++++++++++++++++ 5 files changed, 217 insertions(+), 14 deletions(-) create mode 100644 marshmallow_dataclass/union_field.py create mode 100644 tests/test_union.py diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 16b5b86..f3cded0 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -34,7 +34,6 @@ class User: }) Schema: ClassVar[Type[Schema]] = Schema # For the type checker """ -import dataclasses import inspect from enum import EnumMeta from functools import lru_cache @@ -54,6 +53,7 @@ class User: overload, ) +import dataclasses import marshmallow import typing_inspect @@ -450,13 +450,20 @@ def field_for_schema( metadata["required"] = False return field_for_schema(subtyp, metadata=metadata, base_schema=base_schema) elif typing_inspect.is_union_type(typ): - subfields = [ - field_for_schema(subtyp, metadata=metadata, base_schema=base_schema) - for subtyp in arguments - ] - import marshmallow_union - - return marshmallow_union.Union(subfields, **metadata) + from . import union_field + + return union_field.Union( + [ + ( + subtyp, + field_for_schema( + subtyp, metadata=metadata, base_schema=base_schema + ), + ) + for subtyp in arguments + ], + **metadata, + ) # typing.NewType returns a function with a __supertype__ attribute newtype_supertype = getattr(typ, "__supertype__", None) diff --git a/marshmallow_dataclass/union_field.py b/marshmallow_dataclass/union_field.py new file mode 100644 index 0000000..18ff131 --- /dev/null +++ b/marshmallow_dataclass/union_field.py @@ -0,0 +1,59 @@ +import copy +from typing import List, Tuple, Any + +import typeguard +from marshmallow import fields, Schema, ValidationError + + +class Union(fields.Field): + """A union field, composed other `Field` classes or instances. + This field serializes elements based on their type, with one of its child fields. + + Example: :: + + number_or_string = UnionField([ + (float, fields.Float()), + (str, fields.Str()) + ]) + + :param union_fields: A list of types and their associated field instance. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs): + super().__init__(**kwargs) + self.union_fields = union_fields + + def _bind_to_schema(self, field_name: str, schema: Schema) -> None: + super()._bind_to_schema(field_name, schema) + new_union_fields = [] + for typ, field in self.union_fields: + field = copy.deepcopy(field) + field._bind_to_schema(field_name, self) + new_union_fields.append((typ, field)) + + self.union_fields = new_union_fields + + def _serialize(self, value: Any, attr: str, obj, **kwargs) -> Any: + errors = [] + for typ, field in self.union_fields: + try: + typeguard.check_type(attr, value, typ) + return field._serialize(value, attr, obj, **kwargs) + except TypeError as e: + errors.append(e) + raise TypeError( + f"Unable to serialize value with any of the fields in the union: {errors}" + ) + + def _deserialize(self, value: Any, attr: str, data, **kwargs) -> Any: + errors = [] + for typ, field in self.union_fields: + try: + result = field.deserialize(value, **kwargs) + typeguard.check_type(attr, result, typ) + return result + except (TypeError, ValidationError) as e: + errors.append(e) + + raise ValidationError(errors) diff --git a/setup.py b/setup.py index ee3bbe3..d305633 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ EXTRAS_REQUIRE = { "enum": ["marshmallow-enum"], - "union": ["marshmallow-union"], + "union": ["typeguard"], ':python_version == "3.6"': ["dataclasses"], "lint": ["pre-commit~=1.18"], "docs": ["sphinx"], diff --git a/tests/test_field_for_schema.py b/tests/test_field_for_schema.py index 699bd55..3d0af96 100644 --- a/tests/test_field_for_schema.py +++ b/tests/test_field_for_schema.py @@ -6,7 +6,7 @@ from marshmallow import fields, Schema -from marshmallow_dataclass import field_for_schema, dataclass +from marshmallow_dataclass import field_for_schema, dataclass, union_field class TestFieldForSchema(unittest.TestCase): @@ -89,12 +89,14 @@ class Color(Enum): ) def test_union(self): - import marshmallow_union - self.assertFieldsEqual( field_for_schema(Union[int, str]), - marshmallow_union.Union( - fields=[fields.Integer(), fields.String()], required=True + union_field.Union( + [ + (int, fields.Integer(required=True)), + (str, fields.String(required=True)), + ], + required=True, ), ) diff --git a/tests/test_union.py b/tests/test_union.py new file mode 100644 index 0000000..3363779 --- /dev/null +++ b/tests/test_union.py @@ -0,0 +1,135 @@ +import unittest +from typing import List, Union, Dict + +import marshmallow + +from marshmallow_dataclass import dataclass + + +class TestClassSchema(unittest.TestCase): + def test_simple_union(self): + @dataclass + class IntOrStr: + value: Union[int, str] + + schema = IntOrStr.Schema() + data_in = {"value": "hello"} + loaded = schema.load(data_in) + self.assertEqual(loaded, IntOrStr(value="hello")) + self.assertEqual(schema.dump(loaded), data_in) + + data_in = {"value": 42} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_list_union_builtin(self): + @dataclass + class Dclass2: + value: List[Union[int, str]] + + schema = Dclass2.Schema() + data_in = {"value": ["hello", 42]} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_list_union_object(self): + @dataclass + class Elm1: + elm1: str + + @dataclass + class Elm2: + elm2: str + + @dataclass + class Dclass: + value: List[Union[Elm1, Elm2]] + + schema = Dclass.Schema() + data_in = {"value": [{"elm1": "foo"}, {"elm2": "bar"}]} + load = schema.load(data_in) + self.assertEqual(load, Dclass(value=[Elm1(elm1="foo"), Elm2(elm2="bar")])) + self.assertEqual(schema.dump(load), data_in) + + def test_union_list(self): + @dataclass + class Elm1: + elm1: int + + @dataclass + class Elm2: + elm2: int + + @dataclass + class TestDataClass: + value: Union[List[Elm1], List[Elm2]] + + schema = TestDataClass.Schema() + + data_in = {"value": [{"elm1": 10}, {"elm1": 11}]} + load = schema.load(data_in) + self.assertEqual(load, TestDataClass(value=[Elm1(elm1=10), Elm1(elm1=11)])) + self.assertEqual(schema.dump(load), data_in) + + data_in = {"value": [{"elm2": 10}, {"elm2": 11}]} + load = schema.load(data_in) + self.assertEqual(load, TestDataClass(value=[Elm2(elm2=10), Elm2(elm2=11)])) + self.assertEqual(schema.dump(load), data_in) + + dictwrong_in = {"value": [{"elm1": 10}, {"elm2": 11}]} + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load(dictwrong_in) + + def test_many_nested_union(self): + @dataclass + class Elm1: + elm1: str + + @dataclass + class Dclass: + value: List[Union[List[Union[int, str, Elm1]], int]] + + schema = Dclass.Schema() + data_in = {"value": [42, ["hello", 13, {"elm1": "foo"}]]} + + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"value": [42, ["hello", 13, {"elm2": "foo"}]]}) + + def test_union_dict(self): + @dataclass + class Dclass: + value: List[Union[Dict[int, Union[int, str]], Union[int, str]]] + + schema = Dclass.Schema() + data_in = {"value": [42, {12: 13, 13: "hello"}, "foo"]} + + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"value": [(42,), {12: 13, 13: "hello"}, "foo"]}) + + def test_union_list_dict(self): + @dataclass + class Elm: + elm: int + + @dataclass + class Dclass: + value: Union[List[int], Dict[str, Elm]] + + schema = Dclass.Schema() + + data_in = {"value": {"a": {"elm": 10}, "b": {"elm": 10}}} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + data_in = {"value": [1, 2, 3, 4]} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_union_noschema(self): + @dataclass + class Dclass: + value: Union[int, str] + + schema = Dclass.Schema() + data_in = {"value": [1.4, 4.2]} + with self.assertRaises(marshmallow.exceptions.ValidationError): + self.assertEqual(schema.dump(schema.load(data_in)), data_in)