From 5cdb753f74f5807887c40e8aee138291d1f5b920 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 4 Jun 2024 03:02:44 +0200 Subject: [PATCH 01/15] Reject ParamSpec-typed callables calls with insufficient arguments --- mypy/checkexpr.py | 10 +++- .../unit/check-parameter-specification.test | 51 +++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 0a4af069ea17..7d300f3a2bcb 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1745,7 +1745,11 @@ def check_callable_call( ) param_spec = callee.param_spec() - if param_spec is not None and arg_kinds == [ARG_STAR, ARG_STAR2]: + if ( + param_spec is not None + and arg_kinds == [ARG_STAR, ARG_STAR2] + and len(formal_to_actual) == 2 + ): arg1 = self.accept(args[0]) arg2 = self.accept(args[1]) if ( @@ -2351,6 +2355,10 @@ def check_argument_count( # Positional argument when expecting a keyword argument. self.msg.too_many_positional_arguments(callee, context) ok = False + elif callee.param_spec() is not None: + if not formal_to_actual[i]: + self.msg.too_few_arguments(callee, context, actual_names) + ok = False return ok def check_for_extra_actual_arguments( diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index cab7d2bf6819..63a5e9cb1777 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2204,3 +2204,54 @@ parametrize(_test, Case(1, b=2), Case(3, b=4)) parametrize(_test, Case(1, 2), Case(3)) parametrize(_test, Case(1, 2), Case(3, b=4)) [builtins fixtures/paramspec.pyi] + +[case testRunParamSpecInsufficientArgs] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable + +_P = ParamSpec("_P") + +def run(predicate: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> None: + predicate() # E: Too few arguments + predicate(*args) # E: Too few arguments + predicate(**kwargs) # E: Too few arguments + predicate(*args, **kwargs) + +[builtins fixtures/paramspec.pyi] + +[case testRunParamSpecConcatenateInsufficientArgs] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable + +_P = ParamSpec("_P") + +def run(predicate: Callable[Concatenate[int, _P], str], *args: _P.args, **kwargs: _P.kwargs) -> None: + predicate() # E: Too few arguments + predicate(1) # E: Too few arguments + predicate(1, *args) # E: Too few arguments + predicate(1, *args) # E: Too few arguments + predicate(1, **kwargs) # E: Too few arguments + predicate(*args, **kwargs) # E: Argument 1 has incompatible type "*_P.args"; expected "int" + predicate(1, *args, **kwargs) + +[builtins fixtures/paramspec.pyi] + +[case testRunParamSpecConcatenateInsufficientArgsInDecorator] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable + +P = ParamSpec("P") + +def decorator(fn: Callable[Concatenate[str, P], None]) -> Callable[P, None]: + def inner(*args: P.args, **kwargs: P.kwargs) -> None: + fn("value") # E: Too few arguments + fn("value", *args) # E: Too few arguments + fn("value", **kwargs) # E: Too few arguments + fn(*args, **kwargs) # E: Argument 1 has incompatible type "*P.args"; expected "str" + fn("value", *args, **kwargs) + return inner + +@decorator +def foo(s: str, s2: str) -> None: ... + +[builtins fixtures/paramspec.pyi] From 3b2297f1b370549695dae895045282831b901a64 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 4 Jun 2024 04:46:21 +0200 Subject: [PATCH 02/15] Reuse params preprocessing logic for generic functions --- mypy/checkexpr.py | 74 ++++++++++++------- .../unit/check-parameter-specification.test | 64 +++++++++++++++- 2 files changed, 109 insertions(+), 29 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7d300f3a2bcb..fcf61d73579c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1716,33 +1716,9 @@ def check_callable_call( callee = callee.copy_modified(ret_type=fresh_ret_type) if callee.is_generic(): - need_refresh = any( - isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables + callee, formal_to_actual = self.adjust_generic_callable_params_mapping( + callee, args, arg_kinds, arg_names, formal_to_actual, context ) - callee = freshen_function_type_vars(callee) - callee = self.infer_function_type_arguments_using_context(callee, context) - if need_refresh: - # Argument kinds etc. may have changed due to - # ParamSpec or TypeVarTuple variables being replaced with an arbitrary - # number of arguments; recalculate actual-to-formal map - formal_to_actual = map_actuals_to_formals( - arg_kinds, - arg_names, - callee.arg_kinds, - callee.arg_names, - lambda i: self.accept(args[i]), - ) - callee = self.infer_function_type_arguments( - callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context - ) - if need_refresh: - formal_to_actual = map_actuals_to_formals( - arg_kinds, - arg_names, - callee.arg_kinds, - callee.arg_names, - lambda i: self.accept(args[i]), - ) param_spec = callee.param_spec() if ( @@ -2633,7 +2609,7 @@ def check_overload_call( arg_types = self.infer_arg_types_in_empty_context(args) # Step 1: Filter call targets to remove ones where the argument counts don't match plausible_targets = self.plausible_overload_call_targets( - arg_types, arg_kinds, arg_names, callee + args, arg_types, arg_kinds, arg_names, callee, context ) # Step 2: If the arguments contain a union, we try performing union math first, @@ -2751,12 +2727,52 @@ def check_overload_call( self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context) return result + def adjust_generic_callable_params_mapping( + self, + callee: CallableType, + args: list[Expression], + arg_kinds: list[ArgKind], + arg_names: Sequence[str | None] | None, + formal_to_actual: list[list[int]], + context: Context, + ) -> tuple[CallableType, list[list[int]]]: + need_refresh = any( + isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables + ) + callee = freshen_function_type_vars(callee) + callee = self.infer_function_type_arguments_using_context(callee, context) + if need_refresh: + # Argument kinds etc. may have changed due to + # ParamSpec or TypeVarTuple variables being replaced with an arbitrary + # number of arguments; recalculate actual-to-formal map + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee.arg_kinds, + callee.arg_names, + lambda i: self.accept(args[i]), + ) + callee = self.infer_function_type_arguments( + callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context + ) + if need_refresh: + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee.arg_kinds, + callee.arg_names, + lambda i: self.accept(args[i]), + ) + return callee, formal_to_actual + def plausible_overload_call_targets( self, + args: list[Expression], arg_types: list[Type], arg_kinds: list[ArgKind], arg_names: Sequence[str | None] | None, overload: Overloaded, + context: Context, ) -> list[CallableType]: """Returns all overload call targets that having matching argument counts. @@ -2790,6 +2806,10 @@ def has_shape(typ: Type) -> bool: formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, typ.arg_kinds, typ.arg_names, lambda i: arg_types[i] ) + if typ.is_generic(): + typ, formal_to_actual = self.adjust_generic_callable_params_mapping( + typ, args, arg_kinds, arg_names, formal_to_actual, context + ) with self.msg.filter_errors(): if self.check_argument_count( diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 63a5e9cb1777..fc12780aa89f 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2211,12 +2211,22 @@ from typing import Callable _P = ParamSpec("_P") -def run(predicate: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> None: +def run(predicate: Callable[_P, None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run" defined here predicate() # E: Too few arguments predicate(*args) # E: Too few arguments predicate(**kwargs) # E: Too few arguments predicate(*args, **kwargs) +def fn() -> None: ... +def fn_args(x: int) -> None: ... +def fn_posonly(x: int, /) -> None: ... + +run(fn) +run(fn_args, 1) +run(fn_args, x=1) +run(fn_posonly, 1) +run(fn_posonly, x=1) # E: Unexpected keyword argument "x" for "run" + [builtins fixtures/paramspec.pyi] [case testRunParamSpecConcatenateInsufficientArgs] @@ -2225,7 +2235,7 @@ from typing import Callable _P = ParamSpec("_P") -def run(predicate: Callable[Concatenate[int, _P], str], *args: _P.args, **kwargs: _P.kwargs) -> None: +def run(predicate: Callable[Concatenate[int, _P], None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run" defined here predicate() # E: Too few arguments predicate(1) # E: Too few arguments predicate(1, *args) # E: Too few arguments @@ -2234,6 +2244,22 @@ def run(predicate: Callable[Concatenate[int, _P], str], *args: _P.args, **kwargs predicate(*args, **kwargs) # E: Argument 1 has incompatible type "*_P.args"; expected "int" predicate(1, *args, **kwargs) +def fn() -> None: ... +def fn_args(x: int, y: str) -> None: ... +def fn_posonly(x: int, /) -> None: ... +def fn_posonly_args(x: int, /, y: str) -> None: ... + +run(fn) # E: Argument 1 to "run" has incompatible type "Callable[[], None]"; expected "Callable[[int], None]" +run(fn_args, 1, 'a') # E: Too many arguments for "run" \ + # E: Argument 2 to "run" has incompatible type "int"; expected "str" +run(fn_args, y='a') +run(fn_args, 'a') +run(fn_posonly) +run(fn_posonly, x=1) # E: Unexpected keyword argument "x" for "run" +run(fn_posonly_args) # E: Missing positional argument "y" in call to "run" +run(fn_posonly_args, 'a') +run(fn_posonly_args, y='a') + [builtins fixtures/paramspec.pyi] [case testRunParamSpecConcatenateInsufficientArgsInDecorator] @@ -2255,3 +2281,37 @@ def decorator(fn: Callable[Concatenate[str, P], None]) -> Callable[P, None]: def foo(s: str, s2: str) -> None: ... [builtins fixtures/paramspec.pyi] + +[case testRunParamSpecOverload] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable, overload, NoReturn, TypeVar, Union + +P = ParamSpec("P") +T = TypeVar("T") + +@overload +def capture( + sync_fn: Callable[P, NoReturn], + *args: P.args, + **kwargs: P.kwargs, +) -> int: ... +@overload +def capture( + sync_fn: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> Union[T, int]: ... +def capture( + sync_fn: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> Union[T, int]: + return sync_fn(*args, **kwargs) + +def fn() -> str: return '' +def err() -> NoReturn: ... + +reveal_type(capture(fn)) # N: Revealed type is "Union[builtins.str, builtins.int]" +reveal_type(capture(err)) # N: Revealed type is "builtins.int" + +[builtins fixtures/paramspec.pyi] From a32ad3f6618a93cd065429d9eabb6fa194204bd1 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 4 Jun 2024 05:52:36 +0200 Subject: [PATCH 03/15] Only perform deep expansion on overloads when ParamSpec is present --- mypy/checkexpr.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index fcf61d73579c..25d9c808e2d2 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2806,12 +2806,12 @@ def has_shape(typ: Type) -> bool: formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, typ.arg_kinds, typ.arg_names, lambda i: arg_types[i] ) - if typ.is_generic(): - typ, formal_to_actual = self.adjust_generic_callable_params_mapping( - typ, args, arg_kinds, arg_names, formal_to_actual, context - ) - with self.msg.filter_errors(): + if typ.is_generic() and typ.param_spec() is not None: + typ, formal_to_actual = self.adjust_generic_callable_params_mapping( + typ, args, arg_kinds, arg_names, formal_to_actual, context + ) + if self.check_argument_count( typ, arg_types, arg_kinds, arg_names, formal_to_actual, None ): From 63995e3178b2e8ce4f98ad52ea6848df173c33c8 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Mon, 10 Jun 2024 02:31:38 +0200 Subject: [PATCH 04/15] Tidy up code a bit --- mypy/checkexpr.py | 7 +++---- test-data/unit/check-parameter-specification.test | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 25d9c808e2d2..f3d21d24a428 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2331,10 +2331,9 @@ def check_argument_count( # Positional argument when expecting a keyword argument. self.msg.too_many_positional_arguments(callee, context) ok = False - elif callee.param_spec() is not None: - if not formal_to_actual[i]: - self.msg.too_few_arguments(callee, context, actual_names) - ok = False + elif callee.param_spec() is not None and not formal_to_actual[i]: + self.msg.too_few_arguments(callee, context, actual_names) + ok = False return ok def check_for_extra_actual_arguments( diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index fc12780aa89f..16dcff7f630c 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2283,8 +2283,8 @@ def foo(s: str, s2: str) -> None: ... [builtins fixtures/paramspec.pyi] [case testRunParamSpecOverload] -from typing_extensions import ParamSpec, Concatenate -from typing import Callable, overload, NoReturn, TypeVar, Union +from typing_extensions import ParamSpec +from typing import Callable, NoReturn, TypeVar, Union, overload P = ParamSpec("P") T = TypeVar("T") From 0a658a724dc87ae80f414b8b1d219371afb09f69 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 21 Aug 2024 15:33:39 +0200 Subject: [PATCH 05/15] Always pick ParamSpec-containing overloads as plausible candidates --- mypy/checkexpr.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d415195e433d..d9729384c465 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2786,9 +2786,9 @@ def plausible_overload_call_targets( ) -> list[CallableType]: """Returns all overload call targets that having matching argument counts. - If the given args contains a star-arg (*arg or **kwarg argument), this method - will ensure all star-arg overloads appear at the start of the list, instead - of their usual location. + If the given args contains a star-arg (*arg or **kwarg argument, including + ParamSpec), this method will ensure all star-arg overloads appear at the start + of the list, instead of their usual location. The only exception is if the starred argument is something like a Tuple or a NamedTuple, which has a definitive "shape". If so, we don't move the corresponding @@ -2817,12 +2817,12 @@ def has_shape(typ: Type) -> bool: arg_kinds, arg_names, typ.arg_kinds, typ.arg_names, lambda i: arg_types[i] ) with self.msg.filter_errors(): - if typ.is_generic() and typ.param_spec() is not None: - typ, formal_to_actual = self.adjust_generic_callable_params_mapping( - typ, args, arg_kinds, arg_names, formal_to_actual, context - ) - - if self.check_argument_count( + if typ.param_spec() is not None: + # ParamSpec can be expanded in a lot of different ways. We may try + # to expand it here instead, but picking an impossible overload + # is safe: it will be filtered out later. + star_matches.append(typ) + elif self.check_argument_count( typ, arg_types, arg_kinds, arg_names, formal_to_actual, None ): if args_have_var_arg and typ.is_var_arg: From 63f9438bfd08f6303e0b67878b46063c011b8162 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 21 Aug 2024 20:29:47 +0200 Subject: [PATCH 06/15] Remove parameters thaat are no longer used --- mypy/checkexpr.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d9729384c465..000a17693e9d 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2619,7 +2619,7 @@ def check_overload_call( arg_types = self.infer_arg_types_in_empty_context(args) # Step 1: Filter call targets to remove ones where the argument counts don't match plausible_targets = self.plausible_overload_call_targets( - args, arg_types, arg_kinds, arg_names, callee, context + arg_types, arg_kinds, arg_names, callee ) # Step 2: If the arguments contain a union, we try performing union math first, @@ -2777,12 +2777,10 @@ def adjust_generic_callable_params_mapping( def plausible_overload_call_targets( self, - args: list[Expression], arg_types: list[Type], arg_kinds: list[ArgKind], arg_names: Sequence[str | None] | None, overload: Overloaded, - context: Context, ) -> list[CallableType]: """Returns all overload call targets that having matching argument counts. From 190340217871982ae263591f4a552a9f09ad7067 Mon Sep 17 00:00:00 2001 From: Shantanu Jain Date: Mon, 23 Sep 2024 23:07:28 -0700 Subject: [PATCH 07/15] Undo style change to make it easier to review, feel free to add in separate PR --- mypy/checkexpr.py | 66 +++++++++++++++++++---------------------------- 1 file changed, 26 insertions(+), 40 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 000a17693e9d..ad1850175aab 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1727,9 +1727,33 @@ def check_callable_call( callee = callee.copy_modified(ret_type=fresh_ret_type) if callee.is_generic(): - callee, formal_to_actual = self.adjust_generic_callable_params_mapping( - callee, args, arg_kinds, arg_names, formal_to_actual, context + need_refresh = any( + isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables ) + callee = freshen_function_type_vars(callee) + callee = self.infer_function_type_arguments_using_context(callee, context) + if need_refresh: + # Argument kinds etc. may have changed due to + # ParamSpec or TypeVarTuple variables being replaced with an arbitrary + # number of arguments; recalculate actual-to-formal map + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee.arg_kinds, + callee.arg_names, + lambda i: self.accept(args[i]), + ) + callee = self.infer_function_type_arguments( + callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context + ) + if need_refresh: + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee.arg_kinds, + callee.arg_names, + lambda i: self.accept(args[i]), + ) param_spec = callee.param_spec() if ( @@ -2737,44 +2761,6 @@ def check_overload_call( self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context) return result - def adjust_generic_callable_params_mapping( - self, - callee: CallableType, - args: list[Expression], - arg_kinds: list[ArgKind], - arg_names: Sequence[str | None] | None, - formal_to_actual: list[list[int]], - context: Context, - ) -> tuple[CallableType, list[list[int]]]: - need_refresh = any( - isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables - ) - callee = freshen_function_type_vars(callee) - callee = self.infer_function_type_arguments_using_context(callee, context) - if need_refresh: - # Argument kinds etc. may have changed due to - # ParamSpec or TypeVarTuple variables being replaced with an arbitrary - # number of arguments; recalculate actual-to-formal map - formal_to_actual = map_actuals_to_formals( - arg_kinds, - arg_names, - callee.arg_kinds, - callee.arg_names, - lambda i: self.accept(args[i]), - ) - callee = self.infer_function_type_arguments( - callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context - ) - if need_refresh: - formal_to_actual = map_actuals_to_formals( - arg_kinds, - arg_names, - callee.arg_kinds, - callee.arg_names, - lambda i: self.accept(args[i]), - ) - return callee, formal_to_actual - def plausible_overload_call_targets( self, arg_types: list[Type], From 512a722561231ce216a1a1bf4196bba14e4125f9 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Mon, 10 Jun 2024 02:33:08 +0200 Subject: [PATCH 08/15] Support ParamSpec + functools.partial --- mypy/checkexpr.py | 18 ++++- mypy/plugins/functools.py | 5 ++ mypy/types.py | 13 +++- .../unit/check-parameter-specification.test | 74 +++++++++++++++++++ 4 files changed, 105 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ad1850175aab..c06457edba38 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2366,9 +2366,21 @@ def check_argument_count( # Positional argument when expecting a keyword argument. self.msg.too_many_positional_arguments(callee, context) ok = False - elif callee.param_spec() is not None and not formal_to_actual[i]: - self.msg.too_few_arguments(callee, context, actual_names) - ok = False + elif callee.param_spec() is not None: + if ( + not formal_to_actual[i] + and not callee.param_spec_parts_bound[kind == ArgKind.ARG_STAR2] + and callee.special_sig != "partial" + ): + self.msg.too_few_arguments(callee, context, actual_names) + ok = False + elif ( + formal_to_actual[i] + and kind == ArgKind.ARG_STAR + and callee.param_spec_parts_bound[0] + ): + self.msg.too_many_arguments(callee, context) + ok = False return ok def check_for_extra_actual_arguments( diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index 6650af637519..f61e152ced30 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -161,6 +161,7 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: for k in fn_type.arg_kinds ], ret_type=ret_type, + special_sig="partial", ) if defaulted.line < 0: # Make up a line number if we don't have one @@ -267,6 +268,10 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: arg_kinds=partial_kinds, arg_names=partial_names, ret_type=ret_type, + param_spec_parts_bound=( + ArgKind.ARG_STAR in actual_arg_kinds, + ArgKind.ARG_STAR2 in actual_arg_kinds, + ), ) ret = ctx.api.named_generic_type(PARTIAL, [ret_type]) diff --git a/mypy/types.py b/mypy/types.py index 78244d0f9cf4..1fc3e884d13b 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1811,8 +1811,8 @@ class CallableType(FunctionLike): "implicit", # Was this type implicitly generated instead of explicitly # specified by the user? "special_sig", # Non-None for signatures that require special handling - # (currently only value is 'dict' for a signature similar to - # 'dict') + # (currently only values are 'dict' for a signature similar to + # 'dict' and 'partial' for a `functools.partial` evaluation) "from_type_type", # Was this callable generated by analyzing Type[...] # instantiation? "bound_args", # Bound type args, mostly unused but may be useful for @@ -1825,6 +1825,7 @@ class CallableType(FunctionLike): # (this is used for error messages) "imprecise_arg_kinds", "unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable? + "param_spec_parts_bound", # Hack for functools.partial: allow early binding ) def __init__( @@ -1851,6 +1852,7 @@ def __init__( from_concatenate: bool = False, imprecise_arg_kinds: bool = False, unpack_kwargs: bool = False, + param_spec_parts_bound: tuple[bool, bool] = (False, False), ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) @@ -1877,6 +1879,7 @@ def __init__( self.from_type_type = from_type_type self.from_concatenate = from_concatenate self.imprecise_arg_kinds = imprecise_arg_kinds + self.param_spec_parts_bound = param_spec_parts_bound if not bound_args: bound_args = () self.bound_args = bound_args @@ -1923,6 +1926,7 @@ def copy_modified( from_concatenate: Bogus[bool] = _dummy, imprecise_arg_kinds: Bogus[bool] = _dummy, unpack_kwargs: Bogus[bool] = _dummy, + param_spec_parts_bound: Bogus[tuple[bool, bool]] = _dummy, ) -> CT: modified = CallableType( arg_types=arg_types if arg_types is not _dummy else self.arg_types, @@ -1954,6 +1958,11 @@ def copy_modified( else self.imprecise_arg_kinds ), unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs, + param_spec_parts_bound=( + param_spec_parts_bound + if param_spec_parts_bound is not _dummy + else self.param_spec_parts_bound + ), ) # Optimization: Only NewTypes are supported as subtypes since # the class is effectively final, so we can use a cast safely. diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 703ccfce0060..20500f7f5f28 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2294,3 +2294,77 @@ reveal_type(capture(fn)) # N: Revealed type is "Union[builtins.str, builtins.in reveal_type(capture(err)) # N: Revealed type is "builtins.int" [builtins fixtures/paramspec.pyi] + +[case testBindPartial] +from functools import partial +from typing_extensions import ParamSpec +from typing import Callable, TypeVar + +P = ParamSpec("P") +T = TypeVar("T") + +def run(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, **kwargs) + return func2(*args) + +def run2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args) + return func2(**kwargs) + +def run3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args, **kwargs) + return func2() + +def run4(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args, **kwargs) + return func2(**kwargs) + +def run_bad(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args, **kwargs) + return func2(*args) # E: Too many arguments + +def run_bad2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, **kwargs) + return func2(**kwargs) # E: Too few arguments + +def run_bad3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args) + return func2() # E: Too few arguments + +[builtins fixtures/paramspec.pyi] + +[case testBindPartialConcatenate] +from functools import partial +from typing_extensions import Concatenate, ParamSpec +from typing import Callable, TypeVar + +P = ParamSpec("P") +T = TypeVar("T") + +def run(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1, **kwargs) + return func2(*args) + +def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, **kwargs) + return func2(1, *args) + +def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1, *args) + return func2(**kwargs) + +def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int" + return func2(1, **kwargs) # E: Too many arguments + +def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1, *args) + return func2(1, **kwargs) # E: Too many arguments \ + # E: Argument 1 has incompatible type "int"; expected "P.args" + +def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1, *args) + return func2(1, **kwargs) # E: Too many arguments \ + # E: Argument 1 has incompatible type "int"; expected "P.args" + +[builtins fixtures/paramspec.pyi] From 01390c0ba2bd3d034c00e4e919c56a202577fe27 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 25 Sep 2024 18:45:54 +0200 Subject: [PATCH 09/15] Fix lost error in test --- test-data/unit/check-parameter-specification.test | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index caf4f312e5f8..2ba2e0a623c0 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2364,7 +2364,8 @@ def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwar def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int" - return func2(1, **kwargs) # E: Too many arguments + return func2(1, **kwargs) # E: Too many arguments \ + # E: Argument 1 has incompatible type "int"; expected "P.args" def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, 1, *args) @@ -2376,4 +2377,6 @@ def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P. return func2(1, **kwargs) # E: Too many arguments \ # E: Argument 1 has incompatible type "int"; expected "P.args" + + [builtins fixtures/paramspec.pyi] From bf60ec1792edaf450f1277871f9434b5d5c074c8 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 17 Oct 2024 23:09:39 +0200 Subject: [PATCH 10/15] Remove param_spec_args_bound - new version produces worse error messages, but is much less intrusive --- mypy/checkexpr.py | 22 ++++--------- mypy/plugins/functools.py | 32 +++++++++++++++---- mypy/types.py | 9 ------ .../unit/check-parameter-specification.test | 17 +++++----- 4 files changed, 41 insertions(+), 39 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 1427f0c9b27d..73fdbfeb7aba 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2377,21 +2377,13 @@ def check_argument_count( # Positional argument when expecting a keyword argument. self.msg.too_many_positional_arguments(callee, context) ok = False - elif callee.param_spec() is not None: - if ( - not formal_to_actual[i] - and not callee.param_spec_parts_bound[kind == ArgKind.ARG_STAR2] - and callee.special_sig != "partial" - ): - self.msg.too_few_arguments(callee, context, actual_names) - ok = False - elif ( - formal_to_actual[i] - and kind == ArgKind.ARG_STAR - and callee.param_spec_parts_bound[0] - ): - self.msg.too_many_arguments(callee, context) - ok = False + elif ( + callee.param_spec() is not None + and not formal_to_actual[i] + and callee.special_sig != "partial" + ): + self.msg.too_few_arguments(callee, context, actual_names) + ok = False return ok def check_for_extra_actual_arguments( diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index 1c6d10d22efb..ca0c226fa6cb 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -202,6 +202,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) - continue can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)}) + # special_sig="partial" allows omission of args/kwargs typed with ParamSpec defaulted = fn_type.copy_modified( arg_kinds=[ ( @@ -297,14 +298,17 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) - arg_kinds=partial_kinds, arg_names=partial_names, ret_type=ret_type, - param_spec_parts_bound=( - ArgKind.ARG_STAR in actual_arg_kinds, - ArgKind.ARG_STAR2 in actual_arg_kinds, - ), + special_sig="partial", ) ret = ctx.api.named_generic_type(PARTIAL, [ret_type]) ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied) + if partially_applied.param_spec(): + ret = ret.copy_with_extra_attr( + "__mypy_partial_paramspec_args_bound", ArgKind.ARG_STAR in actual_arg_kinds + ).copy_with_extra_attr( + "__mypy_partial_paramspec_kwargs_bound", ArgKind.ARG_STAR2 in actual_arg_kinds + ) return ret @@ -319,7 +323,8 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: ): return ctx.default_return_type - partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"] + extra_attrs = ctx.type.extra_attrs.attrs + partial_type = extra_attrs["__mypy_partial"] if len(ctx.arg_types) != 2: # *args, **kwargs return ctx.default_return_type @@ -337,11 +342,24 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: actual_arg_kinds.append(ctx.arg_kinds[i][j]) actual_arg_names.append(ctx.arg_names[i][j]) - result = ctx.api.expr_checker.check_call( + result, _ = ctx.api.expr_checker.check_call( callee=partial_type, args=actual_args, arg_kinds=actual_arg_kinds, arg_names=actual_arg_names, context=ctx.context, ) - return result[0] + args_bound = extra_attrs.get("__mypy_partial_paramspec_args_bound") + kwargs_bound = extra_attrs.get("__mypy_partial_paramspec_kwargs_bound") + if args_bound is None or kwargs_bound is None: + return result + # ensure *args: P.args + if not args_bound and ArgKind.ARG_STAR not in actual_arg_kinds: + ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names) + elif args_bound and ArgKind.ARG_STAR in actual_arg_kinds: + ctx.api.expr_checker.msg.too_many_arguments(partial_type, ctx.context) + # ensure **kwargs: P.kwargs + if not kwargs_bound and ArgKind.ARG_STAR2 not in actual_arg_kinds: + ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names) + + return result diff --git a/mypy/types.py b/mypy/types.py index 02c5dc477edf..0b010ca9d593 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1841,7 +1841,6 @@ class CallableType(FunctionLike): # (this is used for error messages) "imprecise_arg_kinds", "unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable? - "param_spec_parts_bound", # Hack for functools.partial: allow early binding ) def __init__( @@ -1868,7 +1867,6 @@ def __init__( from_concatenate: bool = False, imprecise_arg_kinds: bool = False, unpack_kwargs: bool = False, - param_spec_parts_bound: tuple[bool, bool] = (False, False), ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) @@ -1895,7 +1893,6 @@ def __init__( self.from_type_type = from_type_type self.from_concatenate = from_concatenate self.imprecise_arg_kinds = imprecise_arg_kinds - self.param_spec_parts_bound = param_spec_parts_bound if not bound_args: bound_args = () self.bound_args = bound_args @@ -1942,7 +1939,6 @@ def copy_modified( from_concatenate: Bogus[bool] = _dummy, imprecise_arg_kinds: Bogus[bool] = _dummy, unpack_kwargs: Bogus[bool] = _dummy, - param_spec_parts_bound: Bogus[tuple[bool, bool]] = _dummy, ) -> CT: modified = CallableType( arg_types=arg_types if arg_types is not _dummy else self.arg_types, @@ -1974,11 +1970,6 @@ def copy_modified( else self.imprecise_arg_kinds ), unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs, - param_spec_parts_bound=( - param_spec_parts_bound - if param_spec_parts_bound is not _dummy - else self.param_spec_parts_bound - ), ) # Optimization: Only NewTypes are supported as subtypes since # the class is effectively final, so we can use a cast safely. diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 2ba2e0a623c0..0569a6dae194 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2356,27 +2356,28 @@ def run(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwarg def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, **kwargs) + p = [""] + func2(1, *p) # E: Argument 2 has incompatible type "*List[str]"; expected "P.args" + func2(1, 2, *p) # E: Argument 2 has incompatible type "int"; expected "P.args" \ + # E: Argument 3 has incompatible type "*List[str]"; expected "P.args" return func2(1, *args) def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, 1, *args) + d = {"":""} + func2(**d) # E: Argument 1 has incompatible type "**Dict[str, str]"; expected "P.kwargs" return func2(**kwargs) def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int" - return func2(1, **kwargs) # E: Too many arguments \ - # E: Argument 1 has incompatible type "int"; expected "P.args" + return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args" def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, 1, *args) - return func2(1, **kwargs) # E: Too many arguments \ - # E: Argument 1 has incompatible type "int"; expected "P.args" + return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args" def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, 1, *args) - return func2(1, **kwargs) # E: Too many arguments \ - # E: Argument 1 has incompatible type "int"; expected "P.args" - - + return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args" [builtins fixtures/paramspec.pyi] From a660f1f45de17cd60a6c0ba08b255f1a45147102 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 17 Oct 2024 23:12:48 +0200 Subject: [PATCH 11/15] Add test scenario --- test-data/unit/check-parameter-specification.test | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 0569a6dae194..4bc2bbc25997 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2360,6 +2360,8 @@ def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwar func2(1, *p) # E: Argument 2 has incompatible type "*List[str]"; expected "P.args" func2(1, 2, *p) # E: Argument 2 has incompatible type "int"; expected "P.args" \ # E: Argument 3 has incompatible type "*List[str]"; expected "P.args" + func2(1, *args, *p) # E: Argument 3 has incompatible type "*List[str]"; expected "P.args" + func2(1, *p, *args) # E: Argument 2 has incompatible type "*List[str]"; expected "P.args" return func2(1, *args) def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: @@ -2380,4 +2382,5 @@ def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P. func2 = partial(func, 1, *args) return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args" + [builtins fixtures/paramspec.pyi] From 5d46d85862af146f00bf4bee5a44bee402d5a40f Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 17 Oct 2024 23:38:48 +0200 Subject: [PATCH 12/15] Fix typing, use `immutable` for storing binding information --- mypy/plugins/functools.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index ca0c226fa6cb..a18f38bf6918 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -304,11 +304,13 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) - ret = ctx.api.named_generic_type(PARTIAL, [ret_type]) ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied) if partially_applied.param_spec(): - ret = ret.copy_with_extra_attr( - "__mypy_partial_paramspec_args_bound", ArgKind.ARG_STAR in actual_arg_kinds - ).copy_with_extra_attr( - "__mypy_partial_paramspec_kwargs_bound", ArgKind.ARG_STAR2 in actual_arg_kinds - ) + assert ret.extra_attrs is not None # copy_with_extra_attr above ensures this + attrs = ret.extra_attrs.copy() + if ArgKind.ARG_STAR in actual_arg_kinds: + attrs.immutable.add("__mypy_partial_paramspec_args_bound") + if ArgKind.ARG_STAR2 in actual_arg_kinds: + attrs.immutable.add("__mypy_partial_paramspec_kwargs_bound") + ret.extra_attrs = attrs return ret @@ -323,8 +325,8 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: ): return ctx.default_return_type - extra_attrs = ctx.type.extra_attrs.attrs - partial_type = extra_attrs["__mypy_partial"] + extra_attrs = ctx.type.extra_attrs + partial_type = get_proper_type(extra_attrs.attrs["__mypy_partial"]) if len(ctx.arg_types) != 2: # *args, **kwargs return ctx.default_return_type @@ -349,10 +351,11 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: arg_names=actual_arg_names, context=ctx.context, ) - args_bound = extra_attrs.get("__mypy_partial_paramspec_args_bound") - kwargs_bound = extra_attrs.get("__mypy_partial_paramspec_kwargs_bound") - if args_bound is None or kwargs_bound is None: + if not isinstance(partial_type, CallableType) or partial_type.param_spec() is None: return result + + args_bound = "__mypy_partial_paramspec_args_bound" in extra_attrs.immutable + kwargs_bound = "__mypy_partial_paramspec_kwargs_bound" in extra_attrs.immutable # ensure *args: P.args if not args_bound and ArgKind.ARG_STAR not in actual_arg_kinds: ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names) From 4d55122c0651069c78662ab7046f3cb291f141c8 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 24 Oct 2024 20:07:38 +0200 Subject: [PATCH 13/15] Fix duplicated testcase --- .../unit/check-parameter-specification.test | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 4bc2bbc25997..f5b77f181186 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2370,17 +2370,38 @@ def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwar func2(**d) # E: Argument 1 has incompatible type "**Dict[str, str]"; expected "P.kwargs" return func2(**kwargs) +def run4(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1) + return func2(*args, **kwargs) + +def run5(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1, *args, **kwargs) + func2() + return func2(**kwargs) + def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int" return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args" def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, 1, *args) + func2() # E: Too few arguments + func2(*args, **kwargs) # E: Too many arguments return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args" def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: - func2 = partial(func, 1, *args) - return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args" - + func2 = partial(func, 1, **kwargs) + func2() # E: Too few arguments + return func2(1, *args) # E: Argument 1 has incompatible type "int"; expected "P.args" + +def run_bad4(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, 1) + func2() # E: Too few arguments + func2(*args) # E: Too few arguments + func2(1, *args) # E: Too few arguments \ + # E: Argument 1 has incompatible type "int"; expected "P.args" + func2(1, **kwargs) # E: Too few arguments \ + # E: Argument 1 has incompatible type "int"; expected "P.args" + return func2(**kwargs) # E: Too few arguments [builtins fixtures/paramspec.pyi] From aa1391ca9890d0d0ad9a01af74f7946a147e2777 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Fri, 25 Oct 2024 17:22:11 +0200 Subject: [PATCH 14/15] Replace naive check by arg kinds with more robust type-aware check --- mypy/plugins/functools.py | 28 +++++++++++++++---- .../unit/check-parameter-specification.test | 28 +++++++++++++++++-- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index a18f38bf6918..045578e76387 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -8,7 +8,7 @@ import mypy.plugin import mypy.semanal from mypy.argmap import map_actuals_to_formals -from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var +from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, NameExpr, Var from mypy.plugins.common import add_method_to_class from mypy.typeops import get_all_type_vars from mypy.types import ( @@ -16,6 +16,8 @@ CallableType, Instance, Overloaded, + ParamSpecFlavor, + ParamSpecType, Type, TypeOfAny, TypeVarType, @@ -344,7 +346,7 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: actual_arg_kinds.append(ctx.arg_kinds[i][j]) actual_arg_names.append(ctx.arg_names[i][j]) - result, _ = ctx.api.expr_checker.check_call( + result, inf = ctx.api.expr_checker.check_call( callee=partial_type, args=actual_args, arg_kinds=actual_arg_kinds, @@ -356,13 +358,29 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: args_bound = "__mypy_partial_paramspec_args_bound" in extra_attrs.immutable kwargs_bound = "__mypy_partial_paramspec_kwargs_bound" in extra_attrs.immutable + # ensure *args: P.args - if not args_bound and ArgKind.ARG_STAR not in actual_arg_kinds: + args_passed = any( + isinstance(arg, NameExpr) + and isinstance(arg.node, Var) + and isinstance(arg.node.type, ParamSpecType) + and arg.node.type.flavor == ParamSpecFlavor.ARGS + for arg in actual_args + ) + if not args_bound and not args_passed: ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names) - elif args_bound and ArgKind.ARG_STAR in actual_arg_kinds: + elif args_bound and args_passed: ctx.api.expr_checker.msg.too_many_arguments(partial_type, ctx.context) + # ensure **kwargs: P.kwargs - if not kwargs_bound and ArgKind.ARG_STAR2 not in actual_arg_kinds: + kwargs_passed = any( + isinstance(arg, NameExpr) + and isinstance(arg.node, Var) + and isinstance(arg.node.type, ParamSpecType) + and arg.node.type.flavor == ParamSpecFlavor.KWARGS + for arg in actual_args + ) + if not kwargs_bound and not kwargs_passed: ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names) return result diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 059a1efa5428..674e3894940b 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2392,8 +2392,10 @@ def run(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwarg def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, **kwargs) p = [""] - func2(1, *p) # E: Argument 2 has incompatible type "*List[str]"; expected "P.args" - func2(1, 2, *p) # E: Argument 2 has incompatible type "int"; expected "P.args" \ + func2(1, *p) # E: Too few arguments \ + # E: Argument 2 has incompatible type "*List[str]"; expected "P.args" + func2(1, 2, *p) # E: Too few arguments \ + # E: Argument 2 has incompatible type "int"; expected "P.args" \ # E: Argument 3 has incompatible type "*List[str]"; expected "P.args" func2(1, *args, *p) # E: Argument 3 has incompatible type "*List[str]"; expected "P.args" func2(1, *p, *args) # E: Argument 2 has incompatible type "*List[str]"; expected "P.args" @@ -2402,7 +2404,8 @@ def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwar def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: func2 = partial(func, 1, *args) d = {"":""} - func2(**d) # E: Argument 1 has incompatible type "**Dict[str, str]"; expected "P.kwargs" + func2(**d) # E: Too few arguments \ + # E: Argument 1 has incompatible type "**Dict[str, str]"; expected "P.kwargs" return func2(**kwargs) def run4(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T: @@ -2440,3 +2443,22 @@ def run_bad4(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P. return func2(**kwargs) # E: Too few arguments [builtins fixtures/paramspec.pyi] + +[case testOtherVarArgs] +from functools import partial +from typing_extensions import Concatenate, ParamSpec +from typing import Callable, TypeVar, Tuple + +P = ParamSpec("P") +T = TypeVar("T") + +def run(func: Callable[Concatenate[int, str, P], T], *args: P.args, **kwargs: P.kwargs) -> T: + func2 = partial(func, **kwargs) + args_prefix: Tuple[int, str] = (1, 'a') + func2(*args_prefix) # E: Too few arguments + func2(*args, *args_prefix) # E: Argument 1 has incompatible type "*P.args"; expected "int" \ + # E: Argument 1 has incompatible type "*P.args"; expected "str" \ + # E: Argument 2 has incompatible type "*Tuple[int, str]"; expected "P.args" + return func2(*args_prefix, *args) + +[builtins fixtures/paramspec.pyi] From 31a74927b7beab3003f8702f844c8727b905f8a5 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Fri, 25 Oct 2024 17:29:37 +0200 Subject: [PATCH 15/15] Deduplicate isinstance() checks --- mypy/plugins/functools.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index 045578e76387..6a063174bfcb 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -346,7 +346,7 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: actual_arg_kinds.append(ctx.arg_kinds[i][j]) actual_arg_names.append(ctx.arg_names[i][j]) - result, inf = ctx.api.expr_checker.check_call( + result, _ = ctx.api.expr_checker.check_call( callee=partial_type, args=actual_args, arg_kinds=actual_arg_kinds, @@ -359,27 +359,22 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type: args_bound = "__mypy_partial_paramspec_args_bound" in extra_attrs.immutable kwargs_bound = "__mypy_partial_paramspec_kwargs_bound" in extra_attrs.immutable - # ensure *args: P.args - args_passed = any( - isinstance(arg, NameExpr) + passed_paramspec_parts = [ + arg.node.type + for arg in actual_args + if isinstance(arg, NameExpr) and isinstance(arg.node, Var) and isinstance(arg.node.type, ParamSpecType) - and arg.node.type.flavor == ParamSpecFlavor.ARGS - for arg in actual_args - ) + ] + # ensure *args: P.args + args_passed = any(part.flavor == ParamSpecFlavor.ARGS for part in passed_paramspec_parts) if not args_bound and not args_passed: ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names) elif args_bound and args_passed: ctx.api.expr_checker.msg.too_many_arguments(partial_type, ctx.context) # ensure **kwargs: P.kwargs - kwargs_passed = any( - isinstance(arg, NameExpr) - and isinstance(arg.node, Var) - and isinstance(arg.node.type, ParamSpecType) - and arg.node.type.flavor == ParamSpecFlavor.KWARGS - for arg in actual_args - ) + kwargs_passed = any(part.flavor == ParamSpecFlavor.KWARGS for part in passed_paramspec_parts) if not kwargs_bound and not kwargs_passed: ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)