From 2ee377d1d396cc421cc7299a192d345d020de666 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 5 Mar 2022 21:24:34 -0800 Subject: [PATCH 01/30] initial --- Lib/functools.py | 32 ++++++++++++++++++++++++++++++++ Lib/typing.py | 3 +++ 2 files changed, 35 insertions(+) diff --git a/Lib/functools.py b/Lib/functools.py index cd5666dfa71fd0..451a7443803cb3 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -653,6 +653,38 @@ def cache(user_function, /): return lru_cache(maxsize=None)(user_function) +################################################################################ +### Function variant registry +################################################################################ + +# {key: (max lineno, [variant])} +_variant_registry = {} + + +def register_variant(key, variant, *, test) -> None: + """Register a function variant.""" + _variant_registry.setdefault(key, []).append(variant) + + +def get_variants(key): + """Get all function variants for the given key.""" + return _variant_registry.get(key, []) + + +def get_key_for_callable(func): + """Return a key for the given callable. + + This key can be used to register the callable in the variant registry + with register_variant() or to get variants for this callable with get_variants(). + + If no key can be created, return None. + """ + try: + return f"{func.__module__}.{func.__qualname__}" + except AttributeError: + return None + + ################################################################################ ### singledispatch() - single-dispatch generic function decorator ################################################################################ diff --git a/Lib/typing.py b/Lib/typing.py index 6e0c68c842420b..a138b2f3be9b19 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2184,6 +2184,9 @@ def utf8(value: str) -> bytes: ... def utf8(value): # implementation goes here """ + key = functools.get_key_for_callable(func) + if key is not None: + functools.register_variant(key, func) return _overload_dummy From 831b5650936ea83d3531e6f92bc9719161e8e875 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 6 Mar 2022 18:16:03 -0800 Subject: [PATCH 02/30] Implementation, tests, and docs --- Doc/library/functools.rst | 42 ++++++++++++ Doc/library/typing.rst | 3 + Lib/functools.py | 16 ++++- Lib/test/test_functools.py | 64 +++++++++++++++++++ Lib/test/test_typing.py | 40 +++++++++++- Lib/typing.py | 25 ++++++++ .../2022-03-06-18-15-32.bpo-45100.B_lHu0.rst | 2 + 7 files changed, 188 insertions(+), 4 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2022-03-06-18-15-32.bpo-45100.B_lHu0.rst diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst index c78818bfab1a51..9b05becf2fe4d7 100644 --- a/Doc/library/functools.rst +++ b/Doc/library/functools.rst @@ -535,6 +535,8 @@ The :mod:`functools` module defines the following functions: .. versionchanged:: 3.7 The :func:`register` attribute now supports using type annotations. + .. versionchanged:: 3.11 + Implementation functions are now registered using :func:`register_variant`. .. class:: singledispatchmethod(func) @@ -587,6 +589,9 @@ The :mod:`functools` module defines the following functions: .. versionadded:: 3.8 + .. versionchanged:: 3.11 + Implementation functions are now registered using :func:`register_variant`. + .. function:: update_wrapper(wrapper, wrapped, assigned=WRAPPER_ASSIGNMENTS, updated=WRAPPER_UPDATES) @@ -664,6 +669,43 @@ The :mod:`functools` module defines the following functions: would have been ``'wrapper'``, and the docstring of the original :func:`example` would have been lost. +.. function:: get_variants(key) + + Return all registered function variants for this key. Function variants are + objects that represent some subset of the functionality of a function, for + example overloads decorated with :func:`typing.overload` or :func:`singledispatch` + implementation functions. + + Variants are registered by calling :func:`register_variant`. + The *key* argument is a string that uniquely identifies the function and its + variants. It should be the result of a call to :func:`get_key_for_callable`. + + .. versionadded: 3.11 + +.. function:: register_variant(key, variant) + + Register a function variant that can later be retrieved using + :func:`get_variants`. The key should be the result of a call to + :func:`get_key_for_callable`. + + .. versionadded: 3.11 + +.. function:: clear_variants(key=None) + + Clear all registered variants with the given *key*. If *key* is None, clear + all variants. + + .. versionadded: 3.11 + +.. function:: get_key_for_callable(func) + + Return a string key that can be used with :func:`get_variants` and + :func:`register_variant`. *func* must be a :class:`function`, + :class:`classmethod`, :class:`staticmethod`, or similar callable. + If no key can be computed, the function returns None. + + .. versionadded: 3.11 + .. _partial-objects: diff --git a/Doc/library/typing.rst b/Doc/library/typing.rst index bfcbeb8c7e6808..0dc196054d1905 100644 --- a/Doc/library/typing.rst +++ b/Doc/library/typing.rst @@ -2106,6 +2106,9 @@ Functions and decorators See :pep:`484` for details and comparison with other typing semantics. + .. versionchanged:: 3.11 + Overloaded functions are now registered using :func:`functools.register_variant`. + .. decorator:: final A decorator to indicate to type checkers that the decorated method diff --git a/Lib/functools.py b/Lib/functools.py index 451a7443803cb3..a8f8754bd8227c 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -661,7 +661,7 @@ def cache(user_function, /): _variant_registry = {} -def register_variant(key, variant, *, test) -> None: +def register_variant(key, variant): """Register a function variant.""" _variant_registry.setdefault(key, []).append(variant) @@ -671,6 +671,14 @@ def get_variants(key): return _variant_registry.get(key, []) +def clear_variants(key=None): + """Clear all variants for the given key (or all keys).""" + if key is None: + _variant_registry.clear() + else: + _variant_registry.pop(key, None) + + def get_key_for_callable(func): """Return a key for the given callable. @@ -679,6 +687,8 @@ def get_key_for_callable(func): If no key can be created, return None. """ + # classmethod and staticmethod + func = getattr(func, "__func__", func) try: return f"{func.__module__}.{func.__qualname__}" except AttributeError: @@ -922,6 +932,10 @@ def register(cls, func=None): f"{cls!r} is not a class." ) + key = get_key_for_callable(func) + if key is not None: + register_variant(key, func) + if _is_union_type(cls): from typing import get_args diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index abbd50a47f395f..7c7ef74ed5f065 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1890,6 +1890,70 @@ def cached_staticmeth(x, y): return 3 * x + y +class MethodHolder: + @classmethod + def clsmethod(cls): ... + @staticmethod + def stmethod(): ... + def method(self): ... + + +class TestVariantRegistry(unittest.TestCase): + def test_get_key_for_callable(self): + self.assertEqual(functools.get_key_for_callable(len), + "builtins.len") + self.assertEqual(functools.get_key_for_callable(py_cached_func), + f"{__name__}.py_cached_func") + self.assertEqual(functools.get_key_for_callable(MethodHolder.clsmethod), + f"{__name__}.MethodHolder.clsmethod") + self.assertEqual(functools.get_key_for_callable(MethodHolder.stmethod), + f"{__name__}.MethodHolder.stmethod") + self.assertEqual(functools.get_key_for_callable(MethodHolder.method), + f"{__name__}.MethodHolder.method") + + def test_get_variants(self): + key1 = "key1" + key2 = "key2" + obj1 = object() + obj2 = object() + self.assertEqual(functools.get_variants(key1), []) + self.assertEqual(functools.get_variants(key2), []) + + functools.register_variant(key1, obj1) + self.assertEqual(functools.get_variants(key1), [obj1]) + self.assertEqual(functools.get_variants(key2), []) + + functools.register_variant(key1, obj2) + self.assertEqual(functools.get_variants(key1), [obj1, obj2]) + self.assertEqual(functools.get_variants(key2), []) + + def test_clear_variants(self): + key1 = "key1" + key2 = "key2" + obj1 = object() + + functools.register_variant(key1, obj1) + self.assertEqual(functools.get_variants(key1), [obj1]) + self.assertEqual(functools.get_variants(key2), []) + + functools.clear_variants(key2) + self.assertEqual(functools.get_variants(key1), [obj1]) + self.assertEqual(functools.get_variants(key2), []) + + functools.clear_variants(key1) + self.assertEqual(functools.get_variants(key1), []) + self.assertEqual(functools.get_variants(key2), []) + + functools.register_variant(key1, obj1) + functools.register_variant(key2, obj1) + self.assertEqual(functools.get_variants(key1), [obj1]) + self.assertEqual(functools.get_variants(key2), [obj1]) + + functools.clear_variants() + self.assertEqual(functools.get_variants(key1), []) + self.assertEqual(functools.get_variants(key2), []) + + class TestSingleDispatch(unittest.TestCase): def test_simple_overloads(self): @functools.singledispatch diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index bd9920436223ce..f10a9740d11fb3 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -1,6 +1,7 @@ import contextlib import collections from functools import lru_cache +import functools import inspect import pickle import re @@ -9,7 +10,7 @@ from unittest import TestCase, main, skipUnless, skip from copy import copy, deepcopy -from typing import Any, NoReturn, Never, assert_never +from typing import Any, NoReturn, Never, assert_never, overload from typing import TypeVar, AnyStr from typing import T, KT, VT # Not in __all__. from typing import Union, Optional, Literal @@ -2930,8 +2931,8 @@ def fun(x: a): def test_forward_repr(self): self.assertEqual(repr(List['int']), "typing.List[ForwardRef('int')]") - self.assertEqual(repr(List[ForwardRef('int', module='mod')]), - "typing.List[ForwardRef('int', module='mod')]") + self.assertEqual(repr(List[ForwardRef('int', module='mod')]), + "typing.List[ForwardRef('int', module='mod')]") def test_union_forward(self): @@ -3303,6 +3304,39 @@ def blah(): blah() + def test_get_variants(self): + def blah(): + pass + + overload1 = blah + overload(blah) + + def blah(): + pass + + overload2 = blah + overload(blah) + + def blah(): + pass + + key = functools.get_key_for_callable(blah) + self.assertEqual(functools.get_variants(key), [overload1, overload2]) + + def test_get_variants_repeated(self): + for _ in range(2): + def blah(): + pass + + overload_func = blah + overload(blah) + + def blah(): + pass + + key = functools.get_key_for_callable(blah) + self.assertEqual(functools.get_variants(key), [overload_func]) + # Definitions needed for features introduced in Python 3.6 diff --git a/Lib/typing.py b/Lib/typing.py index a138b2f3be9b19..dafa31a37f1466 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2183,13 +2183,38 @@ def utf8(value: bytes) -> bytes: ... def utf8(value: str) -> bytes: ... def utf8(value): # implementation goes here + + Each overload is registered with functools.register_variant and can be + retrieved using functools.get_variants. """ key = functools.get_key_for_callable(func) if key is not None: + + # If we are registering a variant with a lineno below or equal to that of the + # most recent existing variant, we're probably re-creating overloads for a + # function that already exists. In that case, we clear the existing variants + # to avoid leaking memory. + firstlineno = _get_firstlineno(func) + if firstlineno is not None: + existing = functools.get_variants(key) + if existing: + existing_lineno = _get_firstlineno(existing[-1]) + if existing_lineno is not None and firstlineno <= existing_lineno: + functools.clear_variants(key) + functools.register_variant(key, func) return _overload_dummy +def _get_firstlineno(func): + # staticmethod, classmethod + if hasattr(func, "__func__"): + func = func.__func__ + if not hasattr(func, '__code__'): + return None + return func.__code__.co_firstlineno + + def final(f): """A decorator to indicate final methods and final classes. diff --git a/Misc/NEWS.d/next/Library/2022-03-06-18-15-32.bpo-45100.B_lHu0.rst b/Misc/NEWS.d/next/Library/2022-03-06-18-15-32.bpo-45100.B_lHu0.rst new file mode 100644 index 00000000000000..5bf0d7fdbb1927 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2022-03-06-18-15-32.bpo-45100.B_lHu0.rst @@ -0,0 +1,2 @@ +Add a mechanism to register function variants, such as overloads and +singledispatch implementation functions. Patch by Jelle Zijlstra. From f03f8a91599901999e1e6c98b36b0fd77c0a3a34 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 6 Mar 2022 18:18:38 -0800 Subject: [PATCH 03/30] fix versionadded --- Doc/library/functools.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst index 9b05becf2fe4d7..41338bef5a1e89 100644 --- a/Doc/library/functools.rst +++ b/Doc/library/functools.rst @@ -680,7 +680,7 @@ The :mod:`functools` module defines the following functions: The *key* argument is a string that uniquely identifies the function and its variants. It should be the result of a call to :func:`get_key_for_callable`. - .. versionadded: 3.11 + .. versionadded:: 3.11 .. function:: register_variant(key, variant) @@ -688,14 +688,14 @@ The :mod:`functools` module defines the following functions: :func:`get_variants`. The key should be the result of a call to :func:`get_key_for_callable`. - .. versionadded: 3.11 + .. versionadded:: 3.11 .. function:: clear_variants(key=None) Clear all registered variants with the given *key*. If *key* is None, clear all variants. - .. versionadded: 3.11 + .. versionadded:: 3.11 .. function:: get_key_for_callable(func) @@ -704,7 +704,7 @@ The :mod:`functools` module defines the following functions: :class:`classmethod`, :class:`staticmethod`, or similar callable. If no key can be computed, the function returns None. - .. versionadded: 3.11 + .. versionadded:: 3.11 .. _partial-objects: From 7a5b0d1ce6167bf6585cb68b91fa8b001bb133b8 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 7 Mar 2022 20:41:21 -0800 Subject: [PATCH 04/30] make get_key_for_callable private --- Lib/functools.py | 30 +++++++++++------- Lib/test/test_functools.py | 64 +++++++++++++++++++------------------- Lib/test/test_typing.py | 6 ++-- Lib/typing.py | 28 +++++++++-------- 4 files changed, 67 insertions(+), 61 deletions(-) diff --git a/Lib/functools.py b/Lib/functools.py index a8f8754bd8227c..f244fe041a9392 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -657,35 +657,39 @@ def cache(user_function, /): ### Function variant registry ################################################################################ -# {key: (max lineno, [variant])} +# {key: [variant]} _variant_registry = {} -def register_variant(key, variant): +def register_variant(func, variant): """Register a function variant.""" + key = _get_key_for_callable(func) _variant_registry.setdefault(key, []).append(variant) -def get_variants(key): - """Get all function variants for the given key.""" +def get_variants(func): + """Get all function variants for the given function.""" + key = _get_key_for_callable(func) return _variant_registry.get(key, []) -def clear_variants(key=None): - """Clear all variants for the given key (or all keys).""" - if key is None: +def clear_variants(func=None): + """Clear all variants for the given function (or all functions).""" + if func is None: _variant_registry.clear() else: + key = _get_key_for_callable(func) _variant_registry.pop(key, None) -def get_key_for_callable(func): +def _get_key_for_callable(func): """Return a key for the given callable. This key can be used to register the callable in the variant registry with register_variant() or to get variants for this callable with get_variants(). - If no key can be created, return None. + If no key can be created (because the object is not of a supported type), raise + AttributeError. """ # classmethod and staticmethod func = getattr(func, "__func__", func) @@ -856,6 +860,7 @@ def singledispatch(func): registry = {} dispatch_cache = weakref.WeakKeyDictionary() cache_token = None + outer_func = func def dispatch(cls): """generic_func.dispatch(cls) -> @@ -932,9 +937,10 @@ def register(cls, func=None): f"{cls!r} is not a class." ) - key = get_key_for_callable(func) - if key is not None: - register_variant(key, func) + try: + register_variant(outer_func, func) + except AttributeError: + pass if _is_union_type(cls): from typing import get_args diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 7c7ef74ed5f065..d4598df358ffb4 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1900,58 +1900,58 @@ def method(self): ... class TestVariantRegistry(unittest.TestCase): def test_get_key_for_callable(self): - self.assertEqual(functools.get_key_for_callable(len), + self.assertEqual(functools._get_key_for_callable(len), "builtins.len") - self.assertEqual(functools.get_key_for_callable(py_cached_func), + self.assertEqual(functools._get_key_for_callable(py_cached_func), f"{__name__}.py_cached_func") - self.assertEqual(functools.get_key_for_callable(MethodHolder.clsmethod), + self.assertEqual(functools._get_key_for_callable(MethodHolder.clsmethod), f"{__name__}.MethodHolder.clsmethod") - self.assertEqual(functools.get_key_for_callable(MethodHolder.stmethod), + self.assertEqual(functools._get_key_for_callable(MethodHolder.stmethod), f"{__name__}.MethodHolder.stmethod") - self.assertEqual(functools.get_key_for_callable(MethodHolder.method), + self.assertEqual(functools._get_key_for_callable(MethodHolder.method), f"{__name__}.MethodHolder.method") def test_get_variants(self): - key1 = "key1" - key2 = "key2" + def func1(): pass + def func2(): pass obj1 = object() obj2 = object() - self.assertEqual(functools.get_variants(key1), []) - self.assertEqual(functools.get_variants(key2), []) + self.assertEqual(functools.get_variants(func1), []) + self.assertEqual(functools.get_variants(func2), []) - functools.register_variant(key1, obj1) - self.assertEqual(functools.get_variants(key1), [obj1]) - self.assertEqual(functools.get_variants(key2), []) + functools.register_variant(func1, obj1) + self.assertEqual(functools.get_variants(func1), [obj1]) + self.assertEqual(functools.get_variants(func2), []) - functools.register_variant(key1, obj2) - self.assertEqual(functools.get_variants(key1), [obj1, obj2]) - self.assertEqual(functools.get_variants(key2), []) + functools.register_variant(func1, obj2) + self.assertEqual(functools.get_variants(func1), [obj1, obj2]) + self.assertEqual(functools.get_variants(func2), []) def test_clear_variants(self): - key1 = "key1" - key2 = "key2" + def func1(): pass + def func2(): pass obj1 = object() - functools.register_variant(key1, obj1) - self.assertEqual(functools.get_variants(key1), [obj1]) - self.assertEqual(functools.get_variants(key2), []) + functools.register_variant(func1, obj1) + self.assertEqual(functools.get_variants(func1), [obj1]) + self.assertEqual(functools.get_variants(func2), []) - functools.clear_variants(key2) - self.assertEqual(functools.get_variants(key1), [obj1]) - self.assertEqual(functools.get_variants(key2), []) + functools.clear_variants(func2) + self.assertEqual(functools.get_variants(func1), [obj1]) + self.assertEqual(functools.get_variants(func2), []) - functools.clear_variants(key1) - self.assertEqual(functools.get_variants(key1), []) - self.assertEqual(functools.get_variants(key2), []) + functools.clear_variants(func1) + self.assertEqual(functools.get_variants(func1), []) + self.assertEqual(functools.get_variants(func2), []) - functools.register_variant(key1, obj1) - functools.register_variant(key2, obj1) - self.assertEqual(functools.get_variants(key1), [obj1]) - self.assertEqual(functools.get_variants(key2), [obj1]) + functools.register_variant(func1, obj1) + functools.register_variant(func2, obj1) + self.assertEqual(functools.get_variants(func1), [obj1]) + self.assertEqual(functools.get_variants(func2), [obj1]) functools.clear_variants() - self.assertEqual(functools.get_variants(key1), []) - self.assertEqual(functools.get_variants(key2), []) + self.assertEqual(functools.get_variants(func1), []) + self.assertEqual(functools.get_variants(func2), []) class TestSingleDispatch(unittest.TestCase): diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 0dd6fe013d64fc..c600c10ebdebf9 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -3790,8 +3790,7 @@ def blah(): def blah(): pass - key = functools.get_key_for_callable(blah) - self.assertEqual(functools.get_variants(key), [overload1, overload2]) + self.assertEqual(functools.get_variants(blah), [overload1, overload2]) def test_get_variants_repeated(self): for _ in range(2): @@ -3804,8 +3803,7 @@ def blah(): def blah(): pass - key = functools.get_key_for_callable(blah) - self.assertEqual(functools.get_variants(key), [overload_func]) + self.assertEqual(functools.get_variants(blah), [overload_func]) # Definitions needed for features introduced in Python 3.6 diff --git a/Lib/typing.py b/Lib/typing.py index 1afa9482ecb1be..2e8741eb92ca62 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2386,22 +2386,24 @@ def utf8(value): Each overload is registered with functools.register_variant and can be retrieved using functools.get_variants. """ - key = functools.get_key_for_callable(func) - if key is not None: - - # If we are registering a variant with a lineno below or equal to that of the - # most recent existing variant, we're probably re-creating overloads for a - # function that already exists. In that case, we clear the existing variants - # to avoid leaking memory. - firstlineno = _get_firstlineno(func) - if firstlineno is not None: - existing = functools.get_variants(key) - if existing: + try: + existing = functools.get_variants(func) + except AttributeError: + # Not a normal function; ignore. + pass + else: + if existing: + # If we are registering a variant with a lineno below or equal to that of the + # most recent existing variant, we're probably re-creating overloads for a + # function that already exists. In that case, we clear the existing variants + # to avoid leaking memory. + firstlineno = _get_firstlineno(func) + if firstlineno is not None: existing_lineno = _get_firstlineno(existing[-1]) if existing_lineno is not None and firstlineno <= existing_lineno: - functools.clear_variants(key) + functools.clear_variants(func) - functools.register_variant(key, func) + functools.register_variant(func, func) return _overload_dummy From 6998255e7de8856a98fce81d160d70c5b3d9809c Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Tue, 8 Mar 2022 20:54:52 -0800 Subject: [PATCH 05/30] doc updates; remove unnecessary try-except --- Doc/library/functools.rst | 26 +++++++------------------- Lib/functools.py | 5 +---- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst index 41338bef5a1e89..8e78f4bee11680 100644 --- a/Doc/library/functools.rst +++ b/Doc/library/functools.rst @@ -669,43 +669,31 @@ The :mod:`functools` module defines the following functions: would have been ``'wrapper'``, and the docstring of the original :func:`example` would have been lost. -.. function:: get_variants(key) +.. function:: get_variants(func) - Return all registered function variants for this key. Function variants are + Return all registered function variants for this function. Function variants are objects that represent some subset of the functionality of a function, for example overloads decorated with :func:`typing.overload` or :func:`singledispatch` implementation functions. Variants are registered by calling :func:`register_variant`. - The *key* argument is a string that uniquely identifies the function and its - variants. It should be the result of a call to :func:`get_key_for_callable`. .. versionadded:: 3.11 -.. function:: register_variant(key, variant) +.. function:: register_variant(func, variant) - Register a function variant that can later be retrieved using - :func:`get_variants`. The key should be the result of a call to - :func:`get_key_for_callable`. + Register *variant* for function *func* that can later be retrieved using + :func:`get_variants`. .. versionadded:: 3.11 -.. function:: clear_variants(key=None) +.. function:: clear_variants(func=None) - Clear all registered variants with the given *key*. If *key* is None, clear + Clear all registered variants for the given *func*. If *func* is None, clear all variants. .. versionadded:: 3.11 -.. function:: get_key_for_callable(func) - - Return a string key that can be used with :func:`get_variants` and - :func:`register_variant`. *func* must be a :class:`function`, - :class:`classmethod`, :class:`staticmethod`, or similar callable. - If no key can be computed, the function returns None. - - .. versionadded:: 3.11 - .. _partial-objects: diff --git a/Lib/functools.py b/Lib/functools.py index f244fe041a9392..980b541b699e64 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -693,10 +693,7 @@ def _get_key_for_callable(func): """ # classmethod and staticmethod func = getattr(func, "__func__", func) - try: - return f"{func.__module__}.{func.__qualname__}" - except AttributeError: - return None + return f"{func.__module__}.{func.__qualname__}" ################################################################################ From f52b75732829d85695a795ba3b85ad5db678a608 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 27 Mar 2022 16:03:29 -0700 Subject: [PATCH 06/30] rename method --- Lib/test/test_typing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 0642e3ed0a9957..484a781216e4b9 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -3820,7 +3820,9 @@ def blah(): blah() - def test_get_variants(self): + def test_variant_registry(self): + # Test the interaction with the variants registry in + # the functools module. def blah(): pass @@ -3838,7 +3840,7 @@ def blah(): self.assertEqual(functools.get_variants(blah), [overload1, overload2]) - def test_get_variants_repeated(self): + def test_variant_registry_repeated(self): for _ in range(2): def blah(): pass From fc6a92579b30aa21e586f0423630b7c233f6804c Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 27 Mar 2022 16:15:14 -0700 Subject: [PATCH 07/30] Don't store singledispatch in the registry --- Doc/library/functools.rst | 2 +- Lib/functools.py | 26 ++++++++++++++++++-------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst index 8e78f4bee11680..b5345fd1617fa8 100644 --- a/Doc/library/functools.rst +++ b/Doc/library/functools.rst @@ -536,7 +536,7 @@ The :mod:`functools` module defines the following functions: The :func:`register` attribute now supports using type annotations. .. versionchanged:: 3.11 - Implementation functions are now registered using :func:`register_variant`. + Implementation functions can now be retrieved using :func:`get_variants`. .. class:: singledispatchmethod(func) diff --git a/Lib/functools.py b/Lib/functools.py index 980b541b699e64..f2fb47df9c25ce 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -16,7 +16,8 @@ from abc import get_cache_token from collections import namedtuple -# import types, weakref # Deferred to single_dispatch() +import types +# import weakref # Deferred to single_dispatch() from reprlib import recursive_repr from _thread import RLock from types import GenericAlias @@ -670,7 +671,21 @@ def register_variant(func, variant): def get_variants(func): """Get all function variants for the given function.""" key = _get_key_for_callable(func) - return _variant_registry.get(key, []) + variants = list(_variant_registry.get(key, [])) + + # We directly retrieve variants from the singledispatch + # and singledispatchmethod registries. + if isinstance(func, singledispatchmethod): + variants += func.dispatcher.registry.values() + else: + try: + registry = func.registry + except AttributeError: + pass + else: + if isinstance(registry, types.MappingProxyType): + variants += registry.values() + return variants def clear_variants(func=None): @@ -852,7 +867,7 @@ def singledispatch(func): # There are many programs that use functools without singledispatch, so we # trade-off making singledispatch marginally slower for the benefit of # making start-up of such applications slightly faster. - import types, weakref + import weakref registry = {} dispatch_cache = weakref.WeakKeyDictionary() @@ -934,11 +949,6 @@ def register(cls, func=None): f"{cls!r} is not a class." ) - try: - register_variant(outer_func, func) - except AttributeError: - pass - if _is_union_type(cls): from typing import get_args From b524244098f56fac7f40b870566ea1bf65121265 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 27 Mar 2022 16:25:35 -0700 Subject: [PATCH 08/30] more tests --- Lib/test/test_functools.py | 46 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index d4598df358ffb4..e8d64f5bd38bda 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1953,6 +1953,52 @@ def func2(): pass self.assertEqual(functools.get_variants(func1), []) self.assertEqual(functools.get_variants(func2), []) + def test_singledispatch_interaction(self): + @functools.singledispatch + def func(obj): + return "base" + original_func = func.registry[object] + self.assertEqual(functools.get_variants(func), [original_func]) + @func.register(int) + def func_int(obj): + return "int" + self.assertEqual(functools.get_variants(func), [original_func, func_int]) + + def weird_func(): pass + weird_func.registry = 42 + # shouldn't crash if the registry attribute exists but is not + # a mapping proxy + self.assertEqual(functools.get_variants(weird_func), []) + + def test_both_singledispatch_and_overload(self): + from typing import overload + def complex_func(arg: str) -> int: ... + str_overload = complex_func + overload(complex_func) + def complex_func(arg: int) -> str: ... + int_overload = complex_func + overload(complex_func) + @functools.singledispatch + def complex_func(arg: object): + raise NotImplementedError + @complex_func.register + def str_variant(arg: str) -> int: + return int(arg) + @complex_func.register + def int_variant(arg: int) -> str: + return str(arg) + + self.assertEqual( + functools.get_variants(complex_func), + [ + str_overload, + int_overload, + complex_func.registry[object], + str_variant, + int_variant, + ] + ) + class TestSingleDispatch(unittest.TestCase): def test_simple_overloads(self): From e95558e4917f7f1f638106490c808e235f0aca5d Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 27 Mar 2022 16:31:46 -0700 Subject: [PATCH 09/30] and another --- Lib/functools.py | 1 + Lib/test/test_functools.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/Lib/functools.py b/Lib/functools.py index f2fb47df9c25ce..0719dfa394d038 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -1007,6 +1007,7 @@ def _method(*args, **kwargs): _method.__isabstractmethod__ = self.__isabstractmethod__ _method.register = self.register + _method.registry = self.dispatcher.registry update_wrapper(_method, self.func) return _method diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index e8d64f5bd38bda..952d425f577f29 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1970,6 +1970,26 @@ def weird_func(): pass # a mapping proxy self.assertEqual(functools.get_variants(weird_func), []) + def test_singledispatchmethod_interaction(self): + class A: + @functools.singledispatchmethod + def t(self, arg): + self.arg = "base" + @t.register(int) + def int_t(self, arg): + self.arg = "int" + @t.register(str) + def str_t(self, arg): + self.arg = "str" + expected = [ + A.t.registry[object], + A.int_t, + A.str_t, + ] + self.assertEqual(functools.get_variants(A.t), expected) + method_object = A.__dict__["t"] # bypass the descriptor + self.assertEqual(functools.get_variants(method_object), expected) + def test_both_singledispatch_and_overload(self): from typing import overload def complex_func(arg: str) -> int: ... From 31fd72d66b8c60af24b267d829cfd0e2be868137 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 27 Mar 2022 16:35:00 -0700 Subject: [PATCH 10/30] fix line length in new tests --- Lib/test/test_functools.py | 75 ++++++++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 19 deletions(-) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 952d425f577f29..823e806ea20cf0 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1900,20 +1900,34 @@ def method(self): ... class TestVariantRegistry(unittest.TestCase): def test_get_key_for_callable(self): - self.assertEqual(functools._get_key_for_callable(len), - "builtins.len") - self.assertEqual(functools._get_key_for_callable(py_cached_func), - f"{__name__}.py_cached_func") - self.assertEqual(functools._get_key_for_callable(MethodHolder.clsmethod), - f"{__name__}.MethodHolder.clsmethod") - self.assertEqual(functools._get_key_for_callable(MethodHolder.stmethod), - f"{__name__}.MethodHolder.stmethod") - self.assertEqual(functools._get_key_for_callable(MethodHolder.method), - f"{__name__}.MethodHolder.method") + self.assertEqual( + functools._get_key_for_callable(len), + "builtins.len", + ) + self.assertEqual( + functools._get_key_for_callable(py_cached_func), + f"{__name__}.py_cached_func", + ) + self.assertEqual( + functools._get_key_for_callable(MethodHolder.clsmethod), + f"{__name__}.MethodHolder.clsmethod", + ) + self.assertEqual( + functools._get_key_for_callable(MethodHolder.stmethod), + f"{__name__}.MethodHolder.stmethod", + ) + self.assertEqual( + functools._get_key_for_callable(MethodHolder.method), + f"{__name__}.MethodHolder.method", + ) def test_get_variants(self): - def func1(): pass - def func2(): pass + def func1(): + pass + + def func2(): + pass + obj1 = object() obj2 = object() self.assertEqual(functools.get_variants(func1), []) @@ -1928,8 +1942,12 @@ def func2(): pass self.assertEqual(functools.get_variants(func2), []) def test_clear_variants(self): - def func1(): pass - def func2(): pass + def func1(): + pass + + def func2(): + pass + obj1 = object() functools.register_variant(func1, obj1) @@ -1957,14 +1975,21 @@ def test_singledispatch_interaction(self): @functools.singledispatch def func(obj): return "base" + original_func = func.registry[object] self.assertEqual(functools.get_variants(func), [original_func]) + @func.register(int) def func_int(obj): return "int" - self.assertEqual(functools.get_variants(func), [original_func, func_int]) - def weird_func(): pass + self.assertEqual( + functools.get_variants(func), [original_func, func_int] + ) + + def weird_func(): + pass + weird_func.registry = 42 # shouldn't crash if the registry attribute exists but is not # a mapping proxy @@ -1975,12 +2000,15 @@ class A: @functools.singledispatchmethod def t(self, arg): self.arg = "base" + @t.register(int) def int_t(self, arg): self.arg = "int" + @t.register(str) def str_t(self, arg): self.arg = "str" + expected = [ A.t.registry[object], A.int_t, @@ -1992,18 +2020,27 @@ def str_t(self, arg): def test_both_singledispatch_and_overload(self): from typing import overload - def complex_func(arg: str) -> int: ... + + def complex_func(arg: str) -> int: + ... + str_overload = complex_func overload(complex_func) - def complex_func(arg: int) -> str: ... + + def complex_func(arg: int) -> str: + ... + int_overload = complex_func overload(complex_func) + @functools.singledispatch def complex_func(arg: object): raise NotImplementedError + @complex_func.register def str_variant(arg: str) -> int: return int(arg) + @complex_func.register def int_variant(arg: int) -> str: return str(arg) @@ -2016,7 +2053,7 @@ def int_variant(arg: int) -> str: complex_func.registry[object], str_variant, int_variant, - ] + ], ) From 7041ad393187912e897f5c2856d1ba2103aed83a Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 27 Mar 2022 16:38:47 -0700 Subject: [PATCH 11/30] Update Doc/library/functools.rst --- Doc/library/functools.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst index b5345fd1617fa8..a85b6870db1ad5 100644 --- a/Doc/library/functools.rst +++ b/Doc/library/functools.rst @@ -590,7 +590,7 @@ The :mod:`functools` module defines the following functions: .. versionadded:: 3.8 .. versionchanged:: 3.11 - Implementation functions are now registered using :func:`register_variant`. + Implementation functions can now be retrieved using :func:`get_variants`. .. function:: update_wrapper(wrapper, wrapped, assigned=WRAPPER_ASSIGNMENTS, updated=WRAPPER_UPDATES) From e26b0db293b4c744b98751f11ddfc9f6b77af36d Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 27 Mar 2022 16:39:13 -0700 Subject: [PATCH 12/30] Update Doc/library/typing.rst --- Doc/library/typing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Doc/library/typing.rst b/Doc/library/typing.rst index 510733c7708796..03479a64671b3c 100644 --- a/Doc/library/typing.rst +++ b/Doc/library/typing.rst @@ -2255,7 +2255,7 @@ Functions and decorators See :pep:`484` for details and comparison with other typing semantics. .. versionchanged:: 3.11 - Overloaded functions are now registered using :func:`functools.register_variant`. + Overloaded functions can now be retrieved using :func:`functools.get_variants`. .. decorator:: final From 1bf89fb67bd5ea2a1a06587767d333f05ced74f4 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 2 Apr 2022 12:27:24 -0700 Subject: [PATCH 13/30] only for overload --- Doc/library/functools.rst | 30 +------ Doc/library/typing.rst | 25 +++++- Lib/functools.py | 57 ------------- Lib/test/test_functools.py | 159 ------------------------------------- Lib/test/test_typing.py | 53 +++++++++---- Lib/typing.py | 46 +++++++++-- 6 files changed, 103 insertions(+), 267 deletions(-) diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst index a85b6870db1ad5..9f68eeedeb25f1 100644 --- a/Doc/library/functools.rst +++ b/Doc/library/functools.rst @@ -535,9 +535,6 @@ The :mod:`functools` module defines the following functions: .. versionchanged:: 3.7 The :func:`register` attribute now supports using type annotations. - .. versionchanged:: 3.11 - Implementation functions can now be retrieved using :func:`get_variants`. - .. class:: singledispatchmethod(func) Transform a method into a :term:`single-dispatch int: - ... - - str_overload = complex_func - overload(complex_func) - - def complex_func(arg: int) -> str: - ... - - int_overload = complex_func - overload(complex_func) - - @functools.singledispatch - def complex_func(arg: object): - raise NotImplementedError - - @complex_func.register - def str_variant(arg: str) -> int: - return int(arg) - - @complex_func.register - def int_variant(arg: int) -> str: - return str(arg) - - self.assertEqual( - functools.get_variants(complex_func), - [ - str_overload, - int_overload, - complex_func.registry[object], - str_variant, - int_variant, - ], - ) - - class TestSingleDispatch(unittest.TestCase): def test_simple_overloads(self): @functools.singledispatch diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 484a781216e4b9..2d7af4980c9683 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -1,7 +1,6 @@ import contextlib import collections from functools import lru_cache -import functools import inspect import pickle import re @@ -10,7 +9,7 @@ from unittest import TestCase, main, skipUnless, skip from copy import copy, deepcopy -from typing import Any, NoReturn, Never, assert_never, overload +from typing import Any, NoReturn, Never, assert_never, overload, get_overloads, clear_overloads from typing import TypeVar, TypeVarTuple, Unpack, AnyStr from typing import T, KT, VT # Not in __all__. from typing import Union, Optional, Literal @@ -3820,9 +3819,7 @@ def blah(): blah() - def test_variant_registry(self): - # Test the interaction with the variants registry in - # the functools module. + def set_up_overloads(self): def blah(): pass @@ -3838,20 +3835,48 @@ def blah(): def blah(): pass - self.assertEqual(functools.get_variants(blah), [overload1, overload2]) + return blah, [overload1, overload2] + + def test_overload_registry(self): + impl, overloads = self.set_up_overloads() + + self.assertEqual(list(get_overloads(impl)), overloads) + clear_overloads(blah) + self.assertEqual(get_overloads(blah), []) + + impl, overloads = self.set_up_overloads() + + self.assertEqual(list(get_overloads(impl)), overloads) + clear_overloads() + self.assertEqual(get_overloads(blah), []) def test_variant_registry_repeated(self): for _ in range(2): - def blah(): - pass + impl, overloads = self.set_up_overloads() - overload_func = blah - overload(blah) + self.assertEqual(list(get_overloads(impl)), overloads) - def blah(): - pass - - self.assertEqual(functools.get_variants(blah), [overload_func]) + def test_get_key_for_callable(self): + self.assertEqual( + typing._get_key_for_callable(len), + "builtins.len", + ) + self.assertEqual( + typing._get_key_for_callable(py_cached_func), + f"{__name__}.py_cached_func", + ) + self.assertEqual( + typing._get_key_for_callable(MethodHolder.clsmethod), + f"{__name__}.MethodHolder.clsmethod", + ) + self.assertEqual( + typing._get_key_for_callable(MethodHolder.stmethod), + f"{__name__}.MethodHolder.stmethod", + ) + self.assertEqual( + typing._get_key_for_callable(MethodHolder.method), + f"{__name__}.MethodHolder.method", + ) # Definitions needed for features introduced in Python 3.6 diff --git a/Lib/typing.py b/Lib/typing.py index 056ecc948d48ae..3d58243063a495 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2413,11 +2413,11 @@ def utf8(value: str) -> bytes: ... def utf8(value): # implementation goes here - Each overload is registered with functools.register_variant and can be - retrieved using functools.get_variants. + Overload definition can be retrieved at runtime using the + get_overloads() function. """ try: - existing = functools.get_variants(func) + existing = get_overloads(func) except AttributeError: # Not a normal function; ignore. pass @@ -2431,21 +2431,53 @@ def utf8(value): if firstlineno is not None: existing_lineno = _get_firstlineno(existing[-1]) if existing_lineno is not None and firstlineno <= existing_lineno: - functools.clear_variants(func) + clear_overloads(func) - functools.register_variant(func, func) + key = _get_key_for_callable(func) + _overload_registry.setdefault(key, []).append(func) return _overload_dummy def _get_firstlineno(func): # staticmethod, classmethod - if hasattr(func, "__func__"): - func = func.__func__ + func = getattr(func, "__func__", func) if not hasattr(func, '__code__'): return None return func.__code__.co_firstlineno +# {key: [overload]} +_overload_registry = {} + + +def get_overloads(func): + """Return all defined overloads for *func* as a sequence.""" + key = _get_key_for_callable(func) + return _overload_registry.get(key, []) + + +def clear_overloads(func=None): + """Clear all overloads for the given function (or all functions).""" + if func is None: + _overload_registry.clear() + else: + key = _get_key_for_callable(func) + _overload_registry.pop(key, None) + + +def _get_key_for_callable(func): + """Return a key for the given callable. + + This is used as a key in the overload registry. + + If no key can be created (because the object is not of a supported type), raise + AttributeError. + """ + # classmethod and staticmethod + func = getattr(func, "__func__", func) + return f"{func.__module__}.{func.__qualname__}" + + def final(f): """A decorator to indicate final methods and final classes. From dfdbdc7cf0bb200294a167baef9ccaca3222a887 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 2 Apr 2022 12:36:00 -0700 Subject: [PATCH 14/30] fix tests --- Lib/test/test_typing.py | 23 ++++++++++++++++++----- Lib/typing.py | 2 ++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 2d7af4980c9683..f0d277bebdd03a 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -3794,6 +3794,19 @@ def test_or(self): self.assertEqual("x" | X, Union["x", X]) +@lru_cache() +def cached_func(x, y): + return 3 * x + y + + +class MethodHolder: + @classmethod + def clsmethod(cls): ... + @staticmethod + def stmethod(): ... + def method(self): ... + + class OverloadTests(BaseTestCase): def test_overload_fails(self): @@ -3841,14 +3854,14 @@ def test_overload_registry(self): impl, overloads = self.set_up_overloads() self.assertEqual(list(get_overloads(impl)), overloads) - clear_overloads(blah) - self.assertEqual(get_overloads(blah), []) + clear_overloads(impl) + self.assertEqual(get_overloads(impl), []) impl, overloads = self.set_up_overloads() self.assertEqual(list(get_overloads(impl)), overloads) clear_overloads() - self.assertEqual(get_overloads(blah), []) + self.assertEqual(get_overloads(impl), []) def test_variant_registry_repeated(self): for _ in range(2): @@ -3862,8 +3875,8 @@ def test_get_key_for_callable(self): "builtins.len", ) self.assertEqual( - typing._get_key_for_callable(py_cached_func), - f"{__name__}.py_cached_func", + typing._get_key_for_callable(cached_func), + f"{__name__}.cached_func", ) self.assertEqual( typing._get_key_for_callable(MethodHolder.clsmethod), diff --git a/Lib/typing.py b/Lib/typing.py index 3d58243063a495..b3fcc24be4dc87 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -121,9 +121,11 @@ def _idfunc(_, x): 'assert_type', 'assert_never', 'cast', + 'clear_overloads', 'final', 'get_args', 'get_origin', + 'get_overloads', 'get_type_hints', 'is_typeddict', 'Never', From e16c8d090c7fa9d4170c6e8ca2d4f136c8f2d62a Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 2 Apr 2022 12:37:54 -0700 Subject: [PATCH 15/30] undo stray changes, fix NEWS entry --- Doc/library/functools.rst | 4 +--- Lib/functools.py | 7 ++----- Lib/test/test_functools.py | 8 -------- .../next/Library/2022-03-06-18-15-32.bpo-45100.B_lHu0.rst | 4 ++-- 4 files changed, 5 insertions(+), 18 deletions(-) diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst index 89d3aea20c2eb9..e23946a0a45e75 100644 --- a/Doc/library/functools.rst +++ b/Doc/library/functools.rst @@ -535,6 +535,7 @@ The :mod:`functools` module defines the following functions: .. versionchanged:: 3.7 The :func:`register` attribute now supports using type annotations. + .. class:: singledispatchmethod(func) Transform a method into a :term:`single-dispatch @@ -950,7 +948,6 @@ def _method(*args, **kwargs): _method.__isabstractmethod__ = self.__isabstractmethod__ _method.register = self.register - _method.registry = self.dispatcher.registry update_wrapper(_method, self.func) return _method diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 23666697c79095..abbd50a47f395f 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -1890,14 +1890,6 @@ def cached_staticmeth(x, y): return 3 * x + y -class MethodHolder: - @classmethod - def clsmethod(cls): ... - @staticmethod - def stmethod(): ... - def method(self): ... - - class TestSingleDispatch(unittest.TestCase): def test_simple_overloads(self): @functools.singledispatch diff --git a/Misc/NEWS.d/next/Library/2022-03-06-18-15-32.bpo-45100.B_lHu0.rst b/Misc/NEWS.d/next/Library/2022-03-06-18-15-32.bpo-45100.B_lHu0.rst index 5bf0d7fdbb1927..d644557545366d 100644 --- a/Misc/NEWS.d/next/Library/2022-03-06-18-15-32.bpo-45100.B_lHu0.rst +++ b/Misc/NEWS.d/next/Library/2022-03-06-18-15-32.bpo-45100.B_lHu0.rst @@ -1,2 +1,2 @@ -Add a mechanism to register function variants, such as overloads and -singledispatch implementation functions. Patch by Jelle Zijlstra. +Add :func:`typing.get_overloads` and :func:`typing.clear_overloads`. +Patch by Jelle Zijlstra. From b3d222790790eaa733814542df327900dbe050b8 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 2 Apr 2022 12:38:45 -0700 Subject: [PATCH 16/30] remove extra import --- Lib/test/test_typing.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index f0d277bebdd03a..2ee6262cf5b92f 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -3810,8 +3810,6 @@ def method(self): ... class OverloadTests(BaseTestCase): def test_overload_fails(self): - from typing import overload - with self.assertRaises(RuntimeError): @overload @@ -3821,8 +3819,6 @@ def blah(): blah() def test_overload_succeeds(self): - from typing import overload - @overload def blah(): pass From 9727eee249d98af5daf9d3b4102b2916d5bebf4f Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 2 Apr 2022 16:45:09 -0700 Subject: [PATCH 17/30] Apply suggestions from code review Co-authored-by: Alex Waygood --- Doc/library/typing.rst | 4 ++-- Lib/test/test_typing.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Doc/library/typing.rst b/Doc/library/typing.rst index 8417cd2af535cc..35bf0b5ee90190 100644 --- a/Doc/library/typing.rst +++ b/Doc/library/typing.rst @@ -2260,10 +2260,10 @@ Functions and decorators .. function:: get_overloads(func) - Return a sequence of :func:`overload`-decorated definitions for *func*. *func* is + Return a sequence of :func:`@overload `-decorated definitions for *func*. *func* is the function object for the implementation of the overloaded function. For example, given the definition of ``process`` in the documentation for - :func:`overload`, ``get_overloads(process)`` will return a sequence of three + :func:`@overload `, ``get_overloads(process)`` will return a sequence of three function objects for the three defined overloads. This function can be used for introspecting an overloaded function at runtime. diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 2ee6262cf5b92f..453a32a9612b65 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -3859,7 +3859,7 @@ def test_overload_registry(self): clear_overloads() self.assertEqual(get_overloads(impl), []) - def test_variant_registry_repeated(self): + def test_overload_registry_repeated(self): for _ in range(2): impl, overloads = self.set_up_overloads() From 2e374b8c62c05f29ba3cd7481d217ddf240fd7e6 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 2 Apr 2022 18:07:16 -0700 Subject: [PATCH 18/30] Apply suggestions from code review Co-authored-by: Guido van Rossum --- Doc/library/typing.rst | 2 +- Lib/test/test_typing.py | 3 ++- Lib/typing.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Doc/library/typing.rst b/Doc/library/typing.rst index 35bf0b5ee90190..dae85f490cfa28 100644 --- a/Doc/library/typing.rst +++ b/Doc/library/typing.rst @@ -2255,7 +2255,7 @@ Functions and decorators See :pep:`484` for details and comparison with other typing semantics. .. versionchanged:: 3.11 - Overloaded functions can now be retrieved at runtime :func:`get_overloads`. + Overloaded functions can now be introspected at runtime using :func:`get_overloads`. .. function:: get_overloads(func) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 453a32a9612b65..5c034c59961c96 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -9,7 +9,8 @@ from unittest import TestCase, main, skipUnless, skip from copy import copy, deepcopy -from typing import Any, NoReturn, Never, assert_never, overload, get_overloads, clear_overloads +from typing import Any, NoReturn, Never, assert_never +from typing import overload, get_overloads, clear_overloads from typing import TypeVar, TypeVarTuple, Unpack, AnyStr from typing import T, KT, VT # Not in __all__. from typing import Union, Optional, Literal diff --git a/Lib/typing.py b/Lib/typing.py index b3fcc24be4dc87..962e6e8046943d 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2415,7 +2415,7 @@ def utf8(value: str) -> bytes: ... def utf8(value): # implementation goes here - Overload definition can be retrieved at runtime using the + The overloads for a function can be retrieved at runtime using the get_overloads() function. """ try: From ff03b12a038129effdb763aab5c7accc73eb2375 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 2 Apr 2022 18:11:21 -0700 Subject: [PATCH 19/30] Guido's feedback --- Lib/typing.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/Lib/typing.py b/Lib/typing.py index 962e6e8046943d..a765c7fa33cf1f 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2419,11 +2419,13 @@ def utf8(value): get_overloads() function. """ try: - existing = get_overloads(func) - except AttributeError: + key = _get_key_for_callable(func) + except TypeError: # Not a normal function; ignore. pass else: + # We're inlining get_overloads() here to avoid computing the key twice. + existing = _overload_registry.setdefault(key, []) if existing: # If we are registering a variant with a lineno below or equal to that of the # most recent existing variant, we're probably re-creating overloads for a @@ -2433,10 +2435,9 @@ def utf8(value): if firstlineno is not None: existing_lineno = _get_firstlineno(existing[-1]) if existing_lineno is not None and firstlineno <= existing_lineno: - clear_overloads(func) + existing.clear() - key = _get_key_for_callable(func) - _overload_registry.setdefault(key, []).append(func) + existing.append(func) return _overload_dummy @@ -2477,7 +2478,10 @@ def _get_key_for_callable(func): """ # classmethod and staticmethod func = getattr(func, "__func__", func) - return f"{func.__module__}.{func.__qualname__}" + try: + return f"{func.__module__}.{func.__qualname__}" + except AttributeError: + raise TypeError(f"Cannot create key for {func!r}") from None def final(f): From 17f071093a477a14d27848729a82f8821fd4eae4 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 3 Apr 2022 14:14:56 -0700 Subject: [PATCH 20/30] Optimizations suggested by Guido and Alex --- Lib/test/test_typing.py | 10 +++++----- Lib/typing.py | 11 +++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 5c034c59961c96..9afaa06f32218d 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -3869,23 +3869,23 @@ def test_overload_registry_repeated(self): def test_get_key_for_callable(self): self.assertEqual( typing._get_key_for_callable(len), - "builtins.len", + ("builtins", "len"), ) self.assertEqual( typing._get_key_for_callable(cached_func), - f"{__name__}.cached_func", + (__name__, "cached_func"), ) self.assertEqual( typing._get_key_for_callable(MethodHolder.clsmethod), - f"{__name__}.MethodHolder.clsmethod", + (__name__, "MethodHolder.clsmethod"), ) self.assertEqual( typing._get_key_for_callable(MethodHolder.stmethod), - f"{__name__}.MethodHolder.stmethod", + (__name__, "MethodHolder.stmethod"), ) self.assertEqual( typing._get_key_for_callable(MethodHolder.method), - f"{__name__}.MethodHolder.method", + (__name__, "MethodHolder.method"), ) diff --git a/Lib/typing.py b/Lib/typing.py index a765c7fa33cf1f..fcb206d07d15e6 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2418,9 +2418,12 @@ def utf8(value): The overloads for a function can be retrieved at runtime using the get_overloads() function. """ + # Inline version of _get_key_for_callable + # classmethod and staticmethod + func = getattr(func, "__func__", func) try: - key = _get_key_for_callable(func) - except TypeError: + key = (func.__module__, func.__qualname__) + except AttributeError: # Not a normal function; ignore. pass else: @@ -2450,7 +2453,7 @@ def _get_firstlineno(func): # {key: [overload]} -_overload_registry = {} +_overload_registry: dict[tuple[str, str], list[Any]] = {} def get_overloads(func): @@ -2479,7 +2482,7 @@ def _get_key_for_callable(func): # classmethod and staticmethod func = getattr(func, "__func__", func) try: - return f"{func.__module__}.{func.__qualname__}" + return (func.__module__, func.__qualname__) except AttributeError: raise TypeError(f"Cannot create key for {func!r}") from None From 2346970658f8d8ef20c00d0ede3ea8032787b0de Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 3 Apr 2022 15:13:41 -0700 Subject: [PATCH 21/30] inline _get_firstlineno, store outer objects for classmethod/staticmethod --- Lib/typing.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/Lib/typing.py b/Lib/typing.py index fcb206d07d15e6..189c2746852282 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2420,9 +2420,9 @@ def utf8(value): """ # Inline version of _get_key_for_callable # classmethod and staticmethod - func = getattr(func, "__func__", func) + inner_func = getattr(func, "__func__", func) try: - key = (func.__module__, func.__qualname__) + key = (inner_func.__module__, inner_func.__qualname__) except AttributeError: # Not a normal function; ignore. pass @@ -2434,24 +2434,26 @@ def utf8(value): # most recent existing variant, we're probably re-creating overloads for a # function that already exists. In that case, we clear the existing variants # to avoid leaking memory. - firstlineno = _get_firstlineno(func) - if firstlineno is not None: - existing_lineno = _get_firstlineno(existing[-1]) - if existing_lineno is not None and firstlineno <= existing_lineno: - existing.clear() + try: + firstlineno = inner_func.__code__.co_firstlineno + except AttributeError: + pass + else: + existing_func = existing[-1] + # classmethod and staticmethod + existing_func = getattr(existing_func, "__func__", existing_func) + try: + existing_lineno = existing_func.__code__.co_firstlineno + except AttributeError: + pass + else: + if firstlineno <= existing_lineno: + existing.clear() existing.append(func) return _overload_dummy -def _get_firstlineno(func): - # staticmethod, classmethod - func = getattr(func, "__func__", func) - if not hasattr(func, '__code__'): - return None - return func.__code__.co_firstlineno - - # {key: [overload]} _overload_registry: dict[tuple[str, str], list[Any]] = {} From f2053a065424876b97fc43d5854cb5b1a962727f Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 3 Apr 2022 15:41:25 -0700 Subject: [PATCH 22/30] use defaultdict --- Lib/typing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Lib/typing.py b/Lib/typing.py index 189c2746852282..3814368523267e 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2428,7 +2428,7 @@ def utf8(value): pass else: # We're inlining get_overloads() here to avoid computing the key twice. - existing = _overload_registry.setdefault(key, []) + existing = _overload_registry[key] if existing: # If we are registering a variant with a lineno below or equal to that of the # most recent existing variant, we're probably re-creating overloads for a @@ -2455,7 +2455,7 @@ def utf8(value): # {key: [overload]} -_overload_registry: dict[tuple[str, str], list[Any]] = {} +_overload_registry = collections.defaultdict(list) def get_overloads(func): From b6131ad8dad0f307cbf235ff1ad9e76a0f06b489 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 3 Apr 2022 19:42:25 -0700 Subject: [PATCH 23/30] another optimization --- Lib/typing.py | 47 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/Lib/typing.py b/Lib/typing.py index 3814368523267e..c3558120439436 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2429,11 +2429,15 @@ def utf8(value): else: # We're inlining get_overloads() here to avoid computing the key twice. existing = _overload_registry[key] - if existing: - # If we are registering a variant with a lineno below or equal to that of the - # most recent existing variant, we're probably re-creating overloads for a - # function that already exists. In that case, we clear the existing variants - # to avoid leaking memory. + if existing and len(existing) > 20: + # If we are registering an overload with a lineno below or equal to that of the + # most recent existing overload, we're probably re-creating overloads for a + # function that already exists. This can happen if the module is being reloaded + # or if the overloads are being created in a nested function. + # In that case, we clear the existing overloads + # to avoid leaking memory. But we do this only if there are already a lot of + # overloads, so we don't have to figure out the linenos in the common case + # where there are only a few overloads. try: firstlineno = inner_func.__code__.co_firstlineno except AttributeError: @@ -2461,7 +2465,38 @@ def utf8(value): def get_overloads(func): """Return all defined overloads for *func* as a sequence.""" key = _get_key_for_callable(func) - return _overload_registry.get(key, []) + overloads = _overload_registry.get(key, []) + + # We clear out overloads that have higher linenos than a later + # overload, because they're probably the result of a recreation of + # the function (see the comments in @overload). But we have to do it + # here too because @overload only clears out overloads when there are + # many. + final_overloads = [] + last_lineno = -1 + should_clear = False + for overloaded_func in overloads: + lineno = _get_firstlineno(overloaded_func) + if lineno is not None: + # If the same function is registered multiple times on the same line, + # we skip the duplicates. + if lineno <= last_lineno: + final_overloads.clear() + should_clear = True + last_lineno = lineno + final_overloads.append(overloaded_func) + if should_clear: + _overload_registry[key] = final_overloads + return final_overloads + + +def _get_firstlineno(func): + # staticmethod, classmethod + func = getattr(func, "__func__", func) + if not hasattr(func, '__code__'): + return None + return func.__code__.co_firstlineno + def clear_overloads(func=None): From 506bd66604b16ea4ec5f4038a356c7b1b327d03d Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Thu, 7 Apr 2022 10:25:24 -0700 Subject: [PATCH 24/30] Update Lib/typing.py Co-authored-by: Ken Jin --- Lib/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/typing.py b/Lib/typing.py index c3558120439436..ff55f6c0011098 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2465,7 +2465,7 @@ def utf8(value): def get_overloads(func): """Return all defined overloads for *func* as a sequence.""" key = _get_key_for_callable(func) - overloads = _overload_registry.get(key, []) + overloads = _overload_registry[key] # We clear out overloads that have higher linenos than a later # overload, because they're probably the result of a recreation of From 103bfd4f3eafdfac9b03960e3a425836f769443b Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Fri, 8 Apr 2022 18:51:12 -0700 Subject: [PATCH 25/30] Simpler implementation (thanks Guido) --- Lib/test/test_typing.py | 37 +++++--------- Lib/typing.py | 109 ++++++++-------------------------------- 2 files changed, 35 insertions(+), 111 deletions(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index d376ed3d2a9c2c..67c38155e447d6 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -1,5 +1,6 @@ import contextlib import collections +from collections import defaultdict from functools import lru_cache import inspect import pickle @@ -7,6 +8,7 @@ import sys import warnings from unittest import TestCase, main, skipUnless, skip +from unittest.mock import patch from copy import copy, deepcopy from typing import Any, NoReturn, Never, assert_never @@ -3903,17 +3905,26 @@ def blah(): return blah, [overload1, overload2] + # Make sure we don't clear the global overload registry + @patch("typing._overload_registry", + defaultdict(lambda: defaultdict(dict))) def test_overload_registry(self): - impl, overloads = self.set_up_overloads() + self.assertEqual(len(typing._overload_registry), 0) + impl, overloads = self.set_up_overloads() + self.assertNotEqual(typing._overload_registry, {}) self.assertEqual(list(get_overloads(impl)), overloads) + clear_overloads(impl) + self.assertEqual(typing._overload_registry, {}) self.assertEqual(get_overloads(impl), []) impl, overloads = self.set_up_overloads() - + self.assertNotEqual(typing._overload_registry, {}) self.assertEqual(list(get_overloads(impl)), overloads) + clear_overloads() + self.assertEqual(typing._overload_registry, {}) self.assertEqual(get_overloads(impl), []) def test_overload_registry_repeated(self): @@ -3922,28 +3933,6 @@ def test_overload_registry_repeated(self): self.assertEqual(list(get_overloads(impl)), overloads) - def test_get_key_for_callable(self): - self.assertEqual( - typing._get_key_for_callable(len), - ("builtins", "len"), - ) - self.assertEqual( - typing._get_key_for_callable(cached_func), - (__name__, "cached_func"), - ) - self.assertEqual( - typing._get_key_for_callable(MethodHolder.clsmethod), - (__name__, "MethodHolder.clsmethod"), - ) - self.assertEqual( - typing._get_key_for_callable(MethodHolder.stmethod), - (__name__, "MethodHolder.stmethod"), - ) - self.assertEqual( - typing._get_key_for_callable(MethodHolder.method), - (__name__, "MethodHolder.method"), - ) - # Definitions needed for features introduced in Python 3.6 diff --git a/Lib/typing.py b/Lib/typing.py index 1e6271ded49552..b3288ae0f9c977 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -21,6 +21,7 @@ from abc import abstractmethod, ABCMeta import collections +from collections import defaultdict import collections.abc import contextlib import functools @@ -2431,6 +2432,10 @@ def _overload_dummy(*args, **kwds): "by an implementation that is not @overload-ed.") +# {module: {qualname: {firstlineno: func}}} +_overload_registry = defaultdict(lambda: defaultdict(dict)) + + def overload(func): """Decorator for overloaded functions/methods. @@ -2460,85 +2465,26 @@ def utf8(value): The overloads for a function can be retrieved at runtime using the get_overloads() function. """ - # Inline version of _get_key_for_callable # classmethod and staticmethod - inner_func = getattr(func, "__func__", func) + f = getattr(func, "__func__", func) try: - key = (inner_func.__module__, inner_func.__qualname__) + _overload_registry[f.__module__][f.__qualname__][f.__code__.co_firstlineno] = func except AttributeError: # Not a normal function; ignore. pass - else: - # We're inlining get_overloads() here to avoid computing the key twice. - existing = _overload_registry[key] - if existing and len(existing) > 20: - # If we are registering an overload with a lineno below or equal to that of the - # most recent existing overload, we're probably re-creating overloads for a - # function that already exists. This can happen if the module is being reloaded - # or if the overloads are being created in a nested function. - # In that case, we clear the existing overloads - # to avoid leaking memory. But we do this only if there are already a lot of - # overloads, so we don't have to figure out the linenos in the common case - # where there are only a few overloads. - try: - firstlineno = inner_func.__code__.co_firstlineno - except AttributeError: - pass - else: - existing_func = existing[-1] - # classmethod and staticmethod - existing_func = getattr(existing_func, "__func__", existing_func) - try: - existing_lineno = existing_func.__code__.co_firstlineno - except AttributeError: - pass - else: - if firstlineno <= existing_lineno: - existing.clear() - - existing.append(func) return _overload_dummy -# {key: [overload]} -_overload_registry = collections.defaultdict(list) - - def get_overloads(func): """Return all defined overloads for *func* as a sequence.""" - key = _get_key_for_callable(func) - overloads = _overload_registry[key] - - # We clear out overloads that have higher linenos than a later - # overload, because they're probably the result of a recreation of - # the function (see the comments in @overload). But we have to do it - # here too because @overload only clears out overloads when there are - # many. - final_overloads = [] - last_lineno = -1 - should_clear = False - for overloaded_func in overloads: - lineno = _get_firstlineno(overloaded_func) - if lineno is not None: - # If the same function is registered multiple times on the same line, - # we skip the duplicates. - if lineno <= last_lineno: - final_overloads.clear() - should_clear = True - last_lineno = lineno - final_overloads.append(overloaded_func) - if should_clear: - _overload_registry[key] = final_overloads - return final_overloads - - -def _get_firstlineno(func): - # staticmethod, classmethod - func = getattr(func, "__func__", func) - if not hasattr(func, '__code__'): - return None - return func.__code__.co_firstlineno - + # classmethod and staticmethod + f = getattr(func, "__func__", func) + if f.__module__ not in _overload_registry: + return [] + mod_dict = _overload_registry[f.__module__] + if f.__qualname__ not in mod_dict: + return [] + return list(mod_dict[f.__qualname__].values()) def clear_overloads(func=None): @@ -2546,24 +2492,13 @@ def clear_overloads(func=None): if func is None: _overload_registry.clear() else: - key = _get_key_for_callable(func) - _overload_registry.pop(key, None) - - -def _get_key_for_callable(func): - """Return a key for the given callable. - - This is used as a key in the overload registry. - - If no key can be created (because the object is not of a supported type), raise - AttributeError. - """ - # classmethod and staticmethod - func = getattr(func, "__func__", func) - try: - return (func.__module__, func.__qualname__) - except AttributeError: - raise TypeError(f"Cannot create key for {func!r}") from None + f = getattr(func, "__func__", func) + if f.__module__ not in _overload_registry: + return + mod_dict = _overload_registry[f.__module__] + mod_dict.pop(f.__qualname__, None) + if not mod_dict: + del _overload_registry[f.__module__] def final(f): From d453f7f78fb8663f6a2fca304cb31c554860e9ba Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Fri, 8 Apr 2022 22:05:45 -0700 Subject: [PATCH 26/30] More comments and tests --- Lib/test/test_typing.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 67c38155e447d6..62f4b1a88e9c6b 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -3909,13 +3909,16 @@ def blah(): @patch("typing._overload_registry", defaultdict(lambda: defaultdict(dict))) def test_overload_registry(self): - self.assertEqual(len(typing._overload_registry), 0) + # The registry starts out empty + self.assertEqual(typing._overload_registry, {}) impl, overloads = self.set_up_overloads() self.assertNotEqual(typing._overload_registry, {}) self.assertEqual(list(get_overloads(impl)), overloads) clear_overloads(impl) + # Clearing overloads for this function should also + # clear out entries for the module. self.assertEqual(typing._overload_registry, {}) self.assertEqual(get_overloads(impl), []) @@ -3923,10 +3926,27 @@ def test_overload_registry(self): self.assertNotEqual(typing._overload_registry, {}) self.assertEqual(list(get_overloads(impl)), overloads) + # Make sure that after we clear all overloads, the registry is + # completely empty. clear_overloads() self.assertEqual(typing._overload_registry, {}) self.assertEqual(get_overloads(impl), []) + # If we create another function, its overloads won't be cleared + # if we call clear_overloads(impl) + impl, overloads = self.set_up_overloads() + self.assertEqual(list(get_overloads(impl)), overloads) + + def some_other_func(): pass + overload(some_other_func) + other_overload = some_other_func + def some_other_func(): pass + self.assertEqual(list(get_overloads(some_other_func)), [other_overload]) + + clear_overloads(impl) + self.assertEqual(list(get_overloads(impl)), []) + self.assertEqual(list(get_overloads(some_other_func)), [other_overload]) + def test_overload_registry_repeated(self): for _ in range(2): impl, overloads = self.set_up_overloads() From ea62287a38d4e7f49525461a531d8d19bcb4724e Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Wed, 13 Apr 2022 19:08:15 -0700 Subject: [PATCH 27/30] simplify clear_overloads --- Doc/library/typing.rst | 24 +++++++++++++----------- Lib/test/test_typing.py | 29 +++++------------------------ Lib/typing.py | 15 +++------------ 3 files changed, 21 insertions(+), 47 deletions(-) diff --git a/Doc/library/typing.rst b/Doc/library/typing.rst index effb02b5b36dca..d62756f5d0a57e 100644 --- a/Doc/library/typing.rst +++ b/Doc/library/typing.rst @@ -2408,27 +2408,29 @@ Functions and decorators See :pep:`484` for details and comparison with other typing semantics. .. versionchanged:: 3.11 - Overloaded functions can now be introspected at runtime using :func:`get_overloads`. + Overloaded functions can now be introspected at runtime using + :func:`get_overloads`. .. function:: get_overloads(func) - Return a sequence of :func:`@overload `-decorated definitions for *func*. *func* is - the function object for the implementation of the overloaded function. - For example, given the definition of ``process`` in the documentation for - :func:`@overload `, ``get_overloads(process)`` will return a sequence of three - function objects for the three defined overloads. + Return a sequence of :func:`@overload `-decorated definitions for + *func*. *func* is the function object for the implementation of the + overloaded function. For example, given the definition of ``process`` in + the documentation for :func:`@overload `, + ``get_overloads(process)`` will return a sequence of three function objects + for the three defined overloads. - This function can be used for introspecting an overloaded function at runtime. + This function can be used for introspecting an overloaded function at + runtime. .. versionadded:: 3.11 -.. function:: clear_overloads(func=None) +.. function:: clear_overloads() - Clear all registered overloads for the given *func*. If *func* is None, clear - all overloads stored in the internal registry. This can be used to reclaim the - memory used by the registry. + Clear all registered overloads in the internal registry. This can be used + to reclaim the memory used by the registry. .. versionadded:: 3.11 diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 1e9ad294afe9a7..a51acc9a298286 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -3956,36 +3956,17 @@ def test_overload_registry(self): self.assertNotEqual(typing._overload_registry, {}) self.assertEqual(list(get_overloads(impl)), overloads) - clear_overloads(impl) - # Clearing overloads for this function should also - # clear out entries for the module. - self.assertEqual(typing._overload_registry, {}) - self.assertEqual(get_overloads(impl), []) - - impl, overloads = self.set_up_overloads() - self.assertNotEqual(typing._overload_registry, {}) - self.assertEqual(list(get_overloads(impl)), overloads) - - # Make sure that after we clear all overloads, the registry is - # completely empty. - clear_overloads() - self.assertEqual(typing._overload_registry, {}) - self.assertEqual(get_overloads(impl), []) - - # If we create another function, its overloads won't be cleared - # if we call clear_overloads(impl) - impl, overloads = self.set_up_overloads() - self.assertEqual(list(get_overloads(impl)), overloads) - def some_other_func(): pass overload(some_other_func) other_overload = some_other_func def some_other_func(): pass self.assertEqual(list(get_overloads(some_other_func)), [other_overload]) - clear_overloads(impl) - self.assertEqual(list(get_overloads(impl)), []) - self.assertEqual(list(get_overloads(some_other_func)), [other_overload]) + # Make sure that after we clear all overloads, the registry is + # completely empty. + clear_overloads() + self.assertEqual(typing._overload_registry, {}) + self.assertEqual(get_overloads(impl), []) def test_overload_registry_repeated(self): for _ in range(2): diff --git a/Lib/typing.py b/Lib/typing.py index 762a0706ec9b67..179d25ba6242ec 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2494,18 +2494,9 @@ def get_overloads(func): return list(mod_dict[f.__qualname__].values()) -def clear_overloads(func=None): - """Clear all overloads for the given function (or all functions).""" - if func is None: - _overload_registry.clear() - else: - f = getattr(func, "__func__", func) - if f.__module__ not in _overload_registry: - return - mod_dict = _overload_registry[f.__module__] - mod_dict.pop(f.__qualname__, None) - if not mod_dict: - del _overload_registry[f.__module__] +def clear_overloads(): + """Clear all overloads in the registry.""" + _overload_registry.clear() def final(f): From 905253c720ae246ad6d4b51468fec79ef803f59d Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Thu, 14 Apr 2022 09:41:13 -0700 Subject: [PATCH 28/30] use partial --- Lib/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/typing.py b/Lib/typing.py index 179d25ba6242ec..b36dd8c615eb81 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2440,7 +2440,7 @@ def _overload_dummy(*args, **kwds): # {module: {qualname: {firstlineno: func}}} -_overload_registry = defaultdict(lambda: defaultdict(dict)) +_overload_registry = defaultdict(functools.partial(defaultdict, dict)) def overload(func): From debbf8acd09000f91ef04b64b21475cda094f68a Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Thu, 14 Apr 2022 09:43:32 -0700 Subject: [PATCH 29/30] add test --- Lib/test/test_typing.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index a51acc9a298286..93796d173e51ae 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -3968,6 +3968,11 @@ def some_other_func(): pass self.assertEqual(typing._overload_registry, {}) self.assertEqual(get_overloads(impl), []) + # Querying a function with no overloads shouldn't change the registry. + def the_only_one(): pass + self.assertEqual(get_overloads(the_only_one), []) + self.assertEqual(typing._overload_registry, {}) + def test_overload_registry_repeated(self): for _ in range(2): impl, overloads = self.set_up_overloads() From 754c134dbbedbabb35612212aec5309c1898fe7e Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Thu, 14 Apr 2022 09:50:23 -0700 Subject: [PATCH 30/30] docs changes (thanks Alex) --- Doc/library/typing.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Doc/library/typing.rst b/Doc/library/typing.rst index d62756f5d0a57e..6b2a0934171a29 100644 --- a/Doc/library/typing.rst +++ b/Doc/library/typing.rst @@ -2419,9 +2419,10 @@ Functions and decorators overloaded function. For example, given the definition of ``process`` in the documentation for :func:`@overload `, ``get_overloads(process)`` will return a sequence of three function objects - for the three defined overloads. + for the three defined overloads. If called on a function with no overloads, + ``get_overloads`` returns an empty sequence. - This function can be used for introspecting an overloaded function at + ``get_overloads`` can be used for introspecting an overloaded function at runtime. .. versionadded:: 3.11