Skip to content

Commit

Permalink
Reconsider constraints involving parameter specifications (#15272)
Browse files Browse the repository at this point in the history
- Fixes #15037
- Fixes #15065
- Fixes #15073
- Fixes #15388
- Fixes #15086

Yet another part of #14903 that's
finally been extracted!
  • Loading branch information
A5rocks authored Aug 9, 2023
1 parent 5617cdd commit 2aaeda4
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 24 deletions.
129 changes: 106 additions & 23 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,19 @@ def __repr__(self) -> str:
op_str = "<:"
if self.op == SUPERTYPE_OF:
op_str = ":>"
return f"{self.type_var} {op_str} {self.target}"
return f"{self.origin_type_var} {op_str} {self.target}"

def __hash__(self) -> int:
return hash((self.type_var, self.op, self.target))
return hash((self.origin_type_var, self.op, self.target))

def __eq__(self, other: object) -> bool:
if not isinstance(other, Constraint):
return False
return (self.type_var, self.op, self.target) == (other.type_var, other.op, other.target)
return (self.origin_type_var, self.op, self.target) == (
other.origin_type_var,
other.op,
other.target,
)


def infer_constraints_for_callable(
Expand Down Expand Up @@ -698,25 +702,54 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
)
elif isinstance(tvar, ParamSpecType) and isinstance(mapped_arg, ParamSpecType):
suffix = get_proper_type(instance_arg)
prefix = mapped_arg.prefix
length = len(prefix.arg_types)

if isinstance(suffix, CallableType):
prefix = mapped_arg.prefix
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
suffix = suffix.copy_modified(from_concatenate=from_concat)

if isinstance(suffix, (Parameters, CallableType)):
# no such thing as variance for ParamSpecs
# TODO: is there a case I am missing?
# TODO: constraints between prefixes
prefix = mapped_arg.prefix
suffix = suffix.copy_modified(
suffix.arg_types[len(prefix.arg_types) :],
suffix.arg_kinds[len(prefix.arg_kinds) :],
suffix.arg_names[len(prefix.arg_names) :],
length = min(length, len(suffix.arg_types))

constrained_to = suffix.copy_modified(
suffix.arg_types[length:],
suffix.arg_kinds[length:],
suffix.arg_names[length:],
)
constrained_from = mapped_arg.copy_modified(
prefix=prefix.copy_modified(
prefix.arg_types[length:],
prefix.arg_kinds[length:],
prefix.arg_names[length:],
)
)
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))

res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained_to))
res.append(Constraint(constrained_from, SUBTYPE_OF, constrained_to))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
suffix_prefix = suffix.prefix
length = min(length, len(suffix_prefix.arg_types))

constrained = suffix.copy_modified(
prefix=suffix_prefix.copy_modified(
suffix_prefix.arg_types[length:],
suffix_prefix.arg_kinds[length:],
suffix_prefix.arg_names[length:],
)
)
constrained_from = mapped_arg.copy_modified(
prefix=prefix.copy_modified(
prefix.arg_types[length:],
prefix.arg_kinds[length:],
prefix.arg_names[length:],
)
)

res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained))
res.append(Constraint(constrained_from, SUBTYPE_OF, constrained))
else:
# This case should have been handled above.
assert not isinstance(tvar, TypeVarTupleType)
Expand Down Expand Up @@ -768,26 +801,56 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
template_arg, ParamSpecType
):
suffix = get_proper_type(mapped_arg)
prefix = template_arg.prefix
length = len(prefix.arg_types)

if isinstance(suffix, CallableType):
prefix = template_arg.prefix
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
suffix = suffix.copy_modified(from_concatenate=from_concat)

# TODO: this is almost a copy-paste of code above: make this into a function
if isinstance(suffix, (Parameters, CallableType)):
# no such thing as variance for ParamSpecs
# TODO: is there a case I am missing?
# TODO: constraints between prefixes
prefix = template_arg.prefix
length = min(length, len(suffix.arg_types))

suffix = suffix.copy_modified(
suffix.arg_types[len(prefix.arg_types) :],
suffix.arg_kinds[len(prefix.arg_kinds) :],
suffix.arg_names[len(prefix.arg_names) :],
constrained_to = suffix.copy_modified(
suffix.arg_types[length:],
suffix.arg_kinds[length:],
suffix.arg_names[length:],
)
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
constrained_from = template_arg.copy_modified(
prefix=prefix.copy_modified(
prefix.arg_types[length:],
prefix.arg_kinds[length:],
prefix.arg_names[length:],
)
)

res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained_to))
res.append(Constraint(constrained_from, SUBTYPE_OF, constrained_to))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
suffix_prefix = suffix.prefix
length = min(length, len(suffix_prefix.arg_types))

