Skip to content

Commit

Permalink
Handle interactions between recursive aliases and recursive instances (
Browse files Browse the repository at this point in the history
…#13328)

This is a follow-up for #13297 

The fix for infinite recursion is kind of simple, but it is hard to make inference infer something useful. Currently we handle all most common cases, but it is quite fragile (I however have few tricks left if people will complain about inference).
  • Loading branch information
ilevkivskyi authored Aug 5, 2022
1 parent b3eebe3 commit 608de81
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 55 deletions.
32 changes: 19 additions & 13 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
is_optional,
remove_optional,
)
from mypy.typestate import TypeState
from mypy.typevars import fill_typevars
from mypy.util import split_module_names
from mypy.visitor import ExpressionVisitor
Expand Down Expand Up @@ -1429,6 +1430,22 @@ def infer_arg_types_in_empty_context(self, args: List[Expression]) -> List[Type]
res.append(arg_type)
return res

@contextmanager
def allow_unions(self, type_context: Type) -> Iterator[None]:
# This is a hack to better support inference for recursive types.
# When the outer context for a function call is known to be recursive,
# we solve type constraints inferred from arguments using unions instead
# of joins. This is a bit arbitrary, but in practice it works for most
# cases. A cleaner alternative would be to switch to single bin type
# inference, but this is a lot of work.
old = TypeState.infer_unions
if has_recursive_types(type_context):
TypeState.infer_unions = True
try:
yield
finally:
TypeState.infer_unions = old

