Skip to content

Commit

Permalink
Fix union simplification performance regression (#12519)
Browse files Browse the repository at this point in the history
#11962 can generate large unions with many Instance types with
last_known_value set. This caused our union simplification algorithm
to be extremely slow, as it hit an O(n**2) code path.

We already had a fast code path for unions of regular literal types. This
generalizes it for unions containing Instance types with last known
values (which behave similarly to literals in a literal type context).

Also fix a union simplification bug that I encountered while writing tests
for this change.

Work on #12408.
  • Loading branch information
JukkaL authored Apr 5, 2022
1 parent 0e8a03c commit 4ff8d04
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 20 deletions.
39 changes: 39 additions & 0 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,12 @@ def test_simplified_union(self) -> None:
self.assert_simplified_union([fx.a, UnionType([fx.a])], fx.a)
self.assert_simplified_union([fx.b, UnionType([fx.c, UnionType([fx.d])])],
UnionType([fx.b, fx.c, fx.d]))

def test_simplified_union_with_literals(self) -> None:
fx = self.fx

self.assert_simplified_union([fx.lit1, fx.a], fx.a)
self.assert_simplified_union([fx.lit1, fx.lit2, fx.a], fx.a)
self.assert_simplified_union([fx.lit1, fx.lit1], fx.lit1)
self.assert_simplified_union([fx.lit1, fx.lit2], UnionType([fx.lit1, fx.lit2]))
self.assert_simplified_union([fx.lit1, fx.lit3], UnionType([fx.lit1, fx.lit3]))
Expand All @@ -481,6 +486,40 @@ def test_simplified_union(self) -> None:
self.assert_simplified_union([fx.lit1, fx.lit2_inst], UnionType([fx.lit1, fx.lit2_inst]))
self.assert_simplified_union([fx.lit1, fx.lit3_inst], UnionType([fx.lit1, fx.lit3_inst]))

def test_simplified_union_with_str_literals(self) -> None:
fx = self.fx

self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.str_type], fx.str_type)
self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1], fx.lit_str1)
self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.lit_str3],
UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3]))
self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.uninhabited],
UnionType([fx.lit_str1, fx.lit_str2]))

def test_simplified_union_with_str_instance_literals(self) -> None:
fx = self.fx

self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.str_type],
fx.str_type)
self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str1_inst, fx.lit_str1_inst],
fx.lit_str1_inst)
self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.lit_str3_inst],
UnionType([fx.lit_str1_inst,
fx.lit_str2_inst,
fx.lit_str3_inst]))
self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.uninhabited],
UnionType([fx.lit_str1_inst, fx.lit_str2_inst]))

def test_simplified_union_with_mixed_str_literals(self) -> None:
fx = self.fx

self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst],
UnionType([fx.lit_str1,
fx.lit_str2,
fx.lit_str3_inst]))
self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst],
UnionType([fx.lit_str1, fx.lit_str1_inst]))

def assert_simplified_union(self, original: List[Type], union: Type) -> None:
assert_equal(make_simplified_union(original), union)
assert_equal(make_simplified_union(list(reversed(original))), union)
Expand Down
9 changes: 9 additions & 0 deletions mypy/test/typefixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
variances=[COVARIANT]) # class tuple
self.type_typei = self.make_type_info('builtins.type') # class type
self.bool_type_info = self.make_type_info('builtins.bool')
self.str_type_info = self.make_type_info('builtins.str')
self.functioni = self.make_type_info('builtins.function') # function TODO
self.ai = self.make_type_info('A', mro=[self.oi]) # class A
self.bi = self.make_type_info('B', mro=[self.ai, self.oi]) # class B(A)
Expand Down Expand Up @@ -109,6 +110,7 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
self.std_tuple = Instance(self.std_tuplei, [self.anyt]) # tuple
self.type_type = Instance(self.type_typei, []) # type
self.function = Instance(self.functioni, []) # function TODO
self.str_type = Instance(self.str_type_info, [])
self.a = Instance(self.ai, []) # A
self.b = Instance(self.bi, []) # B
self.c = Instance(self.ci, []) # C
Expand Down Expand Up @@ -163,6 +165,13 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
self.lit3_inst = Instance(self.di, [], last_known_value=self.lit3)
self.lit4_inst = Instance(self.ai, [], last_known_value=self.lit4)