constrained = suffix.copy_modified(
prefix=suffix_prefix.copy_modified(
suffix_prefix.arg_types[length:],
suffix_prefix.arg_kinds[length:],
suffix_prefix.arg_names[length:],
)
)
constrained_from = template_arg.copy_modified(
prefix=prefix.copy_modified(
prefix.arg_types[length:],
prefix.arg_kinds[length:],
prefix.arg_names[length:],
)
)

res.append(Constraint(constrained_from, SUPERTYPE_OF, constrained))
res.append(Constraint(constrained_from, SUBTYPE_OF, constrained))
else:
# This case should have been handled above.
assert not isinstance(tvar, TypeVarTupleType)
Expand Down Expand Up @@ -954,9 +1017,19 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
prefix_len = len(prefix.arg_types)
cactual_ps = cactual.param_spec()

cactual_prefix: Parameters | CallableType
if cactual_ps:
cactual_prefix = cactual_ps.prefix
else:
cactual_prefix = cactual

max_prefix_len = len(
[k for k in cactual_prefix.arg_kinds if k in (ARG_POS, ARG_OPT)]
)
prefix_len = min(prefix_len, max_prefix_len)

# we could check the prefixes match here, but that should be caught elsewhere.
if not cactual_ps:
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
prefix_len = min(prefix_len, max_prefix_len)
res.append(
Constraint(
param_spec,
Expand All @@ -970,7 +1043,17 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
)
)
else:
res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps))
# earlier, cactual_prefix = cactual_ps.prefix. thus, this is guaranteed
assert isinstance(cactual_prefix, Parameters)

constrained_by = cactual_ps.copy_modified(
prefix=cactual_prefix.copy_modified(
cactual_prefix.arg_types[prefix_len:],
cactual_prefix.arg_kinds[prefix_len:],
cactual_prefix.arg_names[prefix_len:],
)
)
res.append(Constraint(param_spec, SUBTYPE_OF, constrained_by))

# compare prefixes
cactual_prefix = cactual.copy_modified(
Expand Down
62 changes: 62 additions & 0 deletions mypy/test/testconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,65 @@ def test_var_length_tuple_with_fixed_length_tuple(self) -> None:
Instance(fx.std_tuplei, [fx.a]),
SUPERTYPE_OF,
)

def test_paramspec_constrained_with_concatenate(self) -> None:
# for legibility (and my own understanding), `Tester.normal()` is `Tester[P]`
# and `Tester.concatenate()` is `Tester[Concatenate[A, P]]`
# ... and 2nd arg to infer_constraints ends up on LHS of equality
fx = self.fx

# I don't think we can parametrize...
for direction in (SUPERTYPE_OF, SUBTYPE_OF):
print(f"direction is {direction}")
# equiv to: x: Tester[Q] = Tester.normal()
assert set(
infer_constraints(Instance(fx.gpsi, [fx.p]), Instance(fx.gpsi, [fx.q]), direction)
) == {
Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q),
Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q),
}

# equiv to: x: Tester[Q] = Tester.concatenate()
assert set(
infer_constraints(
Instance(fx.gpsi, [fx.p_concatenate]), Instance(fx.gpsi, [fx.q]), direction
)
) == {
Constraint(type_var=fx.p_concatenate, op=SUPERTYPE_OF, target=fx.q),
Constraint(type_var=fx.p_concatenate, op=SUBTYPE_OF, target=fx.q),
}

# equiv to: x: Tester[Concatenate[B, Q]] = Tester.normal()
assert set(
infer_constraints(
Instance(fx.gpsi, [fx.p]), Instance(fx.gpsi, [fx.q_concatenate]), direction
)
) == {
Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q_concatenate),
Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q_concatenate),
}

# equiv to: x: Tester[Concatenate[B, Q]] = Tester.concatenate()
assert set(
infer_constraints(
Instance(fx.gpsi, [fx.p_concatenate]),
Instance(fx.gpsi, [fx.q_concatenate]),
direction,
)
) == {
# this is correct as we assume other parts of mypy will warn that [B] != [A]
Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q),
Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q),
}