def infer_arg_types_in_context(
self,
callee: CallableType,
Expand All @@ -1448,7 +1465,8 @@ def infer_arg_types_in_context(
for i, actuals in enumerate(formal_to_actual):
for ai in actuals:
if not arg_kinds[ai].is_star():
res[ai] = self.accept(args[ai], callee.arg_types[i])
with self.allow_unions(callee.arg_types[i]):
res[ai] = self.accept(args[ai], callee.arg_types[i])

# Fill in the rest of the argument types.
for i, t in enumerate(res):
Expand Down Expand Up @@ -1568,25 +1586,13 @@ def infer_function_type_arguments(
else:
pass1_args.append(arg)

# This is a hack to better support inference for recursive types.
# When the outer context for a function call is known to be recursive,
# we solve type constraints inferred from arguments using unions instead
# of joins. This is a bit arbitrary, but in practice it works for most
# cases. A cleaner alternative would be to switch to single bin type
# inference, but this is a lot of work.
ctx = self.type_context[-1]
if ctx and has_recursive_types(ctx):
infer_unions = True
else:
infer_unions = False
inferred_args = infer_function_type_arguments(
callee_type,
pass1_args,
arg_kinds,
formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function(),
infer_unions=infer_unions,
)

if 2 in arg_pass_nums:
Expand Down
35 changes: 30 additions & 5 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
UnpackType,
callable_with_ellipsis,
get_proper_type,
has_recursive_types,
has_type_vars,
is_named_instance,
is_union_with_any,
)
Expand Down Expand Up @@ -141,14 +143,19 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> List[Cons
The constraints are represented as Constraint objects.
"""
if any(
get_proper_type(template) == get_proper_type(t) for t in reversed(TypeState._inferring)
get_proper_type(template) == get_proper_type(t)
and get_proper_type(actual) == get_proper_type(a)
for (t, a) in reversed(TypeState.inferring)
):
return []
if isinstance(template, TypeAliasType) and template.is_recursive:
if has_recursive_types(template):
# This case requires special care because it may cause infinite recursion.
TypeState._inferring.append(template)
if not has_type_vars(template):
# Return early on an empty branch.
return []
TypeState.inferring.append((template, actual))
res = _infer_constraints(template, actual, direction)
TypeState._inferring.pop()
TypeState.inferring.pop()
return res
return _infer_constraints(template, actual, direction)

Expand Down Expand Up @@ -216,13 +223,18 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> List[Con
# When the template is a union, we are okay with leaving some
# type variables indeterminate. This helps with some special
# cases, though this isn't very principled.
return any_constraints(
result = any_constraints(
[
infer_constraints_if_possible(t_item, actual, direction)
for t_item in template.items
],
eager=False,
)
if result:
return result
elif has_recursive_types(template) and not has_recursive_types(actual):
return handle_recursive_union(template, actual, direction)
return []

# Remaining cases are handled by ConstraintBuilderVisitor.
return template.accept(ConstraintBuilderVisitor(actual, direction))
Expand Down Expand Up @@ -279,6 +291,19 @@ def merge_with_any(constraint: Constraint) -> Constraint:
)


def handle_recursive_union(template: UnionType, actual: Type, direction: int) -> List[Constraint]:
# This is a hack to special-case things like Union[T, Inst[T]] in recursive types. Although
# it is quite arbitrary, it is a relatively common pattern, so we should handle it well.
# This function may be called when inferring against such union resulted in different
# constraints for each item. Normally we give up in such case, but here we instead split
# the union in two parts, and try inferring sequentially.
non_type_var_items = [t for t in template.items if not isinstance(t, TypeVarType)]
type_var_items = [t for t in template.items if isinstance(t, TypeVarType)]
return infer_constraints(
UnionType.make_union(non_type_var_items), actual, direction
) or infer_constraints(UnionType.make_union(type_var_items), actual, direction)


def any_constraints(options: List[Optional[List[Constraint]]], eager: bool) -> List[Constraint]:
"""Deduce what we can from a collection of constraint lists.
Expand Down
3 changes: 1 addition & 2 deletions mypy/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def infer_function_type_arguments(
formal_to_actual: List[List[int]],
context: ArgumentInferContext,
strict: bool = True,
infer_unions: bool = False,
) -> List[Optional[Type]]:
"""Infer the type arguments of a generic function.
Expand All @@ -56,7 +55,7 @@ def infer_function_type_arguments(

# Solve constraints.
type_vars = callee_type.type_var_ids()
return solve_constraints(type_vars, constraints, strict, infer_unions=infer_unions)
return solve_constraints(type_vars, constraints, strict)


def infer_type_arguments(
Expand Down
8 changes: 3 additions & 5 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
UnionType,
get_proper_type,
)
from mypy.typestate import TypeState


def solve_constraints(
vars: List[TypeVarId],
constraints: List[Constraint],
strict: bool = True,
infer_unions: bool = False,
vars: List[TypeVarId], constraints: List[Constraint], strict: bool = True
) -> List[Optional[Type]]:
"""Solve type constraints.
Expand Down Expand Up @@ -55,7 +53,7 @@ def solve_constraints(
if bottom is None:
bottom = c.target
else:
if infer_unions:
if TypeState.infer_unions:
# This deviates from the general mypy semantics because
# recursive types are union-heavy in 95% of cases.
bottom = UnionType.make_union([bottom, c.target])
Expand Down
18 changes: 3 additions & 15 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,7 @@ def is_subtype(
), "Don't pass both context and individual flags"
if TypeState.is_assumed_subtype(left, right):
return True
if (
# TODO: recursive instances like `class str(Sequence[str])` can also cause
# issues, so we also need to include them in the assumptions stack
isinstance(left, TypeAliasType)
and isinstance(right, TypeAliasType)
and left.is_recursive
and right.is_recursive
):
if mypy.typeops.is_recursive_pair(left, right):
# This case requires special care because it may cause infinite recursion.
# Our view on recursive types is known under a fancy name of iso-recursive mu-types.
# Roughly this means that a recursive type is defined as an alias where right hand side
Expand Down Expand Up @@ -205,12 +198,7 @@ def is_proper_subtype(
), "Don't pass both context and individual flags"
if TypeState.is_assumed_proper_subtype(left, right):
return True
if (
isinstance(left, TypeAliasType)
and isinstance(right, TypeAliasType)
and left.is_recursive
and right.is_recursive
):
if mypy.typeops.is_recursive_pair(left, right):
# Same as for non-proper subtype, see detailed comment there for explanation.
with pop_on_exit(TypeState.get_assumptions(is_proper=True), left, right):
return _is_subtype(left, right, subtype_context, proper_subtype=True)
Expand Down Expand Up @@ -874,7 +862,7 @@ def visit_type_alias_type(self, left: TypeAliasType) -> bool:
assert False, f"This should be never called, got {left}"


T = TypeVar("T", Instance, TypeAliasType)
T = TypeVar("T", bound=Type)


@contextmanager
Expand Down
31 changes: 21 additions & 10 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,25 @@


def is_recursive_pair(s: Type, t: Type) -> bool:
"""Is this a pair of recursive type aliases?"""
return (
isinstance(s, TypeAliasType)
and isinstance(t, TypeAliasType)
and s.is_recursive
and t.is_recursive
)
"""Is this a pair of recursive types?
There may be more cases, and we may be forced to use e.g. has_recursive_types()
here, but this function is called in very hot code, so we try to keep it simple
and return True only in cases we know may have problems.
"""
if isinstance(s, TypeAliasType) and s.is_recursive:
return (
isinstance(get_proper_type(t), Instance)
or isinstance(t, TypeAliasType)
and t.is_recursive
)
if isinstance(t, TypeAliasType) and t.is_recursive:
return (
isinstance(get_proper_type(s), Instance)
or isinstance(s, TypeAliasType)
and s.is_recursive
)
return False


def tuple_fallback(typ: TupleType) -> Instance:
Expand All @@ -81,9 +93,8 @@ def tuple_fallback(typ: TupleType) -> Instance:
return typ.partial_fallback
items = []
for item in typ.items:
proper_type = get_proper_type(item)
if isinstance(proper_type, UnpackType):
unpacked_type = get_proper_type(proper_type.type)
if isinstance(item, UnpackType):
unpacked_type = get_proper_type(item.type)
if isinstance(unpacked_type, TypeVarTupleType):
items.append(unpacked_type.upper_bound)
elif isinstance(unpacked_type, TupleType):
Expand Down
12 changes: 7 additions & 5 deletions mypy/typestate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from mypy.nodes import TypeInfo
from mypy.server.trigger import make_trigger
from mypy.types import Instance, Type, TypeAliasType, get_proper_type
from mypy.types import Instance, Type, get_proper_type

# Represents that the 'left' instance is a subtype of the 'right' instance
SubtypeRelationship: _TypeAlias = Tuple[Instance, Instance]
Expand Down Expand Up @@ -80,10 +80,12 @@ class TypeState:
# recursive type aliases. Normally, one would pass type assumptions as an additional
# arguments to is_subtype(), but this would mean updating dozens of related functions
# threading this through all callsites (see also comment for TypeInfo.assuming).
_assuming: Final[List[Tuple[TypeAliasType, TypeAliasType]]] = []
_assuming_proper: Final[List[Tuple[TypeAliasType, TypeAliasType]]] = []
_assuming: Final[List[Tuple[Type, Type]]] = []
_assuming_proper: Final[List[Tuple[Type, Type]]] = []
# Ditto for inference of generic constraints against recursive type aliases.
_inferring: Final[List[TypeAliasType]] = []
inferring: Final[List[Tuple[Type, Type]]] = []
# Whether to use joins or unions when solving constraints, see checkexpr.py for details.
infer_unions: ClassVar = False

# N.B: We do all of the accesses to these properties through
# TypeState, instead of making these classmethods and accessing
Expand All @@ -109,7 +111,7 @@ def is_assumed_proper_subtype(left: Type, right: Type) -> bool:
return False

@staticmethod
def get_assumptions(is_proper: bool) -> List[Tuple[TypeAliasType, TypeAliasType]]:
def get_assumptions(is_proper: bool) -> List[Tuple[Type, Type]]:
if is_proper:
return TypeState._assuming_proper
return TypeState._assuming
Expand Down
Loading

0 comments on commit 608de81

Please sign in to comment.