Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ParamSpec mapping with functools.partial #17355

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5cdb753
Reject ParamSpec-typed callables calls with insufficient arguments
sterliakov Jun 4, 2024
3b2297f
Reuse params preprocessing logic for generic functions
sterliakov Jun 4, 2024
a32ad3f
Only perform deep expansion on overloads when ParamSpec is present
sterliakov Jun 4, 2024
63995e3
Tidy up code a bit
sterliakov Jun 10, 2024
be2c49b
Merge remote-tracking branch 'upstream/master' into bugfix/st-paramsp…
sterliakov Aug 21, 2024
0a658a7
Always pick ParamSpec-containing overloads as plausible candidates
sterliakov Aug 21, 2024
63f9438
Remove parameters thaat are no longer used
sterliakov Aug 21, 2024
786fb55
Merge branch 'master' into bugfix/st-paramspec-missing-args
sterliakov Sep 12, 2024
1903402
Undo style change to make it easier to review, feel free to add in se…
hauntsaninja Sep 24, 2024
512a722
Support ParamSpec + functools.partial
sterliakov Jun 10, 2024
558a35f
Merge remote-tracking branch 'upstream/master' into bugfix/st-paramsp…
sterliakov Sep 25, 2024
01390c0
Fix lost error in test
sterliakov Sep 25, 2024
b23efe8
Merge remote-tracking branch 'upstream/master' into bugfix/st-paramsp…
sterliakov Oct 17, 2024
bf60ec1
Remove param_spec_args_bound - new version produces worse error messa…
sterliakov Oct 17, 2024
a660f1f
Add test scenario
sterliakov Oct 17, 2024
5d46d85
Fix typing, use `immutable` for storing binding information
sterliakov Oct 17, 2024
bbaf9de
Merge remote-tracking branch 'upstream/master' into bugfix/st-paramsp…
sterliakov Oct 24, 2024
4d55122
Fix duplicated testcase
sterliakov Oct 24, 2024
b077063
Merge remote-tracking branch 'upstream/master' into bugfix/st-paramsp…
sterliakov Oct 25, 2024
aa1391c
Replace naive check by arg kinds with more robust type-aware check
sterliakov Oct 25, 2024
31a7492
Deduplicate isinstance() checks
sterliakov Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2377,7 +2377,11 @@ def check_argument_count(
# Positional argument when expecting a keyword argument.
self.msg.too_many_positional_arguments(callee, context)
ok = False
elif callee.param_spec() is not None and not formal_to_actual[i]:
elif (
callee.param_spec() is not None
and not formal_to_actual[i]
and callee.special_sig != "partial"
):
self.msg.too_few_arguments(callee, context, actual_names)
ok = False
return ok
Expand Down
47 changes: 43 additions & 4 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import mypy.plugin
import mypy.semanal
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, NameExpr, Var
from mypy.plugins.common import add_method_to_class
from mypy.typeops import get_all_type_vars
from mypy.types import (
AnyType,
CallableType,
Instance,
Overloaded,
ParamSpecFlavor,
ParamSpecType,
Type,
TypeOfAny,
TypeVarType,
Expand Down Expand Up @@ -202,6 +204,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
continue
can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})

# special_sig="partial" allows omission of args/kwargs typed with ParamSpec
defaulted = fn_type.copy_modified(
arg_kinds=[
(
Expand All @@ -218,6 +221,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
# Keep TypeVarTuple/ParamSpec to avoid spurious errors on empty args.
if tv.id in can_infer_ids or not isinstance(tv, TypeVarType)
],
special_sig="partial",
)
if defaulted.line < 0:
# Make up a line number if we don't have one
Expand Down Expand Up @@ -296,10 +300,19 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
arg_kinds=partial_kinds,
arg_names=partial_names,
ret_type=ret_type,
special_sig="partial",
)

ret = ctx.api.named_generic_type(PARTIAL, [ret_type])
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
if partially_applied.param_spec():
assert ret.extra_attrs is not None # copy_with_extra_attr above ensures this
attrs = ret.extra_attrs.copy()
if ArgKind.ARG_STAR in actual_arg_kinds:
attrs.immutable.add("__mypy_partial_paramspec_args_bound")
if ArgKind.ARG_STAR2 in actual_arg_kinds:
attrs.immutable.add("__mypy_partial_paramspec_kwargs_bound")
ret.extra_attrs = attrs
return ret


Expand All @@ -314,7 +327,8 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
):
return ctx.default_return_type

partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"]
extra_attrs = ctx.type.extra_attrs
partial_type = get_proper_type(extra_attrs.attrs["__mypy_partial"])
if len(ctx.arg_types) != 2: # *args, **kwargs
return ctx.default_return_type

Expand All @@ -332,11 +346,36 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])

