Skip to content

Commit

Permalink
Support ParamSpec + functools.partial
Browse files Browse the repository at this point in the history
  • Loading branch information
sterliakov committed Jun 10, 2024
1 parent 63995e3 commit acb3548
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 6 deletions.
18 changes: 15 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,9 +2331,21 @@ def check_argument_count(
# Positional argument when expecting a keyword argument.
self.msg.too_many_positional_arguments(callee, context)
ok = False
elif callee.param_spec() is not None and not formal_to_actual[i]:
self.msg.too_few_arguments(callee, context, actual_names)
ok = False
elif callee.param_spec() is not None:
if (
not formal_to_actual[i]
and not callee.param_spec_parts_bound[kind == ArgKind.ARG_STAR2]
and callee.special_sig != "partial"
):
self.msg.too_few_arguments(callee, context, actual_names)
ok = False
elif (
formal_to_actual[i]
and kind == ArgKind.ARG_STAR
and callee.param_spec_parts_bound[0]
):
self.msg.too_many_arguments(callee, context)
ok = False
return ok

def check_for_extra_actual_arguments(
Expand Down
7 changes: 6 additions & 1 deletion mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
]
],
special_sig="partial",
)
if defaulted.line < 0:
# Make up a line number if we don't have one
Expand Down Expand Up @@ -208,6 +209,10 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
arg_kinds=partial_kinds,
arg_names=partial_names,
ret_type=ret_type,
param_spec_parts_bound=(
ArgKind.ARG_STAR in actual_arg_kinds,
ArgKind.ARG_STAR2 in actual_arg_kinds,
),
)

ret = ctx.api.named_generic_type("functools.partial", [ret_type])
Expand Down
13 changes: 11 additions & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,8 +1773,8 @@ class CallableType(FunctionLike):
"implicit", # Was this type implicitly generated instead of explicitly
# specified by the user?
"special_sig", # Non-None for signatures that require special handling
# (currently only value is 'dict' for a signature similar to
# 'dict')
# (currently only values are 'dict' for a signature similar to
# 'dict' and 'partial' for a `functools.partial` evaluation)
"from_type_type", # Was this callable generated by analyzing Type[...]
# instantiation?
"bound_args", # Bound type args, mostly unused but may be useful for
Expand All @@ -1787,6 +1787,7 @@ class CallableType(FunctionLike):
# (this is used for error messages)
"imprecise_arg_kinds",
"unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable?
"param_spec_parts_bound", # Hack for functools.partial: allow early binding
)

def __init__(
Expand All @@ -1813,6 +1814,7 @@ def __init__(
from_concatenate: bool = False,
imprecise_arg_kinds: bool = False,
unpack_kwargs: bool = False,
param_spec_parts_bound: tuple[bool, bool] = (False, False),
) -> None:
super().__init__(line, column)
assert len(arg_types) == len(arg_kinds) == len(arg_names)
Expand All @@ -1839,6 +1841,7 @@ def __init__(
self.from_type_type = from_type_type
self.from_concatenate = from_concatenate
self.imprecise_arg_kinds = imprecise_arg_kinds
self.param_spec_parts_bound = param_spec_parts_bound
if not bound_args:
bound_args = ()
self.bound_args = bound_args
Expand Down Expand Up @@ -1885,6 +1888,7 @@ def copy_modified(
from_concatenate: Bogus[bool] = _dummy,
imprecise_arg_kinds: Bogus[bool] = _dummy,
unpack_kwargs: Bogus[bool] = _dummy,
param_spec_parts_bound: Bogus[tuple[bool, bool]] = _dummy,
) -> CT:
modified = CallableType(
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
Expand Down Expand Up @@ -1916,6 +1920,11 @@ def copy_modified(
else self.imprecise_arg_kinds
),
unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs,
param_spec_parts_bound=(
param_spec_parts_bound
if param_spec_parts_bound is not _dummy
else self.param_spec_parts_bound
),
)
# Optimization: Only NewTypes are supported as subtypes since
# the class is effectively final, so we can use a cast safely.
Expand Down
74 changes: 74 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2315,3 +2315,77 @@ reveal_type(capture(fn)) # N: Revealed type is "Union[builtins.str, builtins.in
reveal_type(capture(err)) # N: Revealed type is "builtins.int"

[builtins fixtures/paramspec.pyi]

[case testBindPartial]
from functools import partial
from typing_extensions import ParamSpec
from typing import Callable, TypeVar

P = ParamSpec("P")
T = TypeVar("T")

def run(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
return func2(*args)

def run2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args)
return func2(**kwargs)

def run3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args, **kwargs)
return func2()

def run4(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args, **kwargs)
return func2(**kwargs)

def run_bad(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args, **kwargs)
return func2(*args) # E: Too many arguments

def run_bad2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
return func2(**kwargs) # E: Too few arguments

def run_bad3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args)
return func2() # E: Too few arguments

[builtins fixtures/paramspec.pyi]

[case testBindPartialConcatenate]
from functools import partial
from typing_extensions import Concatenate, ParamSpec
from typing import Callable, TypeVar

P = ParamSpec("P")
T = TypeVar("T")

def run(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, **kwargs)
return func2(*args)

def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
return func2(1, *args)

def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, *args)
return func2(**kwargs)

def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int"
return func2(1, **kwargs) # E: Too many arguments

def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, *args)
return func2(1, **kwargs) # E: Too many arguments \
# E: Argument 1 has incompatible type "int"; expected "P.args"

def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, *args)
return func2(1, **kwargs) # E: Too many arguments \
# E: Argument 1 has incompatible type "int"; expected "P.args"

[builtins fixtures/paramspec.pyi]

0 comments on commit acb3548

Please sign in to comment.