From 7c28357e412cef215a7d66c0ef69b568b316678b Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Thu, 26 May 2022 04:24:20 +0100 Subject: [PATCH] Add a backport of generic `NamedTuple`s (#44) Co-authored-by: Jelle Zijlstra --- CHANGELOG.md | 6 + README.md | 3 + src/test_typing_extensions.py | 306 +++++++++++++++++++++++++++++++++- src/typing_extensions.py | 90 ++++++++++ 4 files changed, 403 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa66e55c..b6721cd2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# Unreleased + +- Add `typing_extensions.NamedTuple`, allowing for generic `NamedTuple`s on + Python <3.11 (backport from python/cpython#92027, by Serhiy Storchaka). Patch + by Alex Waygood (@AlexWaygood). + # Release 4.2.0 (April 17, 2022) - Re-export `typing.Unpack` and `typing.TypeVarTuple` on Python 3.11. diff --git a/README.md b/README.md index 55a23185..79112d1c 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,7 @@ This module currently contains the following: - `Counter` - `DefaultDict` - `Deque` + - `NamedTuple` - `NewType` - `NoReturn` - `overload` @@ -121,6 +122,8 @@ Certain objects were changed after they were added to `typing`, and introspectable at runtime. In order to access overloads with `typing_extensions.get_overloads()`, you must use `@typing_extensions.overload`. +- `NamedTuple` was changed in Python 3.11 to allow for multiple inheritance + with `typing.Generic`. There are a few types whose interface was modified between different versions of typing. For example, `typing.Sequence` was modified to diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 7f14f3f9..407a4860 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -5,6 +5,7 @@ import collections from collections import defaultdict import collections.abc +import copy from functools import lru_cache import inspect import pickle @@ -17,7 +18,7 @@ from typing import TypeVar, Optional, Union, Any, AnyStr from typing import T, KT, VT # Not in __all__. from typing import Tuple, List, Dict, Iterable, Iterator, Callable -from typing import Generic, NamedTuple +from typing import Generic from typing import no_type_check import typing_extensions from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict, Self @@ -27,10 +28,12 @@ from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString from typing_extensions import assert_type, get_type_hints, get_origin, get_args from typing_extensions import clear_overloads, get_overloads, overload +from typing_extensions import NamedTuple # Flags used to mark tests that only apply after a specific # version of the typing module. TYPING_3_8_0 = sys.version_info[:3] >= (3, 8, 0) +TYPING_3_9_0 = sys.version_info[:3] >= (3, 9, 0) TYPING_3_10_0 = sys.version_info[:3] >= (3, 10, 0) # 3.11 makes runtime type checks (_type_check) more lenient. @@ -2874,7 +2877,7 @@ def test_typing_extensions_defers_when_possible(self): if sys.version_info < (3, 10): exclude |= {'get_args', 'get_origin'} if sys.version_info < (3, 11): - exclude.add('final') + exclude |= {'final', 'NamedTuple'} for item in typing_extensions.__all__: if item not in exclude and hasattr(typing, item): self.assertIs( @@ -2892,6 +2895,305 @@ def test_typing_extensions_compiles_with_opt(self): self.fail('Module does not compile with optimize=2 (-OO flag).') +class CoolEmployee(NamedTuple): + name: str + cool: int + + +class CoolEmployeeWithDefault(NamedTuple): + name: str + cool: int = 0 + + +class XMeth(NamedTuple): + x: int + + def double(self): + return 2 * self.x + + +class XRepr(NamedTuple): + x: int + y: int = 1 + + def __str__(self): + return f'{self.x} -> {self.y}' + + def __add__(self, other): + return 0 + + +@skipIf(TYPING_3_11_0, "These invariants should all be tested upstream on 3.11+") +class NamedTupleTests(BaseTestCase): + class NestedEmployee(NamedTuple): + name: str + cool: int + + def test_basics(self): + Emp = NamedTuple('Emp', [('name', str), ('id', int)]) + self.assertIsSubclass(Emp, tuple) + joe = Emp('Joe', 42) + jim = Emp(name='Jim', id=1) + self.assertIsInstance(joe, Emp) + self.assertIsInstance(joe, tuple) + self.assertEqual(joe.name, 'Joe') + self.assertEqual(joe.id, 42) + self.assertEqual(jim.name, 'Jim') + self.assertEqual(jim.id, 1) + self.assertEqual(Emp.__name__, 'Emp') + self.assertEqual(Emp._fields, ('name', 'id')) + self.assertEqual(Emp.__annotations__, + collections.OrderedDict([('name', str), ('id', int)])) + + def test_annotation_usage(self): + tim = CoolEmployee('Tim', 9000) + self.assertIsInstance(tim, CoolEmployee) + self.assertIsInstance(tim, tuple) + self.assertEqual(tim.name, 'Tim') + self.assertEqual(tim.cool, 9000) + self.assertEqual(CoolEmployee.__name__, 'CoolEmployee') + self.assertEqual(CoolEmployee._fields, ('name', 'cool')) + self.assertEqual(CoolEmployee.__annotations__, + collections.OrderedDict(name=str, cool=int)) + + def test_annotation_usage_with_default(self): + jelle = CoolEmployeeWithDefault('Jelle') + self.assertIsInstance(jelle, CoolEmployeeWithDefault) + self.assertIsInstance(jelle, tuple) + self.assertEqual(jelle.name, 'Jelle') + self.assertEqual(jelle.cool, 0) + cooler_employee = CoolEmployeeWithDefault('Sjoerd', 1) + self.assertEqual(cooler_employee.cool, 1) + + self.assertEqual(CoolEmployeeWithDefault.__name__, 'CoolEmployeeWithDefault') + self.assertEqual(CoolEmployeeWithDefault._fields, ('name', 'cool')) + self.assertEqual(CoolEmployeeWithDefault.__annotations__, + dict(name=str, cool=int)) + + with self.assertRaisesRegex( + TypeError, + 'Non-default namedtuple field y cannot follow default field x' + ): + class NonDefaultAfterDefault(NamedTuple): + x: int = 3 + y: int + + @skipUnless( + ( + TYPING_3_8_0 + or hasattr(CoolEmployeeWithDefault, '_field_defaults') + ), + '"_field_defaults" attribute was added in a micro version of 3.7' + ) + def test_field_defaults(self): + self.assertEqual(CoolEmployeeWithDefault._field_defaults, dict(cool=0)) + + def test_annotation_usage_with_methods(self): + self.assertEqual(XMeth(1).double(), 2) + self.assertEqual(XMeth(42).x, XMeth(42)[0]) + self.assertEqual(str(XRepr(42)), '42 -> 1') + self.assertEqual(XRepr(1, 2) + XRepr(3), 0) + + bad_overwrite_error_message = 'Cannot overwrite NamedTuple attribute' + + with self.assertRaisesRegex(AttributeError, bad_overwrite_error_message): + class XMethBad(NamedTuple): + x: int + def _fields(self): + return 'no chance for this' + + with self.assertRaisesRegex(AttributeError, bad_overwrite_error_message): + class XMethBad2(NamedTuple): + x: int + def _source(self): + return 'no chance for this as well' + + def test_multiple_inheritance(self): + class A: + pass + with self.assertRaisesRegex( + TypeError, + 'can only inherit from a NamedTuple type and Generic' + ): + class X(NamedTuple, A): + x: int + + with self.assertRaisesRegex( + TypeError, + 'can only inherit from a NamedTuple type and Generic' + ): + class X(NamedTuple, tuple): + x: int + + with self.assertRaisesRegex(TypeError, 'duplicate base class'): + class X(NamedTuple, NamedTuple): + x: int + + class A(NamedTuple): + x: int + with self.assertRaisesRegex( + TypeError, + 'can only inherit from a NamedTuple type and Generic' + ): + class X(NamedTuple, A): + y: str + + def test_generic(self): + class X(NamedTuple, Generic[T]): + x: T + self.assertEqual(X.__bases__, (tuple, Generic)) + self.assertEqual(X.__orig_bases__, (NamedTuple, Generic[T])) + self.assertEqual(X.__mro__, (X, tuple, Generic, object)) + + class Y(Generic[T], NamedTuple): + x: T + self.assertEqual(Y.__bases__, (Generic, tuple)) + self.assertEqual(Y.__orig_bases__, (Generic[T], NamedTuple)) + self.assertEqual(Y.__mro__, (Y, Generic, tuple, object)) + + for G in X, Y: + with self.subTest(type=G): + self.assertEqual(G.__parameters__, (T,)) + A = G[int] + self.assertIs(A.__origin__, G) + self.assertEqual(A.__args__, (int,)) + self.assertEqual(A.__parameters__, ()) + + a = A(3) + self.assertIs(type(a), G) + self.assertEqual(a.x, 3) + + with self.assertRaisesRegex(TypeError, 'Too many parameters'): + G[int, str] + + @skipUnless(TYPING_3_9_0, "tuple.__class_getitem__ was added in 3.9") + def test_non_generic_subscript_py39_plus(self): + # For backward compatibility, subscription works + # on arbitrary NamedTuple types. + class Group(NamedTuple): + key: T + group: list[T] + A = Group[int] + self.assertEqual(A.__origin__, Group) + self.assertEqual(A.__parameters__, ()) + self.assertEqual(A.__args__, (int,)) + a = A(1, [2]) + self.assertIs(type(a), Group) + self.assertEqual(a, (1, [2])) + + @skipIf(TYPING_3_9_0, "Test isn't relevant to 3.9+") + def test_non_generic_subscript_error_message_py38_minus(self): + class Group(NamedTuple): + key: T + group: List[T] + + with self.assertRaisesRegex(TypeError, 'not subscriptable'): + Group[int] + + for attr in ('__args__', '__origin__', '__parameters__'): + with self.subTest(attr=attr): + self.assertFalse(hasattr(Group, attr)) + + def test_namedtuple_keyword_usage(self): + LocalEmployee = NamedTuple("LocalEmployee", name=str, age=int) + nick = LocalEmployee('Nick', 25) + self.assertIsInstance(nick, tuple) + self.assertEqual(nick.name, 'Nick') + self.assertEqual(LocalEmployee.__name__, 'LocalEmployee') + self.assertEqual(LocalEmployee._fields, ('name', 'age')) + self.assertEqual(LocalEmployee.__annotations__, dict(name=str, age=int)) + with self.assertRaisesRegex( + TypeError, + 'Either list of fields or keywords can be provided to NamedTuple, not both' + ): + NamedTuple('Name', [('x', int)], y=str) + + def test_namedtuple_special_keyword_names(self): + NT = NamedTuple("NT", cls=type, self=object, typename=str, fields=list) + self.assertEqual(NT.__name__, 'NT') + self.assertEqual(NT._fields, ('cls', 'self', 'typename', 'fields')) + a = NT(cls=str, self=42, typename='foo', fields=[('bar', tuple)]) + self.assertEqual(a.cls, str) + self.assertEqual(a.self, 42) + self.assertEqual(a.typename, 'foo') + self.assertEqual(a.fields, [('bar', tuple)]) + + def test_empty_namedtuple(self): + NT = NamedTuple('NT') + + class CNT(NamedTuple): + pass # empty body + + for struct in [NT, CNT]: + with self.subTest(struct=struct): + self.assertEqual(struct._fields, ()) + self.assertEqual(struct.__annotations__, {}) + self.assertIsInstance(struct(), struct) + # Attribute was added in a micro version of 3.7 + # and is tested more fully elsewhere + if hasattr(struct, "_field_defaults"): + self.assertEqual(struct._field_defaults, {}) + + def test_namedtuple_errors(self): + with self.assertRaises(TypeError): + NamedTuple.__new__() + with self.assertRaises(TypeError): + NamedTuple() + with self.assertRaises(TypeError): + NamedTuple('Emp', [('name', str)], None) + with self.assertRaisesRegex(ValueError, 'cannot start with an underscore'): + NamedTuple('Emp', [('_name', str)]) + with self.assertRaises(TypeError): + NamedTuple(typename='Emp', name=str, id=int) + + def test_copy_and_pickle(self): + global Emp # pickle wants to reference the class by name + Emp = NamedTuple('Emp', [('name', str), ('cool', int)]) + for cls in Emp, CoolEmployee, self.NestedEmployee: + with self.subTest(cls=cls): + jane = cls('jane', 37) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z = pickle.dumps(jane, proto) + jane2 = pickle.loads(z) + self.assertEqual(jane2, jane) + self.assertIsInstance(jane2, cls) + + jane2 = copy.copy(jane) + self.assertEqual(jane2, jane) + self.assertIsInstance(jane2, cls) + + jane2 = copy.deepcopy(jane) + self.assertEqual(jane2, jane) + self.assertIsInstance(jane2, cls) + + def test_docstring(self): + self.assertEqual(NamedTuple.__doc__, typing.NamedTuple.__doc__) + self.assertIsInstance(NamedTuple.__doc__, str) + + @skipUnless(TYPING_3_8_0, "NamedTuple had a bad signature on <=3.7") + def test_signature_is_same_as_typing_NamedTuple(self): + self.assertEqual(inspect.signature(NamedTuple), inspect.signature(typing.NamedTuple)) + + @skipIf(TYPING_3_8_0, "tests are only relevant to <=3.7") + def test_signature_on_37(self): + self.assertIsInstance(inspect.signature(NamedTuple), inspect.Signature) + self.assertFalse(hasattr(NamedTuple, "__text_signature__")) + + @skipUnless(TYPING_3_9_0, "NamedTuple was a class on 3.8 and lower") + def test_same_as_typing_NamedTuple_39_plus(self): + self.assertEqual( + set(dir(NamedTuple)), + set(dir(typing.NamedTuple)) | {"__text_signature__"} + ) + self.assertIs(type(NamedTuple), type(typing.NamedTuple)) + + @skipIf(TYPING_3_9_0, "tests are only relevant to <=3.8") + def test_same_as_typing_NamedTuple_38_minus(self): + self.assertEqual( + self.NestedEmployee.__annotations__, + self.NestedEmployee._field_types + ) + if __name__ == '__main__': main() diff --git a/src/typing_extensions.py b/src/typing_extensions.py index dc038819..3b9a39cf 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -37,6 +37,7 @@ 'Counter', 'Deque', 'DefaultDict', + 'NamedTuple', 'OrderedDict', 'TypedDict', @@ -1958,3 +1959,92 @@ def decorator(cls_or_fn): if not hasattr(typing, "TypeVarTuple"): typing._collect_type_vars = _collect_type_vars typing._check_generic = _check_generic + + +# Backport typing.NamedTuple as it exists in Python 3.11. +# In 3.11, the ability to define generic `NamedTuple`s was supported. +# This was explicitly disallowed in 3.9-3.10, and only half-worked in <=3.8. +if sys.version_info >= (3, 11): + NamedTuple = typing.NamedTuple +else: + def _caller(): + try: + return sys._getframe(2).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): # For platforms without _getframe() + return None + + def _make_nmtuple(name, types, module, defaults=()): + fields = [n for n, t in types] + annotations = {n: typing._type_check(t, f"field {n} annotation must be a type") + for n, t in types} + nm_tpl = collections.namedtuple(name, fields, + defaults=defaults, module=module) + nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = annotations + # The `_field_types` attribute was removed in 3.9; + # in earlier versions, it is the same as the `__annotations__` attribute + if sys.version_info < (3, 9): + nm_tpl._field_types = annotations + return nm_tpl + + _prohibited_namedtuple_fields = typing._prohibited + _special_namedtuple_fields = frozenset({'__module__', '__name__', '__annotations__'}) + + class _NamedTupleMeta(type): + def __new__(cls, typename, bases, ns): + assert _NamedTuple in bases + for base in bases: + if base is not _NamedTuple and base is not typing.Generic: + raise TypeError( + 'can only inherit from a NamedTuple type and Generic') + bases = tuple(tuple if base is _NamedTuple else base for base in bases) + types = ns.get('__annotations__', {}) + default_names = [] + for field_name in types: + if field_name in ns: + default_names.append(field_name) + elif default_names: + raise TypeError(f"Non-default namedtuple field {field_name} " + f"cannot follow default field" + f"{'s' if len(default_names) > 1 else ''} " + f"{', '.join(default_names)}") + nm_tpl = _make_nmtuple( + typename, types.items(), + defaults=[ns[n] for n in default_names], + module=ns['__module__'] + ) + nm_tpl.__bases__ = bases + if typing.Generic in bases: + class_getitem = typing.Generic.__class_getitem__.__func__ + nm_tpl.__class_getitem__ = classmethod(class_getitem) + # update from user namespace without overriding special namedtuple attributes + for key in ns: + if key in _prohibited_namedtuple_fields: + raise AttributeError("Cannot overwrite NamedTuple attribute " + key) + elif key not in _special_namedtuple_fields and key not in nm_tpl._fields: + setattr(nm_tpl, key, ns[key]) + if typing.Generic in bases: + nm_tpl.__init_subclass__() + return nm_tpl + + def NamedTuple(__typename, __fields=None, **kwargs): + if __fields is None: + __fields = kwargs.items() + elif kwargs: + raise TypeError("Either list of fields or keywords" + " can be provided to NamedTuple, not both") + return _make_nmtuple(__typename, __fields, module=_caller()) + + NamedTuple.__doc__ = typing.NamedTuple.__doc__ + _NamedTuple = type.__new__(_NamedTupleMeta, 'NamedTuple', (), {}) + + # On 3.8+, alter the signature so that it matches typing.NamedTuple. + # The signature of typing.NamedTuple on >=3.8 is invalid syntax in Python 3.7, + # so just leave the signature as it is on 3.7. + if sys.version_info >= (3, 8): + NamedTuple.__text_signature__ = '(typename, fields=None, /, **kwargs)' + + def _namedtuple_mro_entries(bases): + assert NamedTuple in bases + return (_NamedTuple,) + + NamedTuple.__mro_entries__ = _namedtuple_mro_entries