diff --git a/stdlib/functools.pyi b/stdlib/functools.pyi index 8adc3d82292e..828a9d241ab0 100644 --- a/stdlib/functools.pyi +++ b/stdlib/functools.pyi @@ -3,7 +3,7 @@ import types from _typeshed import SupportsAllComparisons, SupportsItems from collections.abc import Callable, Hashable, Iterable, Sequence, Sized from typing import Any, Generic, NamedTuple, TypeVar, overload -from typing_extensions import Literal, ParamSpec, Self, TypeAlias, TypedDict, final +from typing_extensions import Concatenate, Literal, ParamSpec, Self, TypeAlias, TypedDict, final if sys.version_info >= (3, 9): from types import GenericAlias @@ -30,6 +30,8 @@ if sys.version_info >= (3, 9): _T = TypeVar("_T") _S = TypeVar("_S") +_P = ParamSpec("_P") + _PWrapped = ParamSpec("_PWrapped") _RWrapped = TypeVar("_RWrapped") _PWrapper = ParamSpec("_PWrapper") @@ -52,25 +54,35 @@ if sys.version_info >= (3, 9): typed: bool @final -class _lru_cache_wrapper(Generic[_T]): - __wrapped__: Callable[..., _T] - def __call__(self, *args: Hashable, **kwargs: Hashable) -> _T: ... +class _lru_cache_wrapper(Generic[_P, _T]): + __wrapped__: Callable[_P, _T] + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ... def cache_info(self) -> _CacheInfo: ... def cache_clear(self) -> None: ... if sys.version_info >= (3, 9): def cache_parameters(self) -> _CacheParameters: ... - def __copy__(self) -> _lru_cache_wrapper[_T]: ... - def __deepcopy__(self, __memo: Any) -> _lru_cache_wrapper[_T]: ... + def __copy__(self) -> _lru_cache_wrapper[_P, _T]: ... + def __deepcopy__(self, __memo: Any) -> _lru_cache_wrapper[_P, _T]: ... + if sys.version_info >= (3, 8): + @overload + def __get__(self, __instance: None, __owner: type[_S] | None = ...) -> _lru_cache_wrapper[_P, _T]: ... + @overload + def __get__(self, __instance: _S, __owner: type[_S] | None = ...) -> Callable[Concatenate[_S, _P], _T]: ... + else: + @overload + def __get__(self, __instance: None, __owner: type[_S] | None) -> _lru_cache_wrapper[_P, _T]: ... + @overload + def __get__(self, __instance: _S, __owner: type[_S] | None) -> Callable[Concatenate[_S, _P], _T]: ... if sys.version_info >= (3, 8): @overload - def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[..., _T]], _lru_cache_wrapper[_T]]: ... + def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[_P, _T]], _lru_cache_wrapper[_P, _T]]: ... @overload - def lru_cache(maxsize: Callable[..., _T], typed: bool = False) -> _lru_cache_wrapper[_T]: ... + def lru_cache(maxsize: Callable[_P, _T], typed: bool = False) -> _lru_cache_wrapper[_P, _T]: ... else: - def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[..., _T]], _lru_cache_wrapper[_T]]: ... + def lru_cache(maxsize: int | None = 128, typed: bool = False) -> Callable[[Callable[_P, _T]], _lru_cache_wrapper[_P, _T]]: ... if sys.version_info >= (3, 12): WRAPPER_ASSIGNMENTS: tuple[ @@ -208,7 +220,7 @@ if sys.version_info >= (3, 8): def __class_getitem__(cls, item: Any) -> GenericAlias: ... if sys.version_info >= (3, 9): - def cache(__user_function: Callable[..., _T]) -> _lru_cache_wrapper[_T]: ... + def cache(__user_function: Callable[_P, _T]) -> _lru_cache_wrapper[_P, _T]: ... def _make_key( args: tuple[Hashable, ...],