# equiv to: x: Tester[Concatenate[A, Q]] = Tester.concatenate()
assert set(
infer_constraints(
Instance(fx.gpsi, [fx.p_concatenate]),
Instance(fx.gpsi, [fx.q_concatenate]),
direction,
)
) == {
Constraint(type_var=fx.p, op=SUPERTYPE_OF, target=fx.q),
Constraint(type_var=fx.p, op=SUBTYPE_OF, target=fx.q),
}
42 changes: 42 additions & 0 deletions mypy/test/typefixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from __future__ import annotations

from typing import Sequence

from mypy.nodes import (
ARG_OPT,
ARG_POS,
Expand All @@ -26,6 +28,9 @@
Instance,
LiteralType,
NoneType,
Parameters,
ParamSpecFlavor,
ParamSpecType,
Type,
TypeAliasType,
TypeOfAny,
Expand Down Expand Up @@ -238,6 +243,31 @@ def make_type_var_tuple(name: str, id: int, upper_bound: Type) -> TypeVarTupleTy
"GV2", mro=[self.oi], typevars=["T", "Ts", "S"], typevar_tuple_index=1
)

def make_parameter_specification(
name: str, id: int, concatenate: Sequence[Type]
) -> ParamSpecType:
return ParamSpecType(
name,
name,
id,
ParamSpecFlavor.BARE,
self.o,
AnyType(TypeOfAny.from_omitted_generics),
prefix=Parameters(
concatenate, [ARG_POS for _ in concatenate], [None for _ in concatenate]
),
)

self.p = make_parameter_specification("P", 1, [])
self.p_concatenate = make_parameter_specification("P", 1, [self.a])
self.q = make_parameter_specification("Q", 2, [])
self.q_concatenate = make_parameter_specification("Q", 2, [self.b])
self.q_concatenate_a = make_parameter_specification("Q", 2, [self.a])

self.gpsi = self.make_type_info(
"GPS", mro=[self.oi], typevars=["P"], paramspec_indexes={0}
)

def _add_bool_dunder(self, type_info: TypeInfo) -> None:
signature = CallableType([], [], [], Instance(self.bool_type_info, []), self.function)
bool_func = FuncDef("__bool__", [], Block([]))
Expand Down Expand Up @@ -299,6 +329,7 @@ def make_type_info(
bases: list[Instance] | None = None,
typevars: list[str] | None = None,
typevar_tuple_index: int | None = None,
paramspec_indexes: set[int] | None = None,
variances: list[int] | None = None,
) -> TypeInfo:
"""Make a TypeInfo suitable for use in unit tests."""
Expand Down Expand Up @@ -326,6 +357,17 @@ def make_type_info(
AnyType(TypeOfAny.from_omitted_generics),
)
)
elif paramspec_indexes is not None and id - 1 in paramspec_indexes:
v.append(
ParamSpecType(
n,
n,
id,
ParamSpecFlavor.BARE,
self.o,
AnyType(TypeOfAny.from_omitted_generics),
)
)
else:
if variances:
variance = variances[id - 1]
Expand Down
32 changes: 31 additions & 1 deletion test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ _P = ParamSpec("_P")

class Job(Generic[_P]):
def __init__(self, target: Callable[_P, None]) -> None:
self.target = target
...

def func(
action: Union[Job[int], Callable[[int], None]],
Expand Down Expand Up @@ -1535,6 +1535,36 @@ def identity(func: Callable[P, None]) -> Callable[P, None]: ...
def f(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ...
[builtins fixtures/paramspec.pyi]

[case testComplicatedParamSpecReturnType]
# regression test for https://github.com/python/mypy/issues/15073
from typing import TypeVar, Callable
from typing_extensions import ParamSpec, Concatenate

R = TypeVar("R")
P = ParamSpec("P")

def f(
) -> Callable[[Callable[Concatenate[Callable[P, R], P], R]], Callable[P, R]]:
def r(fn: Callable[Concatenate[Callable[P, R], P], R]) -> Callable[P, R]: ...
return r
[builtins fixtures/paramspec.pyi]

[case testParamSpecToParamSpecAssignment]
# minimized from https://github.com/python/mypy/issues/15037
# ~ the same as https://github.com/python/mypy/issues/15065
from typing import Callable
from typing_extensions import Concatenate, ParamSpec

P = ParamSpec("P")

def f(f: Callable[Concatenate[int, P], None]) -> Callable[P, None]: ...

x: Callable[
[Callable[Concatenate[int, P], None]],
Callable[P, None],
] = f
[builtins fixtures/paramspec.pyi]

[case testParamSpecDecoratorAppliedToGeneric]
# flags: --new-type-inference
from typing import Callable, List, TypeVar
Expand Down

0 comments on commit 2aaeda4

Please sign in to comment.