Skip to content

Commit

Permalink
Support TypedDict field as Dict[str, Any]
Browse files Browse the repository at this point in the history
  • Loading branch information
isra17 committed Nov 27, 2022
1 parent ca7b924 commit de8324f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog.d/237.change.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `TypedDict` subclass support to fields. These are treated the same as `Dict[str, Any]`.
25 changes: 24 additions & 1 deletion src/desert/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class User:
import decimal
import enum
import inspect
import sys
import typing as t
import uuid

Expand Down Expand Up @@ -305,11 +306,18 @@ def field_for_schema(
field = field_for_schema(newtype_supertype, default=default)

# enumerations
if type(typ) is enum.EnumMeta:
elif type(typ) is enum.EnumMeta:
import marshmallow_enum

field = marshmallow_enum.EnumField(typ, metadata=metadata)

# TypedDict
elif _is_typeddict(typ):
field = marshmallow.fields.Dict(
keys=marshmallow.fields.String,
values=marshmallow.fields.Raw,
)

# Nested dataclasses
forward_reference = getattr(typ, "__forward_arg__", None)

Expand Down Expand Up @@ -370,6 +378,21 @@ def _get_field_default(
raise TypeError(field)


def _is_typeddict(typ: t.Any) -> bool:
if typing_inspect.typed_dict_keys(typ) is not None:
return True

# typing_inspect misses some case.
if sys.version_info >= (3, 10):
return t.is_typeddict(typ)

# python>=3.8; <3.10: Reimplement t.is_typeddict
if sys.version_info >= (3, 8):
return isinstance(typ, t._TypedDictMeta) # type: ignore[attr-defined]

return False


@attr.frozen
class _DesertSentinel:
pass
Expand Down
33 changes: 33 additions & 0 deletions tests/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
import desert


typed_dict_classes: t.List[t.Type] = [typing_extensions.TypedDict]
if sys.version_info >= (3, 8):
typed_dict_classes.append(t.TypedDict)


@attr.frozen(order=False)
class DataclassModule:
"""Implementation of a dataclass module like attr or dataclasses."""
Expand Down Expand Up @@ -45,6 +50,13 @@ def dataclass_param(request: _pytest.fixtures.SubRequest) -> DataclassModule:
return module


@pytest.fixture(
params=typed_dict_classes, ids=[x.__module__ for x in typed_dict_classes]
)
def typed_dict_class(request: _pytest.fixtures.SubRequest) -> t.Type:
return request.param


class AssertLoadDumpProtocol(typing_extensions.Protocol):
def __call__(
self, schema: marshmallow.Schema, loaded: t.Any, dumped: t.Dict[t.Any, t.Any]
Expand Down Expand Up @@ -437,6 +449,27 @@ class A:
assert_dump_load(schema=schema, loaded=loaded, dumped=dumped)


def test_typed_dict(
module: DataclassModule,
assert_dump_load: AssertLoadDumpProtocol,
typed_dict_class: None,
) -> None:
"""Test dataclasses with basic TypedDict support"""

class B(typed_dict_class):
x: int

@module.dataclass
class A:
x: B

schema = desert.schema_class(A)()
dumped = {"x": {"x": 1}}
loaded = A(x={"x": 1}) # type: ignore[call-arg]

assert_dump_load(schema=schema, loaded=loaded, dumped=dumped)


@pytest.mark.xfail(
strict=True,
reason=(
Expand Down

0 comments on commit de8324f

Please sign in to comment.