Skip to content

Commit

Permalink
Backport generic TypedDicts (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood authored May 27, 2022
1 parent 7c28357 commit 1baf0a5
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 32 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
- Add `typing_extensions.NamedTuple`, allowing for generic `NamedTuple`s on
Python <3.11 (backport from python/cpython#92027, by Serhiy Storchaka). Patch
by Alex Waygood (@AlexWaygood).
- Adjust `typing_extensions.TypedDict` to allow for generic `TypedDict`s on
Python <3.11 (backport from python/cpython#27663, by Samodya Abey). Patch by
Alex Waygood (@AlexWaygood).

# Release 4.2.0 (April 17, 2022)

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ Certain objects were changed after they were added to `typing`, and
- `TypedDict` does not store runtime information
about which (if any) keys are non-required in Python 3.8, and does not
honor the `total` keyword with old-style `TypedDict()` in Python
3.9.0 and 3.9.1.
3.9.0 and 3.9.1. `TypedDict` also does not support multiple inheritance
with `typing.Generic` on Python <3.11.
- `get_origin` and `get_args` lack support for `Annotated` in
Python 3.8 and lack support for `ParamSpecArgs` and `ParamSpecKwargs`
in 3.9.
Expand Down
8 changes: 8 additions & 0 deletions src/_typed_dict_test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations

from typing import Generic, Optional, T
from typing_extensions import TypedDict


class FooGeneric(TypedDict, Generic[T]):
a: Optional[T]
138 changes: 138 additions & 0 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing_extensions import assert_type, get_type_hints, get_origin, get_args
from typing_extensions import clear_overloads, get_overloads, overload
from typing_extensions import NamedTuple
from _typed_dict_test_helper import FooGeneric

# Flags used to mark tests that only apply after a specific
# version of the typing module.
Expand Down Expand Up @@ -1664,6 +1665,15 @@ class CustomProtocolWithoutInitB(Protocol):
self.assertEqual(CustomProtocolWithoutInitA.__init__, CustomProtocolWithoutInitB.__init__)


class Point2DGeneric(Generic[T], TypedDict):
a: T
b: T


class BarGeneric(FooGeneric[T], total=False):
b: int


class TypedDictTests(BaseTestCase):

def test_basics_iterable_syntax(self):
Expand Down Expand Up @@ -1769,14 +1779,24 @@ def test_pickle(self):
global EmpD # pickle wants to reference the class by name
EmpD = TypedDict('EmpD', name=str, id=int)
jane = EmpD({'name': 'jane', 'id': 37})
point = Point2DGeneric(a=5.0, b=3.0)
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
# Test non-generic TypedDict
z = pickle.dumps(jane, proto)
jane2 = pickle.loads(z)
self.assertEqual(jane2, jane)
self.assertEqual(jane2, {'name': 'jane', 'id': 37})
ZZ = pickle.dumps(EmpD, proto)
EmpDnew = pickle.loads(ZZ)
self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane)
# and generic TypedDict
y = pickle.dumps(point, proto)
point2 = pickle.loads(y)
self.assertEqual(point, point2)
self.assertEqual(point2, {'a': 5.0, 'b': 3.0})
YY = pickle.dumps(Point2DGeneric, proto)
Point2DGenericNew = pickle.loads(YY)
self.assertEqual(Point2DGenericNew({'a': 5.0, 'b': 3.0}), point)

def test_optional(self):
EmpD = TypedDict('EmpD', name=str, id=int)
Expand Down Expand Up @@ -1854,6 +1874,124 @@ class PointDict3D(PointDict2D, total=False):
assert is_typeddict(PointDict2D) is True
assert is_typeddict(PointDict3D) is True

def test_get_type_hints_generic(self):
self.assertEqual(
get_type_hints(BarGeneric),
{'a': typing.Optional[T], 'b': int}
)

class FooBarGeneric(BarGeneric[int]):
c: str

self.assertEqual(
get_type_hints(FooBarGeneric),
{'a': typing.Optional[T], 'b': int, 'c': str}
)

def test_generic_inheritance(self):
class A(TypedDict, Generic[T]):
a: T

self.assertEqual(A.__bases__, (Generic, dict))
self.assertEqual(A.__orig_bases__, (TypedDict, Generic[T]))
self.assertEqual(A.__mro__, (A, Generic, dict, object))
self.assertEqual(A.__parameters__, (T,))
self.assertEqual(A[str].__parameters__, ())
self.assertEqual(A[str].__args__, (str,))

class A2(Generic[T], TypedDict):
a: T

self.assertEqual(A2.__bases__, (Generic, dict))
self.assertEqual(A2.__orig_bases__, (Generic[T], TypedDict))
self.assertEqual(A2.__mro__, (A2, Generic, dict, object))
self.assertEqual(A2.__parameters__, (T,))
self.assertEqual(A2[str].__parameters__, ())
self.assertEqual(A2[str].__args__, (str,))

class B(A[KT], total=False):
b: KT

self.assertEqual(B.__bases__, (Generic, dict))
self.assertEqual(B.__orig_bases__, (A[KT],))
self.assertEqual(B.__mro__, (B, Generic, dict, object))
self.assertEqual(B.__parameters__, (KT,))
self.assertEqual(B.__total__, False)
self.assertEqual(B.__optional_keys__, frozenset(['b']))
self.assertEqual(B.__required_keys__, frozenset(['a']))

self.assertEqual(B[str].__parameters__, ())
self.assertEqual(B[str].__args__, (str,))
self.assertEqual(B[str].__origin__, B)

class C(B[int]):
c: int

self.assertEqual(C.__bases__, (Generic, dict))
self.assertEqual(C.__orig_bases__, (B[int],))
self.assertEqual(C.__mro__, (C, Generic, dict, object))
self.assertEqual(C.__parameters__, ())
self.assertEqual(C.__total__, True)
self.assertEqual(C.__optional_keys__, frozenset(['b']))
self.assertEqual(C.__required_keys__, frozenset(['a', 'c']))
assert C.__annotations__ == {
'a': T,
'b': KT,
'c': int,
}
with self.assertRaises(TypeError):
C[str]


class Point3D(Point2DGeneric[T], Generic[T, KT]):
c: KT

self.assertEqual(Point3D.__bases__, (Generic, dict))
self.assertEqual(Point3D.__orig_bases__, (Point2DGeneric[T], Generic[T, KT]))
self.assertEqual(Point3D.__mro__, (Point3D, Generic, dict, object))
self.assertEqual(Point3D.__parameters__, (T, KT))
self.assertEqual(Point3D.__total__, True)
self.assertEqual(Point3D.__optional_keys__, frozenset())
self.assertEqual(Point3D.__required_keys__, frozenset(['a', 'b', 'c']))
assert Point3D.__annotations__ == {
'a': T,
'b': T,
'c': KT,
}
self.assertEqual(Point3D[int, str].__origin__, Point3D)

with self.assertRaises(TypeError):
Point3D[int]

with self.assertRaises(TypeError):
class Point3D(Point2DGeneric[T], Generic[KT]):
c: KT

def test_implicit_any_inheritance(self):
class A(TypedDict, Generic[T]):
a: T

class B(A[KT], total=False):
b: KT

class WithImplicitAny(B):
c: int

self.assertEqual(WithImplicitAny.__bases__, (Generic, dict,))
self.assertEqual(WithImplicitAny.__mro__, (WithImplicitAny, Generic, dict, object))
# Consistent with GenericTests.test_implicit_any
self.assertEqual(WithImplicitAny.__parameters__, ())
self.assertEqual(WithImplicitAny.__total__, True)
self.assertEqual(WithImplicitAny.__optional_keys__, frozenset(['b']))
self.assertEqual(WithImplicitAny.__required_keys__, frozenset(['a', 'c']))
assert WithImplicitAny.__annotations__ == {
'a': T,
'b': KT,
'c': int,
}
with self.assertRaises(TypeError):
WithImplicitAny[str]


class AnnotatedTests(BaseTestCase):

Expand Down
81 changes: 50 additions & 31 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,46 @@ def _is_callable_members_only(cls):
return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))


