Skip to content

Commit

Permalink
Make join of recursive types more robust (python#13808)
Browse files Browse the repository at this point in the history
Fixes python#13795

Calculating tuple fallbacks on the fly creates a cycle between joins and
subtyping. Although IMO this is conceptually not a right thing, it is
hard to get rid of (unless we want to use unions in the fallbacks, cc
@JukkaL). So instead I re-worked how `join_types()` works w.r.t.
`get_proper_type()`, essentially it now follows the golden rule "Always
pass on the original type if possible".
  • Loading branch information
ilevkivskyi authored Oct 8, 2022
1 parent 9f39120 commit dbe9a88
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 25 deletions.
11 changes: 7 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
StarType,
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeType,
Expand Down Expand Up @@ -195,10 +196,12 @@ class TooManyUnions(Exception):
"""


def allow_fast_container_literal(t: ProperType) -> bool:
def allow_fast_container_literal(t: Type) -> bool:
if isinstance(t, TypeAliasType) and t.is_recursive:
return False
t = get_proper_type(t)
return isinstance(t, Instance) or (
isinstance(t, TupleType)
and all(allow_fast_container_literal(get_proper_type(it)) for it in t.items)
isinstance(t, TupleType) and all(allow_fast_container_literal(it) for it in t.items)
)


Expand Down Expand Up @@ -4603,7 +4606,7 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F
#
# TODO: Always create a union or at least in more cases?
if isinstance(get_proper_type(self.type_context[-1]), UnionType):
res = make_simplified_union([if_type, full_context_else_type])
res: Type = make_simplified_union([if_type, full_context_else_type])
else:
res = join.join_types(if_type, else_type)

Expand Down
38 changes: 24 additions & 14 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import overload

import mypy.typeops
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
Expand Down Expand Up @@ -131,7 +133,6 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
best = res
assert best is not None
for promote in t.type._promote:
promote = get_proper_type(promote)
if isinstance(promote, Instance):
res = self.join_instances(promote, s)
if is_better(res, best):
Expand Down Expand Up @@ -182,17 +183,29 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
return declaration


def trivial_join(s: Type, t: Type) -> ProperType:
def trivial_join(s: Type, t: Type) -> Type:
"""Return one of types (expanded) if it is a supertype of other, otherwise top type."""
if is_subtype(s, t):
return get_proper_type(t)
return t
elif is_subtype(t, s):
return get_proper_type(s)
return s
else:
return object_or_any_from_type(get_proper_type(t))


def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> ProperType:
@overload
def join_types(
s: ProperType, t: ProperType, instance_joiner: InstanceJoiner | None = None
) -> ProperType:
...


@overload
def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> Type:
...


def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> Type:
"""Return the least upper bound of s and t.
For example, the join of 'int' and 'object' is 'object'.
Expand Down Expand Up @@ -443,7 +456,7 @@ def visit_tuple_type(self, t: TupleType) -> ProperType:
if self.s.length() == t.length():
items: list[Type] = []
for i in range(t.length()):
items.append(self.join(t.items[i], self.s.items[i]))
items.append(join_types(t.items[i], self.s.items[i]))
return TupleType(items, fallback)
else:
return fallback
Expand Down Expand Up @@ -487,7 +500,7 @@ def visit_partial_type(self, t: PartialType) -> ProperType:

def visit_type_type(self, t: TypeType) -> ProperType:
if isinstance(self.s, TypeType):
return TypeType.make_normalized(self.join(t.item, self.s.item), line=t.line)
return TypeType.make_normalized(join_types(t.item, self.s.item), line=t.line)
elif isinstance(self.s, Instance) and self.s.type.fullname == "builtins.type":
return self.s
else:
Expand All @@ -496,9 +509,6 @@ def visit_type_type(self, t: TypeType) -> ProperType:
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
assert False, f"This should be never called, got {t}"

def join(self, s: Type, t: Type) -> ProperType:
return join_types(s, t)

def default(self, typ: Type) -> ProperType:
typ = get_proper_type(typ)
if isinstance(typ, Instance):
Expand Down Expand Up @@ -654,19 +664,19 @@ def object_or_any_from_type(typ: ProperType) -> ProperType:
return AnyType(TypeOfAny.implementation_artifact)


def join_type_list(types: list[Type]) -> ProperType:
def join_type_list(types: list[Type]) -> Type:
if not types:
# This is a little arbitrary but reasonable. Any empty tuple should be compatible
# with all variable length tuples, and this makes it possible.
return UninhabitedType()
joined = get_proper_type(types[0])
joined = types[0]
for t in types[1:]:
joined = join_types(joined, t)
return joined


def unpack_callback_protocol(t: Instance) -> Type | None:
def unpack_callback_protocol(t: Instance) -> ProperType | None:
assert t.type.is_protocol
if t.type.protocol_members == ["__call__"]:
return find_member("__call__", t, t, is_operator=True)
return get_proper_type(find_member("__call__", t, t, is_operator=True))
return None
13 changes: 9 additions & 4 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,9 +2542,9 @@ class PromoteExpr(Expression):

__slots__ = ("type",)

type: mypy.types.Type
type: mypy.types.ProperType

def __init__(self, type: mypy.types.Type) -> None:
def __init__(self, type: mypy.types.ProperType) -> None:
super().__init__()
self.type = type

Expand Down Expand Up @@ -2769,7 +2769,7 @@ class is generic then it will be a type constructor of higher kind.
# even though it's not a subclass in Python. The non-standard
# `@_promote` decorator introduces this, and there are also
# several builtin examples, in particular `int` -> `float`.
_promote: list[mypy.types.Type]
_promote: list[mypy.types.ProperType]

# This is used for promoting native integer types such as 'i64' to
# 'int'. (_promote is used for the other direction.) This only
Expand Down Expand Up @@ -3100,7 +3100,12 @@ def deserialize(cls, data: JsonDict) -> TypeInfo:
ti.type_vars = data["type_vars"]
ti.has_param_spec_type = data["has_param_spec_type"]
ti.bases = [mypy.types.Instance.deserialize(b) for b in data["bases"]]
ti._promote = [mypy.types.deserialize_type(p) for p in data["_promote"]]
_promote = []
for p in data["_promote"]:
t = mypy.types.deserialize_type(p)
assert isinstance(t, mypy.types.ProperType)
_promote.append(t)
ti._promote = _promote
ti.declared_metaclass = (
None
if data["declared_metaclass"] is None
Expand Down
1 change: 1 addition & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4945,6 +4945,7 @@ def visit_conditional_expr(self, expr: ConditionalExpr) -> None:
def visit__promote_expr(self, expr: PromoteExpr) -> None:
analyzed = self.anal_type(expr.type)
if analyzed is not None:
assert isinstance(analyzed, ProperType), "Cannot use type aliases for promotions"
expr.type = analyzed

def visit_yield_expr(self, e: YieldExpr) -> None:
Expand Down
4 changes: 2 additions & 2 deletions mypy/semanal_classprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Var,
)
from mypy.options import Options
from mypy.types import Instance, Type
from mypy.types import Instance, ProperType

# Hard coded type promotions (shared between all Python versions).
# These add extra ad-hoc edges to the subtyping relation. For example,
Expand Down Expand Up @@ -155,7 +155,7 @@ def add_type_promotion(
This includes things like 'int' being compatible with 'float'.
"""
defn = info.defn
promote_targets: list[Type] = []
promote_targets: list[ProperType] = []
for decorator in defn.decorators:
if isinstance(decorator, CallExpr):
analyzed = decorator.analyzed
Expand Down
24 changes: 23 additions & 1 deletion test-data/unit/check-recursive-types.test
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,14 @@ m: A
s: str = n.x # E: Incompatible types in assignment (expression has type "Tuple[A, int]", variable has type "str")
reveal_type(m[0]) # N: Revealed type is "builtins.str"
lst = [m, n]
reveal_type(lst[0]) # N: Revealed type is "Tuple[builtins.object, builtins.object]"

# Unfortunately, join of two recursive types is not very precise.
reveal_type(lst[0]) # N: Revealed type is "builtins.object"

# These just should not crash
lst1 = [m]
lst2 = [m, m]
lst3 = [m, m, m]
[builtins fixtures/tuple.pyi]

[case testMutuallyRecursiveNamedTuplesClasses]
Expand Down Expand Up @@ -786,3 +793,18 @@ class B:
y: B.Foo
reveal_type(y) # N: Revealed type is "typing.Sequence[typing.Sequence[...]]"
[builtins fixtures/tuple.pyi]

[case testNoCrashOnRecursiveTupleFallback]
from typing import Union, Tuple

Tree1 = Union[str, Tuple[Tree1]]
Tree2 = Union[str, Tuple[Tree2, Tree2]]
Tree3 = Union[str, Tuple[Tree3, Tree3, Tree3]]

def test1() -> Tree1:
return 42 # E: Incompatible return value type (got "int", expected "Union[str, Tuple[Tree1]]")
def test2() -> Tree2:
return 42 # E: Incompatible return value type (got "int", expected "Union[str, Tuple[Tree2, Tree2]]")
def test3() -> Tree3:
return 42 # E: Incompatible return value type (got "int", expected "Union[str, Tuple[Tree3, Tree3, Tree3]]")
[builtins fixtures/tuple.pyi]

0 comments on commit dbe9a88

Please sign in to comment.