From 5133691aff260627b1daf81fc71cc4eff442c473 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Cabessa?= Date: Mon, 27 Apr 2020 10:33:53 +0200 Subject: [PATCH] use polyfield for union marshmallow-union is not supported anymore and have some known issues see: https://github.com/lovasoa/marshmallow_dataclass/issues/67 Author advise to switch to marshmallow-polyfield --- marshmallow_dataclass/__init__.py | 78 +++++++++++++++-- setup.py | 2 +- tests/test_field_for_schema.py | 6 +- tests/test_union.py | 137 ++++++++++++++++++++++++++++++ 4 files changed, 212 insertions(+), 11 deletions(-) create mode 100644 tests/test_union.py diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 9b223a5..ac490e7 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -405,13 +405,46 @@ def field_for_schema( metadata["required"] = False return field_for_schema(subtyp, metadata=metadata, base_schema=base_schema) elif typing_inspect.is_union_type(typ): - subfields = [ - field_for_schema(subtyp, metadata=metadata, base_schema=base_schema) - for subtyp in arguments - ] - import marshmallow_union - return marshmallow_union.Union(subfields, **metadata) + def deserialization_disambiguation(obj_dict, base_dict): + for subtype in arguments: + @dataclass + class dclass: + field: subtype + + try: + candidate = SchemaPolyfieldProxy(dclass) + candidate.check_deserialization(obj_dict) + return candidate + except Exception: + pass + else: + raise marshmallow.exceptions.ValidationError( + "cannot deserialize") + + def serialization_disambiguation(obj, base_obj): + for subtype in arguments: + @dataclass + class dclass: + field: subtype + + try: + candidate = SchemaPolyfieldProxy(dclass) + candidate.check_serialization(obj) + return candidate + except Exception: + pass + else: + raise marshmallow.exceptions.ValidationError( + "cannot serialize") + + import marshmallow_polyfield + + return marshmallow_polyfield.PolyField( + deserialization_schema_selector=deserialization_disambiguation, + serialization_schema_selector=serialization_disambiguation, + **metadata, + ) # typing.NewType returns a function with a __supertype__ attribute newtype_supertype = getattr(typ, "__supertype__", None) @@ -534,6 +567,39 @@ def new_type(x: _U): return new_type +# extends Schema for instance check by polyfield +class SchemaPolyfieldProxy(marshmallow.Schema): + """ Proxy class that implement Schema interface to proxify call to a + dataclass. It is used in order to disambiguate Union subtype. + By convention, we assume the dataclass has one field called "field" + """ + def __init__(self, dataclass): + self.schema = dataclass.Schema() + self.dataclass = dataclass + + @property + def context(self): + return self.schema.context + + def dump(self, value): + return self.schema.dump(self.dataclass(value))["field"] + + def load(self, value): + return self.schema.load({"field": value}).field + + def check_deserialization(self, obj_dict): + load = self.schema.load({"field": obj_dict}) + dump = self.schema.dump(load)["field"] + if type(dump) != type(obj_dict): + raise TypeError("types do not match") + + def check_serialization(self, obj): + dump = self.schema.dump(self.dataclass(obj)) + load = self.schema.load(dump) + if type(load.field) != type(obj): + raise TypeError("types do not match") + + if __name__ == "__main__": import doctest diff --git a/setup.py b/setup.py index a31ffc5..cd7095e 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ EXTRAS_REQUIRE = { "enum": ["marshmallow-enum"], - "union": ["marshmallow-union"], + "union": ["marshmallow-polyfield"], ':python_version == "3.6"': ["dataclasses"], "lint": ["pre-commit~=1.18"], "docs": ["sphinx"], diff --git a/tests/test_field_for_schema.py b/tests/test_field_for_schema.py index c597e87..dcdfbab 100644 --- a/tests/test_field_for_schema.py +++ b/tests/test_field_for_schema.py @@ -89,13 +89,11 @@ class Color(Enum): ) def test_union(self): - import marshmallow_union + import marshmallow_polyfield self.assertFieldsEqual( field_for_schema(Union[int, str]), - marshmallow_union.Union( - fields=[fields.Integer(), fields.String()], required=True - ), + marshmallow_polyfield.PolyField(required=True), ) def test_newtype(self): diff --git a/tests/test_union.py b/tests/test_union.py new file mode 100644 index 0000000..b90061f --- /dev/null +++ b/tests/test_union.py @@ -0,0 +1,137 @@ +import unittest +import marshmallow +from marshmallow_dataclass import dataclass +from typing import List, Union, Dict + + +class TestClassSchema(unittest.TestCase): + def test_simple_union(self): + @dataclass + class Dclass: + value: Union[int, str] + + schema = Dclass.Schema() + data_in = {"value": "42"} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + data_in = {"value": 42} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_list_union_builtin(self): + @dataclass + class Dclass: + value: List[Union[int, str]] + + schema = Dclass.Schema() + data_in = {"value": ["hello", 42]} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_list_union_object(self): + @dataclass + class Elm1: + elm1: str + + @dataclass + class Elm2: + elm2: str + + @dataclass + class Dclass: + value: List[Union[Elm1, Elm2]] + + schema = Dclass.Schema() + data_in = {"value": [{"elm1": "foo"}, {"elm2": "bar"}]} + load = schema.load(data_in) + self.assertIsInstance(load, Dclass) + self.assertIsInstance(load.value[0], Elm1) + self.assertIsInstance(load.value[1], Elm2) + self.assertEqual(schema.dump(load), data_in) + + def test_union_list(self): + @dataclass + class Elm1: + elm1: int + + @dataclass + class Elm2: + elm2: int + + @dataclass + class TestDataClass: + value: Union[List[Elm1], List[Elm2]] + + schema = TestDataClass.Schema() + + data_in = {"value": [{"elm1": 10}, {"elm1": 11}]} + load = schema.load(data_in) + self.assertIsInstance(load.value[0], Elm1) + self.assertEqual(schema.dump(load), data_in) + + data_in = {"value": [{"elm2": 10}, {"elm2": 11}]} + load = schema.load(data_in) + self.assertIsInstance(load.value[0], Elm2) + self.assertEqual(schema.dump(load), data_in) + + dictwrong_in = {"value": [{"elm1": 10}, {"elm2": 11}]} + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load(dictwrong_in) + + def test_many_nested_union(self): + @dataclass + class Elm1: + elm1: str + + @dataclass + class Dclass: + value: List[Union[List[Union[int, str, Elm1]], int]] + + schema = Dclass.Schema() + data_in = {"value": [42, ["hello", 13, {"elm1": "foo"}]]} + + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"value": [42, ["hello", 13, {"elm2": "foo"}]]}) + + def test_union_dict(self): + @dataclass + class Dclass: + value: List[Union[Dict[int, Union[int, str]], Union[int, str]]] + + schema = Dclass.Schema() + data_in = {"value": [42, {12: 13, 13: "hello"}, "foo"]} + + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"value": [(42,), {12: 13, 13: "hello"}, "foo"]}) + + def test_union_list_dict(self): + @dataclass + class Elm: + elm: int + + @dataclass + class Dclass: + value: Union[List[int], Dict[str, Elm]] + + schema = Dclass.Schema() + + data_in = { + "value": {"a": {"elm": 10}, "b": {"elm": 10}}, + } + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + data_in = { + "value": [1, 2, 3, 4] + } + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_union_noschema(self): + @dataclass + class Dclass: + value: Union[int, str] + + schema = Dclass.Schema() + data_in = {"value": [1.4, 4.2]} + with self.assertRaises(marshmallow.exceptions.ValidationError): + self.assertEqual(schema.dump(schema.load(data_in)), data_in)