def _maybe_adjust_parameters(cls):
"""Helper function used in Protocol.__init_subclass__ and _TypedDictMeta.__new__.
The contents of this function are very similar
to logic found in typing.Generic.__init_subclass__
on the CPython main branch.
"""
tvars = []
if '__orig_bases__' in cls.__dict__:
tvars = typing._collect_type_vars(cls.__orig_bases__)
# Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
# Also check for and reject plain Generic,
# and reject multiple Generic[...] and/or Protocol[...].
gvars = None
for base in cls.__orig_bases__:
if (isinstance(base, typing._GenericAlias) and
base.__origin__ in (typing.Generic, Protocol)):
# for error messages
the_base = base.__origin__.__name__
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...]"
" and/or Protocol[...] multiple types.")
gvars = base.__parameters__
if gvars is None:
gvars = tvars
else:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
s_args = ', '.join(str(g) for g in gvars)
raise TypeError(f"Some type variables ({s_vars}) are"
f" not listed in {the_base}[{s_args}]")
tvars = gvars
cls.__parameters__ = tuple(tvars)


# 3.8+
if hasattr(typing, 'Protocol'):
Protocol = typing.Protocol
Expand Down Expand Up @@ -477,43 +517,13 @@ def __class_getitem__(cls, params):
return typing._GenericAlias(cls, params)

