From 04ddea813f77df4030337a85f85e42755f401bee Mon Sep 17 00:00:00 2001 From: Dean Gurvitz Date: Tue, 8 Aug 2023 01:37:48 +0300 Subject: [PATCH] Fix: if field has custom decoder, schema takes it into account (#462) --- dataclasses_json/mm.py | 14 ++++++++++++-- tests/entities.py | 10 ++++++++++ tests/test_schema.py | 7 ++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/dataclasses_json/mm.py b/dataclasses_json/mm.py index 61c4045e..8e889222 100644 --- a/dataclasses_json/mm.py +++ b/dataclasses_json/mm.py @@ -253,10 +253,10 @@ def inner(type_, options): origin = getattr(type_, '__origin__', type_) args = [inner(a, {}) for a in getattr(type_, '__args__', []) if a is not type(None)] - + if type_ == Ellipsis: return type_ - + if _is_optional(type_): options["allow_none"] = True if origin is tuple: @@ -318,6 +318,16 @@ def schema(cls, mixin, infer_missing): options['data_key'] = metadata.letter_case(field.name) t = build_type(type_, options, mixin, field, cls) + if field.metadata.get('dataclasses_json', {}).get('decoder'): + # If the field defines a custom decoder, it should completely replace the Marshmallow field's conversion + # logic. + # From Marshmallow's documentation for the _deserialize method: + # "Deserialize value. Concrete :class:`Field` classes should implement this method. " + # This is the method that Field implementations override to perform the actual deserialization logic. + # In this case we specifically override this method instead of `deserialize` to minimize potential + # side effects, and only cancel the actual value deserialization. + t._deserialize = lambda v, *_a, **_kw: v + # if type(t) is not fields.Field: # If we use `isinstance` we would return nothing. if field.type != typing.Optional[CatchAllVar]: schema[field.name] = t diff --git a/tests/entities.py b/tests/entities.py index 61b8af23..480c8a07 100644 --- a/tests/entities.py +++ b/tests/entities.py @@ -271,6 +271,16 @@ class DataClassWithErroneousDecode: id: float = field(metadata=config(decoder=lambda: None)) +def split_str(data: str, *_args, **_kwargs): + return data.split(',') + + +@dataclass_json +@dataclass +class DataClassDifferentTypeDecode: + lst: List[str] = field(default=None, metadata=config(decoder=split_str)) + + @dataclass_json @dataclass class DataClassMappingBadDecode: diff --git a/tests/test_schema.py b/tests/test_schema.py index 1fa302de..1c8a6f55 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -2,7 +2,8 @@ import pytest from .entities import (DataClassDefaultListStr, DataClassDefaultOptionalList, DataClassList, DataClassOptional, - DataClassWithNestedOptional, DataClassWithNestedOptionalAny, DataClassWithNestedAny) + DataClassWithNestedOptional, DataClassWithNestedOptionalAny, DataClassWithNestedAny, + DataClassDifferentTypeDecode) from .test_letter_case import CamelCasePerson, KebabCasePerson, SnakeCasePerson, FieldNamePerson test_do_list = """[{}, {"children": [{"name": "a"}, {"name": "b"}]}]""" @@ -47,3 +48,7 @@ def test_nested_optional_any(self): def test_nested_any_accepts_optional(self): DataClassWithNestedAny.schema().loads(nested_optional_data) assert True + + def test_accounts_for_decode(self): + assert DataClassDifferentTypeDecode.schema().load({'lst': '1,2,3'}) == \ + DataClassDifferentTypeDecode(lst=['1', '2', '3'])