diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 403da98..b410208 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -404,13 +404,47 @@ def field_for_schema( metadata["required"] = False return field_for_schema(subtyp, metadata=metadata, base_schema=base_schema) elif typing_inspect.is_union_type(typ): - subfields = [ - field_for_schema(subtyp, metadata=metadata, base_schema=base_schema) - for subtyp in arguments - ] - import marshmallow_union - return marshmallow_union.Union(subfields, **metadata) + def deserialization_disambiguation(obj_dict, base_dict): + for argument in arguments: + field = field_for_schema(argument) # todo: precise options + try: + field.deserialize(obj_dict, base_dict) + # deserialization support field, don't bother creating a + # schema + return field + except marshmallow.exceptions.ValidationError: + pass + else: + raise marshmallow.exceptions.ValidationError( + "cannot deserialize") + + def serialization_disambiguation(obj, base_obj): + for subtype in arguments: + @dataclass + class dclass: + field: subtype + + try: + schema = dclass.Schema() + dump = schema.dump(dclass(obj)) + load = schema.load(dump) + if type(load.field) != type(obj): + continue + return SchemaPolyfieldProxy(dclass) + except Exception: + pass + else: + raise marshmallow.exceptions.ValidationError( + "cannot serialize") + + import marshmallow_polyfield + + return marshmallow_polyfield.PolyField( + deserialization_schema_selector=deserialization_disambiguation, + serialization_schema_selector=serialization_disambiguation, + **metadata, + ) # typing.NewType returns a function with a __supertype__ attribute newtype_supertype = getattr(typ, "__supertype__", None) @@ -533,6 +567,23 @@ def new_type(x: _U): return new_type +class SchemaPolyfieldProxy: + """ Proxy class that implement Schema interface to proxify call to a + dataclass. It is used in order to disambiguate Union subtype. + By convention, we assume the dataclass has one field called "field" + """ + def __init__(self, dataclass): + self.schema = dataclass.Schema() + self.dataclass = dataclass + + @property + def context(self): + return self.schema.context + + def dump(self, value): + return self.schema.dump(self.dataclass(value))["field"] + + if __name__ == "__main__": import doctest diff --git a/setup.py b/setup.py index 7aaad78..246d87b 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup, find_packages -VERSION = "7.5.1" +VERSION = "7.5.2-dev1" CLASSIFIERS = [ "Development Status :: 4 - Beta", @@ -14,7 +14,7 @@ EXTRAS_REQUIRE = { "enum": ["marshmallow-enum"], - "union": ["marshmallow-union"], + "union": ["marshmallow-polyfield"], ':python_version == "3.6"': ["dataclasses"], "lint": ["pre-commit~=1.18"], "docs": ["sphinx"], diff --git a/tests/test_field_for_schema.py b/tests/test_field_for_schema.py index 8f356e9..5e4afcf 100644 --- a/tests/test_field_for_schema.py +++ b/tests/test_field_for_schema.py @@ -19,7 +19,6 @@ def attrs(x): for k, v in x.__dict__.items() if not k.startswith("_") } - self.assertEqual(attrs(a), attrs(b)) def test_int(self): @@ -87,13 +86,11 @@ class Color(Enum): ) def test_union(self): - import marshmallow_union + import marshmallow_polyfield self.assertFieldsEqual( field_for_schema(Union[int, str]), - marshmallow_union.Union( - fields=[fields.Integer(), fields.String()], required=True - ), + marshmallow_polyfield.PolyField(required=True), ) def test_newtype(self): diff --git a/tests/test_union.py b/tests/test_union.py new file mode 100644 index 0000000..b27ddbc --- /dev/null +++ b/tests/test_union.py @@ -0,0 +1,137 @@ +import unittest +import marshmallow +from marshmallow_dataclass import dataclass +from typing import List, Union, Dict + + +class TestClassSchema(unittest.TestCase): + def test_simple_union(self): + @dataclass + class Dclass: + value: Union[int, str] + + schema = Dclass.Schema() + data_in = {"value": "foo"} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + data_in = {"value": 42} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_list_union_builtin(self): + @dataclass + class Dclass: + value: List[Union[int, str]] + + schema = Dclass.Schema() + data_in = {"value": ["hello", 42]} + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_list_union_object(self): + @dataclass + class Elm1: + elm1: str + + @dataclass + class Elm2: + elm2: str + + @dataclass + class Dclass: + value: List[Union[Elm1, Elm2]] + + schema = Dclass.Schema() + data_in = {"value": [{"elm1": "foo"}, {"elm2": "bar"}]} + load = schema.load(data_in) + self.assertIsInstance(load, Dclass) + self.assertIsInstance(load.value[0], Elm1) + self.assertIsInstance(load.value[1], Elm2) + self.assertEqual(schema.dump(load), data_in) + + def test_union_list(self): + @dataclass + class Elm1: + elm1: int + + @dataclass + class Elm2: + elm2: int + + @dataclass + class TestDataClass: + value: Union[List[Elm1], List[Elm2]] + + schema = TestDataClass.Schema() + + data_in = {"value": [{"elm1": 10}, {"elm1": 11}]} + load = schema.load(data_in) + self.assertIsInstance(load.value[0], Elm1) + self.assertEqual(schema.dump(load), data_in) + + data_in = {"value": [{"elm2": 10}, {"elm2": 11}]} + load = schema.load(data_in) + self.assertIsInstance(load.value[0], Elm2) + self.assertEqual(schema.dump(load), data_in) + + dictwrong_in = {"value": [{"elm1": 10}, {"elm2": 11}]} + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load(dictwrong_in) + + def test_many_nested_union(self): + @dataclass + class Elm1: + elm1: str + + @dataclass + class Dclass: + value: List[Union[List[Union[int, str, Elm1]], int]] + + schema = Dclass.Schema() + data_in = {"value": [42, ["hello", 13, {"elm1": "foo"}]]} + + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"value": [42, ["hello", 13, {"elm2": "foo"}]]}) + + def test_union_dict(self): + @dataclass + class Dclass: + value: List[Union[Dict[int, Union[int, str]], Union[int, str]]] + + schema = Dclass.Schema() + data_in = {"value": [42, {12: 13, 13: "hello"}, "foo"]} + + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"value": [(42,), {12: 13, 13: "hello"}, "foo"]}) + + def test_union_list_dict(self): + @dataclass + class Elm: + elm: int + + @dataclass + class Dclass: + value: Union[List[int], Dict[str, Elm]] + + schema = Dclass.Schema() + + data_in = { + "value": {"a": {"elm": 10}, "b": {"elm": 10}}, + } + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + data_in = { + "value": [1, 2, 3, 4] + } + self.assertEqual(schema.dump(schema.load(data_in)), data_in) + + def test_union_noschema(self): + @dataclass + class Dclass: + value: Union[int, str] + + schema = Dclass.Schema() + data_in = {"value": [1.4, 4.2]} + with self.assertRaises(marshmallow.exceptions.ValidationError): + self.assertEqual(schema.dump(schema.load(data_in)), data_in)