From 608de81d7617fd715f2f4bdca6b8c15c73caabbc Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 5 Aug 2022 17:33:19 +0100 Subject: [PATCH] Handle interactions between recursive aliases and recursive instances (#13328) This is a follow-up for #13297 The fix for infinite recursion is kind of simple, but it is hard to make inference infer something useful. Currently we handle all most common cases, but it is quite fragile (I however have few tricks left if people will complain about inference). --- mypy/checkexpr.py | 32 ++++--- mypy/constraints.py | 35 ++++++- mypy/infer.py | 3 +- mypy/solve.py | 8 +- mypy/subtypes.py | 18 +--- mypy/typeops.py | 31 ++++-- mypy/typestate.py | 12 ++- test-data/unit/check-recursive-types.test | 110 ++++++++++++++++++++++ 8 files changed, 194 insertions(+), 55 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index aa6d8e63f5f7..0753ee80c113 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -154,6 +154,7 @@ is_optional, remove_optional, ) +from mypy.typestate import TypeState from mypy.typevars import fill_typevars from mypy.util import split_module_names from mypy.visitor import ExpressionVisitor @@ -1429,6 +1430,22 @@ def infer_arg_types_in_empty_context(self, args: List[Expression]) -> List[Type] res.append(arg_type) return res + @contextmanager + def allow_unions(self, type_context: Type) -> Iterator[None]: + # This is a hack to better support inference for recursive types. + # When the outer context for a function call is known to be recursive, + # we solve type constraints inferred from arguments using unions instead + # of joins. This is a bit arbitrary, but in practice it works for most + # cases. A cleaner alternative would be to switch to single bin type + # inference, but this is a lot of work. + old = TypeState.infer_unions + if has_recursive_types(type_context): + TypeState.infer_unions = True + try: + yield + finally: + TypeState.infer_unions = old + def infer_arg_types_in_context( self, callee: CallableType, @@ -1448,7 +1465,8 @@ def infer_arg_types_in_context( for i, actuals in enumerate(formal_to_actual): for ai in actuals: if not arg_kinds[ai].is_star(): - res[ai] = self.accept(args[ai], callee.arg_types[i]) + with self.allow_unions(callee.arg_types[i]): + res[ai] = self.accept(args[ai], callee.arg_types[i]) # Fill in the rest of the argument types. for i, t in enumerate(res): @@ -1568,17 +1586,6 @@ def infer_function_type_arguments( else: pass1_args.append(arg) - # This is a hack to better support inference for recursive types. - # When the outer context for a function call is known to be recursive, - # we solve type constraints inferred from arguments using unions instead - # of joins. This is a bit arbitrary, but in practice it works for most - # cases. A cleaner alternative would be to switch to single bin type - # inference, but this is a lot of work. - ctx = self.type_context[-1] - if ctx and has_recursive_types(ctx): - infer_unions = True - else: - infer_unions = False inferred_args = infer_function_type_arguments( callee_type, pass1_args, @@ -1586,7 +1593,6 @@ def infer_function_type_arguments( formal_to_actual, context=self.argument_infer_context(), strict=self.chk.in_checked_function(), - infer_unions=infer_unions, ) if 2 in arg_pass_nums: diff --git a/mypy/constraints.py b/mypy/constraints.py index 0ca6a3e085f0..b4c3cf6f28c9 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -42,6 +42,8 @@ UnpackType, callable_with_ellipsis, get_proper_type, + has_recursive_types, + has_type_vars, is_named_instance, is_union_with_any, ) @@ -141,14 +143,19 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> List[Cons The constraints are represented as Constraint objects. """ if any( - get_proper_type(template) == get_proper_type(t) for t in reversed(TypeState._inferring) + get_proper_type(template) == get_proper_type(t) + and get_proper_type(actual) == get_proper_type(a) + for (t, a) in reversed(TypeState.inferring) ): return [] - if isinstance(template, TypeAliasType) and template.is_recursive: + if has_recursive_types(template): # This case requires special care because it may cause infinite recursion. - TypeState._inferring.append(template) + if not has_type_vars(template): + # Return early on an empty branch. + return [] + TypeState.inferring.append((template, actual)) res = _infer_constraints(template, actual, direction) - TypeState._inferring.pop() + TypeState.inferring.pop() return res return _infer_constraints(template, actual, direction) @@ -216,13 +223,18 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> List[Con # When the template is a union, we are okay with leaving some # type variables indeterminate. This helps with some special # cases, though this isn't very principled. - return any_constraints( + result = any_constraints( [ infer_constraints_if_possible(t_item, actual, direction) for t_item in template.items ], eager=False, ) + if result: + return result + elif has_recursive_types(template) and not has_recursive_types(actual): + return handle_recursive_union(template, actual, direction) + return [] # Remaining cases are handled by ConstraintBuilderVisitor. return template.accept(ConstraintBuilderVisitor(actual, direction)) @@ -279,6 +291,19 @@ def merge_with_any(constraint: Constraint) -> Constraint: ) +def handle_recursive_union(template: UnionType, actual: Type, direction: int) -> List[Constraint]: + # This is a hack to special-case things like Union[T, Inst[T]] in recursive types. Although + # it is quite arbitrary, it is a relatively common pattern, so we should handle it well. + # This function may be called when inferring against such union resulted in different + # constraints for each item. Normally we give up in such case, but here we instead split + # the union in two parts, and try inferring sequentially. + non_type_var_items = [t for t in template.items if not isinstance(t, TypeVarType)] + type_var_items = [t for t in template.items if isinstance(t, TypeVarType)] + return infer_constraints( + UnionType.make_union(non_type_var_items), actual, direction + ) or infer_constraints(UnionType.make_union(type_var_items), actual, direction) + + def any_constraints(options: List[Optional[List[Constraint]]], eager: bool) -> List[Constraint]: """Deduce what we can from a collection of constraint lists. diff --git a/mypy/infer.py b/mypy/infer.py index 1c00d2904702..d3ad0bc19f9b 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -34,7 +34,6 @@ def infer_function_type_arguments( formal_to_actual: List[List[int]], context: ArgumentInferContext, strict: bool = True, - infer_unions: bool = False, ) -> List[Optional[Type]]: """Infer the type arguments of a generic function. @@ -56,7 +55,7 @@ def infer_function_type_arguments( # Solve constraints. type_vars = callee_type.type_var_ids() - return solve_constraints(type_vars, constraints, strict, infer_unions=infer_unions) + return solve_constraints(type_vars, constraints, strict) def infer_type_arguments( diff --git a/mypy/solve.py b/mypy/solve.py index 918308625742..90bbd5b9d3b5 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -17,13 +17,11 @@ UnionType, get_proper_type, ) +from mypy.typestate import TypeState def solve_constraints( - vars: List[TypeVarId], - constraints: List[Constraint], - strict: bool = True, - infer_unions: bool = False, + vars: List[TypeVarId], constraints: List[Constraint], strict: bool = True ) -> List[Optional[Type]]: """Solve type constraints. @@ -55,7 +53,7 @@ def solve_constraints( if bottom is None: bottom = c.target else: - if infer_unions: + if TypeState.infer_unions: # This deviates from the general mypy semantics because # recursive types are union-heavy in 95% of cases. bottom = UnionType.make_union([bottom, c.target]) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 5756c581e53a..5a8c5e38b2fa 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -145,14 +145,7 @@ def is_subtype( ), "Don't pass both context and individual flags" if TypeState.is_assumed_subtype(left, right): return True - if ( - # TODO: recursive instances like `class str(Sequence[str])` can also cause - # issues, so we also need to include them in the assumptions stack - isinstance(left, TypeAliasType) - and isinstance(right, TypeAliasType) - and left.is_recursive - and right.is_recursive - ): + if mypy.typeops.is_recursive_pair(left, right): # This case requires special care because it may cause infinite recursion. # Our view on recursive types is known under a fancy name of iso-recursive mu-types. # Roughly this means that a recursive type is defined as an alias where right hand side @@ -205,12 +198,7 @@ def is_proper_subtype( ), "Don't pass both context and individual flags" if TypeState.is_assumed_proper_subtype(left, right): return True - if ( - isinstance(left, TypeAliasType) - and isinstance(right, TypeAliasType) - and left.is_recursive - and right.is_recursive - ): + if mypy.typeops.is_recursive_pair(left, right): # Same as for non-proper subtype, see detailed comment there for explanation. with pop_on_exit(TypeState.get_assumptions(is_proper=True), left, right): return _is_subtype(left, right, subtype_context, proper_subtype=True) @@ -874,7 +862,7 @@ def visit_type_alias_type(self, left: TypeAliasType) -> bool: assert False, f"This should be never called, got {left}" -T = TypeVar("T", Instance, TypeAliasType) +T = TypeVar("T", bound=Type) @contextmanager diff --git a/mypy/typeops.py b/mypy/typeops.py index f7b14c710cc2..ef3ec1de24c9 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -63,13 +63,25 @@ def is_recursive_pair(s: Type, t: Type) -> bool: - """Is this a pair of recursive type aliases?""" - return ( - isinstance(s, TypeAliasType) - and isinstance(t, TypeAliasType) - and s.is_recursive - and t.is_recursive - ) + """Is this a pair of recursive types? + + There may be more cases, and we may be forced to use e.g. has_recursive_types() + here, but this function is called in very hot code, so we try to keep it simple + and return True only in cases we know may have problems. + """ + if isinstance(s, TypeAliasType) and s.is_recursive: + return ( + isinstance(get_proper_type(t), Instance) + or isinstance(t, TypeAliasType) + and t.is_recursive + ) + if isinstance(t, TypeAliasType) and t.is_recursive: + return ( + isinstance(get_proper_type(s), Instance) + or isinstance(s, TypeAliasType) + and s.is_recursive + ) + return False def tuple_fallback(typ: TupleType) -> Instance: @@ -81,9 +93,8 @@ def tuple_fallback(typ: TupleType) -> Instance: return typ.partial_fallback items = [] for item in typ.items: - proper_type = get_proper_type(item) - if isinstance(proper_type, UnpackType): - unpacked_type = get_proper_type(proper_type.type) + if isinstance(item, UnpackType): + unpacked_type = get_proper_type(item.type) if isinstance(unpacked_type, TypeVarTupleType): items.append(unpacked_type.upper_bound) elif isinstance(unpacked_type, TupleType): diff --git a/mypy/typestate.py b/mypy/typestate.py index 389dc9c2a358..a1d2ab972a11 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -9,7 +9,7 @@ from mypy.nodes import TypeInfo from mypy.server.trigger import make_trigger -from mypy.types import Instance, Type, TypeAliasType, get_proper_type +from mypy.types import Instance, Type, get_proper_type # Represents that the 'left' instance is a subtype of the 'right' instance SubtypeRelationship: _TypeAlias = Tuple[Instance, Instance] @@ -80,10 +80,12 @@ class TypeState: # recursive type aliases. Normally, one would pass type assumptions as an additional # arguments to is_subtype(), but this would mean updating dozens of related functions # threading this through all callsites (see also comment for TypeInfo.assuming). - _assuming: Final[List[Tuple[TypeAliasType, TypeAliasType]]] = [] - _assuming_proper: Final[List[Tuple[TypeAliasType, TypeAliasType]]] = [] + _assuming: Final[List[Tuple[Type, Type]]] = [] + _assuming_proper: Final[List[Tuple[Type, Type]]] = [] # Ditto for inference of generic constraints against recursive type aliases. - _inferring: Final[List[TypeAliasType]] = [] + inferring: Final[List[Tuple[Type, Type]]] = [] + # Whether to use joins or unions when solving constraints, see checkexpr.py for details. + infer_unions: ClassVar = False # N.B: We do all of the accesses to these properties through # TypeState, instead of making these classmethods and accessing @@ -109,7 +111,7 @@ def is_assumed_proper_subtype(left: Type, right: Type) -> bool: return False @staticmethod - def get_assumptions(is_proper: bool) -> List[Tuple[TypeAliasType, TypeAliasType]]: + def get_assumptions(is_proper: bool) -> List[Tuple[Type, Type]]: if is_proper: return TypeState._assuming_proper return TypeState._assuming diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index ac2065c55f18..04b7d634d4a9 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -60,6 +60,22 @@ x: Nested[int] = [1, [2, [3]]] x = [1, [Bad()]] # E: List item 0 has incompatible type "Bad"; expected "Union[int, Nested[int]]" [builtins fixtures/isinstancelist.pyi] +[case testRecursiveAliasGenericInferenceNested] +# flags: --enable-recursive-aliases +from typing import Union, TypeVar, Sequence, List + +T = TypeVar("T") +class A: ... +class B(A): ... + +Nested = Sequence[Union[T, Nested[T]]] + +def flatten(arg: Nested[T]) -> List[T]: ... +reveal_type(flatten([[B(), B()]])) # N: Revealed type is "builtins.list[__main__.B]" +reveal_type(flatten([[[[B()]]]])) # N: Revealed type is "builtins.list[__main__.B]" +reveal_type(flatten([[B(), [[B()]]]])) # N: Revealed type is "builtins.list[__main__.B]" +[builtins fixtures/isinstancelist.pyi] + [case testRecursiveAliasNewStyleSupported] # flags: --enable-recursive-aliases from test import A @@ -278,3 +294,97 @@ if isinstance(b[0], Sequence): a = b[0] x = a # E: Incompatible types in assignment (expression has type "Sequence[Union[B, NestedB]]", variable has type "int") [builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasWithRecursiveInstance] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar + +class A: ... +T = TypeVar("T") +Nested = Sequence[Union[T, Nested[T]]] +class B(Sequence[B]): ... + +a: Nested[A] +aa: Nested[A] +b: B +a = b # OK +a = [[b]] # OK +b = aa # E: Incompatible types in assignment (expression has type "Nested[A]", variable has type "B") + +def join(a: T, b: T) -> T: ... +reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" +reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasWithRecursiveInstanceInference] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar, List + +T = TypeVar("T") +Nested = Sequence[Union[T, Nested[T]]] +class B(Sequence[B]): ... + +nb: Nested[B] = [B(), [B(), [B()]]] +lb: List[B] + +def foo(x: Nested[T]) -> T: ... +reveal_type(foo(lb)) # N: Revealed type is "__main__.B" +reveal_type(foo([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" + +NestedInv = List[Union[T, NestedInv[T]]] +nib: NestedInv[B] = [B(), [B(), [B()]]] +def bar(x: NestedInv[T]) -> T: ... +reveal_type(bar(nib)) # N: Revealed type is "__main__.B" +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasTopUnion] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar, List + +class A: ... +class B(A): ... + +T = TypeVar("T") +PlainNested = Union[T, Sequence[PlainNested[T]]] + +x: PlainNested[A] +y: PlainNested[B] = [B(), [B(), [B()]]] +x = y # OK + +xx: PlainNested[B] +yy: PlainNested[A] +xx = yy # E: Incompatible types in assignment (expression has type "PlainNested[A]", variable has type "PlainNested[B]") + +def foo(arg: PlainNested[T]) -> T: ... +lb: List[B] +reveal_type(foo([B(), [B(), [B()]]])) # N: Revealed type is "__main__.B" +reveal_type(foo(lb)) # N: Revealed type is "__main__.B" +reveal_type(foo(xx)) # N: Revealed type is "__main__.B" +[builtins fixtures/isinstancelist.pyi] + +[case testRecursiveAliasInferenceExplicitNonRecursive] +# flags: --enable-recursive-aliases +from typing import Sequence, Union, TypeVar, List + +T = TypeVar("T") +Nested = Sequence[Union[T, Nested[T]]] +PlainNested = Union[T, Sequence[PlainNested[T]]] + +def foo(x: Nested[T]) -> T: ... +def bar(x: PlainNested[T]) -> T: ... + +class A: ... +a: A +la: List[A] +lla: List[Union[A, List[A]]] +llla: List[Union[A, List[Union[A, List[A]]]]] + +reveal_type(foo(la)) # N: Revealed type is "__main__.A" +reveal_type(foo(lla)) # N: Revealed type is "__main__.A" +reveal_type(foo(llla)) # N: Revealed type is "__main__.A" + +reveal_type(bar(a)) # N: Revealed type is "__main__.A" +reveal_type(bar(la)) # N: Revealed type is "__main__.A" +reveal_type(bar(lla)) # N: Revealed type is "__main__.A" +reveal_type(bar(llla)) # N: Revealed type is "__main__.A" +[builtins fixtures/isinstancelist.pyi]