result = ctx.api.expr_checker.check_call(
result, _ = ctx.api.expr_checker.check_call(
callee=partial_type,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=ctx.context,
)
return result[0]
if not isinstance(partial_type, CallableType) or partial_type.param_spec() is None:
return result

args_bound = "__mypy_partial_paramspec_args_bound" in extra_attrs.immutable
kwargs_bound = "__mypy_partial_paramspec_kwargs_bound" in extra_attrs.immutable

passed_paramspec_parts = [
arg.node.type
for arg in actual_args
if isinstance(arg, NameExpr)
and isinstance(arg.node, Var)
and isinstance(arg.node.type, ParamSpecType)
]
# ensure *args: P.args
args_passed = any(part.flavor == ParamSpecFlavor.ARGS for part in passed_paramspec_parts)
if not args_bound and not args_passed:
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
elif args_bound and args_passed:
ctx.api.expr_checker.msg.too_many_arguments(partial_type, ctx.context)

# ensure **kwargs: P.kwargs
kwargs_passed = any(part.flavor == ParamSpecFlavor.KWARGS for part in passed_paramspec_parts)
if not kwargs_bound and not kwargs_passed:
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)

return result
4 changes: 2 additions & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,8 +1827,8 @@ class CallableType(FunctionLike):
"implicit", # Was this type implicitly generated instead of explicitly
# specified by the user?
"special_sig", # Non-None for signatures that require special handling
# (currently only value is 'dict' for a signature similar to
# 'dict')
# (currently only values are 'dict' for a signature similar to
# 'dict' and 'partial' for a `functools.partial` evaluation)
"from_type_type", # Was this callable generated by analyzing Type[...]
# instantiation?
"bound_args", # Bound type args, mostly unused but may be useful for
Expand Down
124 changes: 124 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2338,3 +2338,127 @@ reveal_type(handle_reversed(Child())) # N: Revealed type is "builtins.str"
reveal_type(handle_reversed(NotChild())) # N: Revealed type is "builtins.str"

[builtins fixtures/paramspec.pyi]

[case testBindPartial]
from functools import partial
from typing_extensions import ParamSpec
from typing import Callable, TypeVar

P = ParamSpec("P")
T = TypeVar("T")

def run(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
return func2(*args)

def run2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args)
return func2(**kwargs)

def run3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args, **kwargs)
return func2()

def run4(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args, **kwargs)
return func2(**kwargs)

def run_bad(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args, **kwargs)
return func2(*args) # E: Too many arguments

def run_bad2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
return func2(**kwargs) # E: Too few arguments

def run_bad3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args)
return func2() # E: Too few arguments

[builtins fixtures/paramspec.pyi]

[case testBindPartialConcatenate]
from functools import partial
from typing_extensions import Concatenate, ParamSpec
from typing import Callable, TypeVar

P = ParamSpec("P")
T = TypeVar("T")

def run(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, **kwargs)
return func2(*args)

def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
p = [""]
func2(1, *p) # E: Too few arguments \
# E: Argument 2 has incompatible type "*List[str]"; expected "P.args"
func2(1, 2, *p) # E: Too few arguments \
# E: Argument 2 has incompatible type "int"; expected "P.args" \
# E: Argument 3 has incompatible type "*List[str]"; expected "P.args"
func2(1, *args, *p) # E: Argument 3 has incompatible type "*List[str]"; expected "P.args"
func2(1, *p, *args) # E: Argument 2 has incompatible type "*List[str]"; expected "P.args"
return func2(1, *args)

def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, *args)
d = {"":""}
func2(**d) # E: Too few arguments \
# E: Argument 1 has incompatible type "**Dict[str, str]"; expected "P.kwargs"
return func2(**kwargs)

def run4(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1)
return func2(*args, **kwargs)

def run5(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, *args, **kwargs)
func2()
return func2(**kwargs)

def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int"
return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args"

def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, *args)
func2() # E: Too few arguments
func2(*args, **kwargs) # E: Too many arguments
return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args"

def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, **kwargs)
func2() # E: Too few arguments
return func2(1, *args) # E: Argument 1 has incompatible type "int"; expected "P.args"

def run_bad4(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1)
func2() # E: Too few arguments
func2(*args) # E: Too few arguments
func2(1, *args) # E: Too few arguments \
# E: Argument 1 has incompatible type "int"; expected "P.args"
func2(1, **kwargs) # E: Too few arguments \
# E: Argument 1 has incompatible type "int"; expected "P.args"
return func2(**kwargs) # E: Too few arguments

[builtins fixtures/paramspec.pyi]

[case testOtherVarArgs]
from functools import partial
from typing_extensions import Concatenate, ParamSpec
from typing import Callable, TypeVar, Tuple

P = ParamSpec("P")
T = TypeVar("T")

def run(func: Callable[Concatenate[int, str, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
args_prefix: Tuple[int, str] = (1, 'a')
func2(*args_prefix) # E: Too few arguments
func2(*args, *args_prefix) # E: Argument 1 has incompatible type "*P.args"; expected "int" \
# E: Argument 1 has incompatible type "*P.args"; expected "str" \
# E: Argument 2 has incompatible type "*Tuple[int, str]"; expected "P.args"
return func2(*args_prefix, *args)

[builtins fixtures/paramspec.pyi]
Loading