From acb3548db5b4d44dbdc6c9bafdf7247fb6b28395 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Mon, 10 Jun 2024 02:33:08 +0200 Subject: [PATCH] Support ParamSpec + functools.partial --- mypy/checkexpr.py | 18 ++++- mypy/plugins/functools.py | 7 +- mypy/types.py | 13 +++- .../unit/check-parameter-specification.test | 74 +++++++++++++++++++ 4 files changed, 106 insertions(+), 6 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index f3d21d24a4287..33509300d430a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2331,9 +2331,21 @@ def check_argument_count( # Positional argument when expecting a keyword argument. self.msg.too_many_positional_arguments(callee, context) ok = False - elif callee.param_spec() is not None and not formal_to_actual[i]: - self.msg.too_few_arguments(callee, context, actual_names) - ok = False + elif callee.param_spec() is not None: + if ( + not formal_to_actual[i] + and not callee.param_spec_parts_bound[kind == ArgKind.ARG_STAR2] + and callee.special_sig != "partial" + ): + self.msg.too_few_arguments(callee, context, actual_names) + ok = False + elif ( + formal_to_actual[i] + and kind == ArgKind.ARG_STAR + and callee.param_spec_parts_bound[0] + ): + self.msg.too_many_arguments(callee, context) + ok = False return ok def check_for_extra_actual_arguments( diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index 81a3b4d96ef3b..7ff30ab2d4f4b 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -140,7 +140,8 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k) ) for k in fn_type.arg_kinds - ] + ], + special_sig="partial", ) if defaulted.line < 0: # Make up a line number if we don't have one @@ -208,6 +209,10 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: arg_kinds=partial_kinds, arg_names=partial_names, ret_type=ret_type, + param_spec_parts_bound=( + ArgKind.ARG_STAR in actual_arg_kinds, + ArgKind.ARG_STAR2 in actual_arg_kinds, + ), ) ret = ctx.api.named_generic_type("functools.partial", [ret_type]) diff --git a/mypy/types.py b/mypy/types.py index 2cacc3e440850..a138a337f2fea 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1773,8 +1773,8 @@ class CallableType(FunctionLike): "implicit", # Was this type implicitly generated instead of explicitly # specified by the user? "special_sig", # Non-None for signatures that require special handling - # (currently only value is 'dict' for a signature similar to - # 'dict') + # (currently only values are 'dict' for a signature similar to + # 'dict' and 'partial' for a `functools.partial` evaluation) "from_type_type", # Was this callable generated by analyzing Type[...] # instantiation? "bound_args", # Bound type args, mostly unused but may be useful for @@ -1787,6 +1787,7 @@ class CallableType(FunctionLike): # (this is used for error messages) "imprecise_arg_kinds", "unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable? + "param_spec_parts_bound", # Hack for functools.partial: allow early binding ) def __init__( @@ -1813,6 +1814,7 @@ def __init__( from_concatenate: bool = False, imprecise_arg_kinds: bool = False, unpack_kwargs: bool = False, + param_spec_parts_bound: tuple[bool, bool] = (False, False), ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) @@ -1839,6 +1841,7 @@ def __init__( self.from_type_type = from_type_type self.from_concatenate = from_concatenate self.imprecise_arg_kinds = imprecise_arg_kinds + self.param_spec_parts_bound = param_spec_parts_bound if not bound_args: bound_args = () self.bound_args = bound_args @@ -1885,6 +1888,7 @@ def copy_modified( from_concatenate: Bogus[bool] = _dummy, imprecise_arg_kinds: Bogus[bool] = _dummy, unpack_kwargs: Bogus[bool] = _dummy, + param_spec_parts_bound: Bogus[tuple[bool, bool]] = _dummy, ) -> CT: modified = CallableType( arg_types=arg_types if arg_types is not _dummy else self.arg_types, @@ -1916,6 +1920,11 @@ def copy_modified( else self.imprecise_arg_kinds ), unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs, + param_spec_parts_bound=( + param_spec_parts_bound + if param_spec_parts_bound is not _dummy + else self.param_spec_parts_bound + ), ) # Optimization: Only NewTypes are supported as subtypes since # the class is effectively final, so we can use a cast safely. diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 16dcff7f630c2..1ea60b700cf1e 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2315,3 +2315,77 @@ reveal_type(capture(fn)) # N: Revealed type is "Union[builtins.str, builtins.in reveal_type(capture(err)) # N: Revealed type is "builtins.int" [builtins fixtures/paramspec.pyi] + +[case testBindPartial] +from functools import partial +from typing_extensions import ParamSpec +from typing import Callable, TypeVar + +P = ParamSpec("P") +T = TypeVar("T") + +def run(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, **kwargs) + return func2(*args) + +def run2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args) + return func2(**kwargs) + +def run3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args, **kwargs) + return func2() + +def run4(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args, **kwargs) + return func2(**kwargs) + +def run_bad(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args, **kwargs) + return func2(*args) # E: Too many arguments + +def run_bad2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, **kwargs) + return func2(**kwargs) # E: Too few arguments + +def run_bad3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args) + return func2() # E: Too few arguments + +[builtins fixtures/paramspec.pyi] + +[case testBindPartialConcatenate] +from functools import partial +from typing_extensions import Concatenate, ParamSpec +from typing import Callable, TypeVar + +P = ParamSpec("P") +T = TypeVar("T") + +def run(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1, **kwargs) + return func2(*args) + +def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, **kwargs) + return func2(1, *args) + +def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1, *args) + return func2(**kwargs) + +def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int" + return func2(1, **kwargs) # E: Too many arguments + +def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1, *args) + return func2(1, **kwargs) # E: Too many arguments \ + # E: Argument 1 has incompatible type "int"; expected "P.args" + +def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1, *args) + return func2(1, **kwargs) # E: Too many arguments \ + # E: Argument 1 has incompatible type "int"; expected "P.args" + +[builtins fixtures/paramspec.pyi]