Skip to content
This repository has been archived by the owner on Oct 10, 2024. It is now read-only.

use polyfield for union #2

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pip3 install marshmallow-dataclass
You may optionally install the following extras:

- `enum`, for translating python enums to [marshmallow-enum](https://github.com/justanr/marshmallow_enum).
- `union`, for translating python [`Union` types](https://docs.python.org/3/library/typing.html#typing.Union) into [`marshmallow-union`](https://pypi.org/project/marshmallow-union/) fields.
- `union`, for translating python [`Union` types](https://docs.python.org/3/library/typing.html#typing.Union) into [`marshmallow-polyfield`](https://pypi.org/project/marshmallow-polyfield/) fields.

```shell
pip3 install "marshmallow-dataclass[enum,union]"
Expand Down
8 changes: 2 additions & 6 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,9 @@ def field_for_schema(
metadata["required"] = False
return field_for_schema(subtyp, metadata=metadata, base_schema=base_schema)
elif typing_inspect.is_union_type(typ):
subfields = [
field_for_schema(subtyp, metadata=metadata, base_schema=base_schema)
for subtyp in arguments
]
import marshmallow_union
from .polyfield import field_for_union
CedricCabessa marked this conversation as resolved.
Show resolved Hide resolved

return marshmallow_union.Union(subfields, **metadata)
return field_for_union(arguments, **metadata)

# typing.NewType returns a function with a __supertype__ attribute
newtype_supertype = getattr(typ, "__supertype__", None)
Expand Down
6 changes: 6 additions & 0 deletions marshmallow_dataclass/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from mypy import nodes
from mypy.plugin import DynamicClassDefContext, Plugin
from mypy.plugins import dataclasses

import marshmallow_dataclass

Expand All @@ -21,6 +22,11 @@ def get_dynamic_class_hook(
return new_type_hook
return None

def get_class_decorator_hook(self, fullname: str):
if fullname == "marshmallow_dataclass.dataclass":
return dataclasses.dataclass_class_maker_callback
return None


def new_type_hook(ctx: DynamicClassDefContext) -> None:
"""
Expand Down
86 changes: 86 additions & 0 deletions marshmallow_dataclass/polyfield.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import marshmallow
from . import dataclass


# extends Schema for instance check by polyfield
class SchemaPolyfieldProxy(marshmallow.Schema):
""" Proxy class that implement Schema interface to proxify call to a
dataclass. It is used in order to disambiguate Union subtype.
By convention, we assume the dataclass has one field called "field"
"""

def __init__(self, dataclass):
self.schema = dataclass.Schema()
self.dataclass = dataclass

@property
def context(self):
return self.schema.context

def dump(self, value):
return self.schema.dump(self.dataclass(value))["field"]

def load(self, value):
return self.schema.load({"field": value}).field

def check_deserialization(self, obj_dict):
load = self.schema.load({"field": obj_dict})
dump = self.schema.dump(load)["field"]
if type(dump) != type(obj_dict):
raise TypeError(
"types do not match ({} is not {})".format(type(dump), type(obj_dict))
)

def check_serialization(self, obj):
dump = self.schema.dump(self.dataclass(obj))
load = self.schema.load(dump)
if type(load.field) != type(obj):
raise TypeError(
"types do not match ({} is not {})".format(type(load.field), type(obj))
)


def field_for_union(arguments, **metadata):
def deserialization_disambiguation(obj_dict, base_dict):
for subtype in arguments:

@dataclass
class dclass:
field: subtype

try:
candidate = SchemaPolyfieldProxy(dclass)
candidate.check_deserialization(obj_dict)
return candidate
except Exception:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be more precise? like (TypeError, ValidationError) (if indeed .dump and .load raise ValidationErrors)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we try every schema, and each of them can trigger an error (we pick the first one that "works")

For example, when we give a string to int schema we get ValueError: invalid literal for int() with base 10: 'hello' or if we give an int to a dict schema AttributeError: 'int' object has no attribute 'keys'

I suppose custom schema can trigger any type of error

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see, ok makes sense

pass
else:
raise marshmallow.exceptions.ValidationError(
"cannot deserialize union %s" % " ".join([str(a) for a in arguments])
)

def serialization_disambiguation(obj, base_obj):
for subtype in arguments:

@dataclass
class dclass:
field: subtype

try:
candidate = SchemaPolyfieldProxy(dclass)
candidate.check_serialization(obj)
return candidate
except Exception:
pass
else:
raise marshmallow.exceptions.ValidationError(
"cannot serialize union %s" % " ".join([str(a) for a in arguments])
)

import marshmallow_polyfield

return marshmallow_polyfield.PolyField(
deserialization_schema_selector=deserialization_disambiguation,
serialization_schema_selector=serialization_disambiguation,
**metadata,
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

EXTRAS_REQUIRE = {
"enum": ["marshmallow-enum"],
"union": ["marshmallow-union"],
"union": ["marshmallow-polyfield"],
':python_version == "3.6"': ["dataclasses"],
"lint": ["pre-commit~=1.18"],
"docs": ["sphinx"],
Expand Down
6 changes: 2 additions & 4 deletions tests/test_field_for_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,11 @@ class Color(Enum):
)

def test_union(self):
import marshmallow_union
import marshmallow_polyfield

self.assertFieldsEqual(
field_for_schema(Union[int, str]),
marshmallow_union.Union(
fields=[fields.Integer(), fields.String()], required=True
),
marshmallow_polyfield.PolyField(required=True),
YBadiss marked this conversation as resolved.
Show resolved Hide resolved
)

def test_newtype(self):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,16 @@

User(id=42, email="[email protected]") # E: Argument "id" to "User" has incompatible type "int"; expected "str"
User(id="a"*32, email=["not", "a", "string"]) # E: Argument "email" to "User" has incompatible type "List[str]"; expected "str"
- case: marshmallow_dataclass_keyword_arguments
mypy_config: |
follow_imports = silent
plugins = marshmallow_dataclass.mypy
main: |
from marshmallow_dataclass import dataclass

@dataclass
class User:
id: int
name: str

user = User(id=4, name='Johny')
133 changes: 133 additions & 0 deletions tests/test_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import unittest
import marshmallow
from marshmallow_dataclass import dataclass
from typing import List, Union, Dict


class TestClassSchema(unittest.TestCase):
def test_simple_union(self):
@dataclass
class Dclass:
value: Union[int, str]

schema = Dclass.Schema()
data_in = {"value": "42"}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

data_in = {"value": 42}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_list_union_builtin(self):
@dataclass
class Dclass:
value: List[Union[int, str]]

schema = Dclass.Schema()
data_in = {"value": ["hello", 42]}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_list_union_object(self):
@dataclass
class Elm1:
elm1: str

@dataclass
class Elm2:
elm2: str

@dataclass
class Dclass:
value: List[Union[Elm1, Elm2]]

schema = Dclass.Schema()
data_in = {"value": [{"elm1": "foo"}, {"elm2": "bar"}]}
load = schema.load(data_in)
self.assertIsInstance(load, Dclass)
self.assertIsInstance(load.value[0], Elm1)
self.assertIsInstance(load.value[1], Elm2)
self.assertEqual(schema.dump(load), data_in)

def test_union_list(self):
@dataclass
class Elm1:
elm1: int

@dataclass
class Elm2:
elm2: int

@dataclass
class TestDataClass:
value: Union[List[Elm1], List[Elm2]]

schema = TestDataClass.Schema()

data_in = {"value": [{"elm1": 10}, {"elm1": 11}]}
load = schema.load(data_in)
self.assertIsInstance(load.value[0], Elm1)
self.assertEqual(schema.dump(load), data_in)

data_in = {"value": [{"elm2": 10}, {"elm2": 11}]}
load = schema.load(data_in)
self.assertIsInstance(load.value[0], Elm2)
self.assertEqual(schema.dump(load), data_in)

dictwrong_in = {"value": [{"elm1": 10}, {"elm2": 11}]}
with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load(dictwrong_in)

def test_many_nested_union(self):
@dataclass
class Elm1:
elm1: str

@dataclass
class Dclass:
value: List[Union[List[Union[int, str, Elm1]], int]]

schema = Dclass.Schema()
data_in = {"value": [42, ["hello", 13, {"elm1": "foo"}]]}

self.assertEqual(schema.dump(schema.load(data_in)), data_in)
with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load({"value": [42, ["hello", 13, {"elm2": "foo"}]]})

def test_union_dict(self):
@dataclass
class Dclass:
value: List[Union[Dict[int, Union[int, str]], Union[int, str]]]

schema = Dclass.Schema()
data_in = {"value": [42, {12: 13, 13: "hello"}, "foo"]}

self.assertEqual(schema.dump(schema.load(data_in)), data_in)

with self.assertRaises(marshmallow.exceptions.ValidationError):
schema.load({"value": [(42,), {12: 13, 13: "hello"}, "foo"]})

def test_union_list_dict(self):
@dataclass
class Elm:
elm: int

@dataclass
class Dclass:
value: Union[List[int], Dict[str, Elm]]

schema = Dclass.Schema()

data_in = {"value": {"a": {"elm": 10}, "b": {"elm": 10}}}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

data_in = {"value": [1, 2, 3, 4]}
self.assertEqual(schema.dump(schema.load(data_in)), data_in)

def test_union_noschema(self):
@dataclass
class Dclass:
value: Union[int, str]

schema = Dclass.Schema()
data_in = {"value": [1.4, 4.2]}
with self.assertRaises(marshmallow.exceptions.ValidationError):
self.assertEqual(schema.dump(schema.load(data_in)), data_in)