self.lit_str1 = LiteralType("x", self.str_type)
self.lit_str2 = LiteralType("y", self.str_type)
self.lit_str3 = LiteralType("z", self.str_type)
self.lit_str1_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str1)
self.lit_str2_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str2)
self.lit_str3_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str3)

self.type_a = TypeType.make_normalized(self.a)
self.type_b = TypeType.make_normalized(self.b)
self.type_c = TypeType.make_normalized(self.c)
Expand Down
50 changes: 30 additions & 20 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,23 @@ def callable_corresponding_argument(typ: CallableType,
return by_name if by_name is not None else by_pos


def is_simple_literal(t: ProperType) -> bool:
"""
Whether a type is a simple enough literal to allow for fast Union simplification
def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]:
"""Return a hashable description of simple literal type.
Return None if not a simple literal type.
For now this means enum or string
The return value can be used to simplify away duplicate types in
unions by comparing keys for equality. For now enum, string or
Instance with string last_known_value are supported.
"""
return isinstance(t, LiteralType) and (
t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str'
)
if isinstance(t, LiteralType):
if t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str':
assert isinstance(t.value, str)
return 'literal', t.value, t.fallback.type.fullname
if isinstance(t, Instance):
if t.last_known_value is not None and isinstance(t.last_known_value.value, str):
return 'instance', t.last_known_value.value, t.type.fullname
return None


def make_simplified_union(items: Sequence[Type],
Expand Down Expand Up @@ -341,10 +349,20 @@ def make_simplified_union(items: Sequence[Type],
all_items.append(typ)
items = all_items

simplified_set = _remove_redundant_union_items(items, keep_erased)

# If more than one literal exists in the union, try to simplify
if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1):
simplified_set = try_contracting_literals_in_union(simplified_set)

return UnionType.make_union(simplified_set, line, column)


def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> List[ProperType]:
from mypy.subtypes import is_proper_subtype

removed: Set[int] = set()
seen: Set[Tuple[str, str]] = set()
seen: Set[Tuple[str, ...]] = set()

# NB: having a separate fast path for Union of Literal and slow path for other things
# would arguably be cleaner, however it breaks down when simplifying the Union of two
Expand All @@ -354,10 +372,8 @@ def make_simplified_union(items: Sequence[Type],
if i in removed:
continue
# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
if is_simple_literal(item):
assert isinstance(item, LiteralType)
assert isinstance(item.value, str)
k = (item.value, item.fallback.type.fullname)
k = simple_literal_value_key(item)
if k is not None:
if k in seen:
removed.add(i)
continue
Expand All @@ -373,13 +389,13 @@ def make_simplified_union(items: Sequence[Type],
seen.add(k)
if safe_skip:
continue

# Keep track of the truishness info for deleted subtypes which can be relevant
cbt = cbf = False
for j, tj in enumerate(items):
# NB: we don't need to check literals as the fast path above takes care of that
if (
i != j
and not is_simple_literal(tj)
and is_proper_subtype(tj, item, keep_erased_types=keep_erased)
and is_redundant_literal_instance(item, tj) # XXX?
):
Expand All @@ -393,13 +409,7 @@ def make_simplified_union(items: Sequence[Type],
elif not item.can_be_false and cbf:
items[i] = true_or_false(item)

simplified_set = [items[i] for i in range(len(items)) if i not in removed]

# If more than one literal exists in the union, try to simplify
if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1):
simplified_set = try_contracting_literals_in_union(simplified_set)

return UnionType.make_union(simplified_set, line, column)
return [items[i] for i in range(len(items)) if i not in removed]


def _get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]:
Expand Down

0 comments on commit 4ff8d04

Please sign in to comment.