def __init_subclass__(cls, *args, **kwargs):
tvars = []
if '__orig_bases__' in cls.__dict__:
error = typing.Generic in cls.__orig_bases__
else:
error = typing.Generic in cls.__bases__
if error:
raise TypeError("Cannot inherit from plain Generic")
if '__orig_bases__' in cls.__dict__:
tvars = typing._collect_type_vars(cls.__orig_bases__)
# Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
# Also check for and reject plain Generic,
# and reject multiple Generic[...] and/or Protocol[...].
gvars = None
for base in cls.__orig_bases__:
if (isinstance(base, typing._GenericAlias) and
base.__origin__ in (typing.Generic, Protocol)):
# for error messages
the_base = base.__origin__.__name__
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...]"
" and/or Protocol[...] multiple types.")
gvars = base.__parameters__
if gvars is None:
gvars = tvars
else:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
s_args = ', '.join(str(g) for g in gvars)
raise TypeError(f"Some type variables ({s_vars}) are"
f" not listed in {the_base}[{s_args}]")
tvars = gvars
cls.__parameters__ = tuple(tvars)
_maybe_adjust_parameters(cls)

# Determine if this is a protocol or a concrete subclass.
if not cls.__dict__.get('_is_protocol', None):
Expand Down Expand Up @@ -614,6 +624,7 @@ def __index__(self) -> int:
# keyword with old-style TypedDict(). See https://bugs.python.org/issue42059
# The standard library TypedDict below Python 3.11 does not store runtime
# information about optional and required keys when using Required or NotRequired.
# Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11.
TypedDict = typing.TypedDict
_TypedDictMeta = typing._TypedDictMeta
is_typeddict = typing.is_typeddict
Expand Down Expand Up @@ -696,8 +707,16 @@ def __new__(cls, name, bases, ns, total=True):
# Subclasses and instances of TypedDict return actual dictionaries
# via _dict_new.
ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new
# Don't insert typing.Generic into __bases__ here,
# or Generic.__init_subclass__ will raise TypeError
# in the super().__new__() call.
# Instead, monkey-patch __bases__ onto the class after it's been created.
tp_dict = super().__new__(cls, name, (dict,), ns)

if any(issubclass(base, typing.Generic) for base in bases):
tp_dict.__bases__ = (typing.Generic, dict)
_maybe_adjust_parameters(tp_dict)

annotations = {}
own_annotations = ns.get('__annotations__', {})
msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type"
Expand Down

0 comments on commit 1baf0a5

Please sign in to comment.