diff --git a/mypy/constraints.py b/mypy/constraints.py index 58d0f4dbed29..7d782551b261 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -692,11 +692,8 @@ def visit_parameters(self, template: Parameters) -> list[Constraint]: return self.infer_against_any(template.arg_types, self.actual) if type_state.infer_polymorphic and isinstance(self.actual, Parameters): # For polymorphic inference we need to be able to infer secondary constraints - # in situations like [x: T] <: P <: [x: int]. Note we invert direction, since - # this function expects direction between callables. - return infer_callable_arguments_constraints( - template, self.actual, neg_op(self.direction) - ) + # in situations like [x: T] <: P <: [x: int]. + return infer_callable_arguments_constraints(template, self.actual, self.direction) raise RuntimeError("Parameters cannot be constrained to") # Non-leaf types @@ -1128,7 +1125,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: ) ) if param_spec_target is not None: - res.append(Constraint(param_spec, neg_op(self.direction), param_spec_target)) + res.append(Constraint(param_spec, self.direction, param_spec_target)) if extra_tvars: for c in res: c.extra_tvars += cactual.variables diff --git a/mypy/join.py b/mypy/join.py index e4429425d98a..2e2939f9fbc8 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -350,10 +350,13 @@ def visit_parameters(self, t: Parameters) -> ProperType: if isinstance(self.s, Parameters): if len(t.arg_types) != len(self.s.arg_types): return self.default(self.s) + from mypy.meet import meet_types + return t.copy_modified( - # Note that since during constraint inference we already treat whole ParamSpec as - # contravariant, we should join individual items, not meet them like for Callables - arg_types=[join_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)] + arg_types=[ + meet_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types) + ], + arg_names=combine_arg_names(self.s, t), ) else: return self.default(self.s) @@ -754,7 +757,9 @@ def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType: ) -def combine_arg_names(t: CallableType, s: CallableType) -> list[str | None]: +def combine_arg_names( + t: CallableType | Parameters, s: CallableType | Parameters +) -> list[str | None]: """Produces a list of argument names compatible with both callables. For example, suppose 't' and 's' have the following signatures: diff --git a/mypy/meet.py b/mypy/meet.py index 0fa500d32c30..e76274456f91 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -706,10 +706,10 @@ def visit_parameters(self, t: Parameters) -> ProperType: if isinstance(self.s, Parameters): if len(t.arg_types) != len(self.s.arg_types): return self.default(self.s) + from mypy.join import join_types + return t.copy_modified( - # Note that since during constraint inference we already treat whole ParamSpec as - # contravariant, we should meet individual items, not join them like for Callables - arg_types=[meet_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)] + arg_types=[join_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)] ) else: return self.default(self.s) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 638553883dd8..77947cb086ec 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -654,8 +654,6 @@ def visit_unpack_type(self, left: UnpackType) -> bool: def visit_parameters(self, left: Parameters) -> bool: if isinstance(self.right, Parameters): - # TODO: direction here should be opposite, this function expects - # order of callables, while parameters are contravariant. return are_parameters_compatible( left, self.right, diff --git a/mypy/types.py b/mypy/types.py index d0c19a08e60a..6b42211836c0 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1552,7 +1552,10 @@ class FormalArgument(NamedTuple): class Parameters(ProperType): """Type that represents the parameters to a function. - Used for ParamSpec analysis.""" + Used for ParamSpec analysis. Note that by convention we handle this + type as a Callable without return type, not as a "tuple with names", + so that it behaves contravariantly, in particular [x: int] <: [int]. + """ __slots__ = ( "arg_types", diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 48fadbc96c90..db8c76fd21e9 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1403,7 +1403,7 @@ def wrong_name_constructor(b: bool) -> SomeClass: func(SomeClass, constructor) reveal_type(func(SomeClass, wrong_constructor)) # N: Revealed type is "def (a: Never) -> __main__.SomeClass" reveal_type(func_regular(SomeClass, wrong_constructor)) # N: Revealed type is "def (Never) -> __main__.SomeClass" -func(SomeClass, wrong_name_constructor) # E: Argument 1 to "func" has incompatible type "Type[SomeClass]"; expected "Callable[[Never], SomeClass]" +reveal_type(func(SomeClass, wrong_name_constructor)) # N: Revealed type is "def (Never) -> __main__.SomeClass" [builtins fixtures/paramspec.pyi] [case testParamSpecInTypeAliasBasic] @@ -2059,3 +2059,30 @@ def test2(x: int, y: int) -> str: ... reveal_type(call(test1, 1)) # N: Revealed type is "builtins.str" reveal_type(call(test2, 1, 2)) # N: Revealed type is "builtins.str" [builtins fixtures/paramspec.pyi] + +[case testParamSpecCorrectParameterNameInference] +from typing import Callable, Protocol +from typing_extensions import ParamSpec, Concatenate + +def a(i: int) -> None: ... +def b(__i: int) -> None: ... + +class WithName(Protocol): + def __call__(self, i: int) -> None: ... +NoName = Callable[[int], None] + +def f1(__fn: WithName, i: int) -> None: ... +def f2(__fn: NoName, i: int) -> None: ... + +P = ParamSpec("P") +def d(f: Callable[P, None], fn: Callable[Concatenate[Callable[P, None], P], None]) -> Callable[P, None]: + def inner(*args: P.args, **kwargs: P.kwargs) -> None: + fn(f, *args, **kwargs) + return inner + +reveal_type(d(a, f1)) # N: Revealed type is "def (i: builtins.int)" +reveal_type(d(a, f2)) # N: Revealed type is "def (i: builtins.int)" +reveal_type(d(b, f1)) # E: Cannot infer type argument 1 of "d" \ + # N: Revealed type is "def (*Any, **Any)" +reveal_type(d(b, f2)) # N: Revealed type is "def (builtins.int)" +[builtins fixtures/paramspec.pyi]