From 59519a14088d7ece582dca1b46d92f3c251d2945 Mon Sep 17 00:00:00 2001 From: isra17 Date: Fri, 11 Nov 2022 12:36:58 -0500 Subject: [PATCH] Support TypedDict field as Dict[str, Any] --- changelog.d/237.change.rst | 1 + src/desert/_make.py | 25 ++++++++++++++++++++++++- tests/test_make.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 changelog.d/237.change.rst diff --git a/changelog.d/237.change.rst b/changelog.d/237.change.rst new file mode 100644 index 0000000..a604775 --- /dev/null +++ b/changelog.d/237.change.rst @@ -0,0 +1 @@ +Add `TypedDict` subclass support to fields. These are treated the same as `Dict[str, Any]`. diff --git a/src/desert/_make.py b/src/desert/_make.py index 4c4892e..95566c0 100644 --- a/src/desert/_make.py +++ b/src/desert/_make.py @@ -60,6 +60,7 @@ class User: import decimal import enum import inspect +import sys import typing as t import uuid @@ -305,11 +306,18 @@ def field_for_schema( field = field_for_schema(newtype_supertype, default=default) # enumerations - if type(typ) is enum.EnumMeta: + elif type(typ) is enum.EnumMeta: import marshmallow_enum field = marshmallow_enum.EnumField(typ, metadata=metadata) + # TypedDict + elif _is_typeddict(typ): + field = marshmallow.fields.Dict( + keys=marshmallow.fields.String, + values=marshmallow.fields.Raw, + ) + # Nested dataclasses forward_reference = getattr(typ, "__forward_arg__", None) @@ -370,6 +378,21 @@ def _get_field_default( raise TypeError(field) +def _is_typeddict(typ: t.Any) -> bool: + if typing_inspect.typed_dict_keys(typ) is not None: + return True + + # typing_inspect misses some case. + if sys.version_info >= (3, 10): + return t.is_typeddict(typ) + + # python>=3.8; <3.10: Reimplement t.is_typeddict + if sys.version_info >= (3, 8): + return isinstance(typ, t._TypedDictMeta) + + return False + + @attr.frozen class _DesertSentinel: pass diff --git a/tests/test_make.py b/tests/test_make.py index 9a2991c..d8d25ce 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -18,6 +18,11 @@ import desert +typed_dict_classes: t.List[t.Any] = [typing_extensions.TypedDict] +if sys.version_info >= (3, 8): + typed_dict_classes.append(t.TypedDict) + + @attr.frozen(order=False) class DataclassModule: """Implementation of a dataclass module like attr or dataclasses.""" @@ -45,6 +50,13 @@ def dataclass_param(request: _pytest.fixtures.SubRequest) -> DataclassModule: return module +@pytest.fixture( + params=typed_dict_classes, ids=[x.__module__ for x in typed_dict_classes] +) +def typed_dict_class(request: _pytest.fixtures.SubRequest) -> t.Any: + return request.param + + class AssertLoadDumpProtocol(typing_extensions.Protocol): def __call__( self, schema: marshmallow.Schema, loaded: t.Any, dumped: t.Dict[t.Any, t.Any] @@ -437,6 +449,27 @@ class A: assert_dump_load(schema=schema, loaded=loaded, dumped=dumped) +def test_typed_dict( + module: DataclassModule, + assert_dump_load: AssertLoadDumpProtocol, + typed_dict_class: t.Type[t.Any], +) -> None: + """Test dataclasses with basic TypedDict support""" + + class B(typed_dict_class): # type: ignore[valid-type, misc] + x: int + + @module.dataclass + class A: + x: B + + schema = desert.schema_class(A)() + dumped = {"x": {"x": 1}} + loaded = A(x={"x": 1}) # type: ignore[call-arg] + + assert_dump_load(schema=schema, loaded=loaded, dumped=dumped) + + @pytest.mark.xfail( strict=True, reason=(