From 0070071d461dd57e2dc9b8a215333212167e13c8 Mon Sep 17 00:00:00 2001
From: Jukka Lehtosalo <jukka.lehtosalo@iki.fi>
Date: Thu, 29 Dec 2022 13:40:54 +0000
Subject: [PATCH] [mypyc] Fixes to union simplification (#14364)

Flatten nested unions before simplifying unions.

Simplify item type unions in loops. This fixes a crash introduced in
#14363.
---
 mypyc/ir/rtypes.py                            | 37 ++++++++++
 mypyc/irbuild/builder.py                      | 13 +++-
 mypyc/irbuild/mapper.py                       | 13 +---
 mypyc/test-data/irbuild-lists.test            | 70 ++++++++++++++++++-
 .../test/{test_subtype.py => test_typeops.py} | 26 ++++++-
 5 files changed, 141 insertions(+), 18 deletions(-)
 rename mypyc/test/{test_subtype.py => test_typeops.py} (64%)

diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py
index 7fe8a940e4c2..babfe0770f35 100644
--- a/mypyc/ir/rtypes.py
+++ b/mypyc/ir/rtypes.py
@@ -797,6 +797,30 @@ def __init__(self, items: list[RType]) -> None:
         self.items_set = frozenset(items)
         self._ctype = "PyObject *"
 
+    @staticmethod
+    def make_simplified_union(items: list[RType]) -> RType:
+        """Return a normalized union that covers the given items.
+
+        Flatten nested unions and remove duplicate items.
+
+        Overlapping items are *not* simplified. For example,
+        [object, str] will not be simplified.
+        """
+        items = flatten_nested_unions(items)
+        assert items
+
+        # Remove duplicate items using set + list to preserve item order
+        seen = set()
+        new_items = []
+        for item in items:
+            if item not in seen:
+                new_items.append(item)
+                seen.add(item)
+        if len(new_items) > 1:
+            return RUnion(new_items)
+        else:
+            return new_items[0]
+
     def accept(self, visitor: RTypeVisitor[T]) -> T:
         return visitor.visit_runion(self)
 
@@ -823,6 +847,19 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RUnion:
         return RUnion(types)
 
 
+def flatten_nested_unions(types: list[RType]) -> list[RType]:
+    if not any(isinstance(t, RUnion) for t in types):
+        return types  # Fast path
+
+    flat_items: list[RType] = []
+    for t in types:
+        if isinstance(t, RUnion):
+            flat_items.extend(flatten_nested_unions(t.items))
+        else:
+            flat_items.append(t)
+    return flat_items
+
+
 def optional_value_type(rtype: RType) -> RType | None:
     """If rtype is the union of none_rprimitive and another type X, return X.
 
diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py
index 6310c25c64fb..792697970785 100644
--- a/mypyc/irbuild/builder.py
+++ b/mypyc/irbuild/builder.py
@@ -53,6 +53,7 @@
     Type,
     TypeOfAny,
     UninhabitedType,
+    UnionType,
     get_proper_type,
 )
 from mypy.util import split_target
@@ -85,6 +86,7 @@
     RInstance,
     RTuple,
     RType,
+    RUnion,
     bitmap_rprimitive,
     c_int_rprimitive,
     c_pyssize_t_rprimitive,
@@ -864,8 +866,15 @@ def extract_int(self, e: Expression) -> int | None:
             return None
 
     def get_sequence_type(self, expr: Expression) -> RType:
-        target_type = get_proper_type(self.types[expr])
-        assert isinstance(target_type, Instance)
+        return self.get_sequence_type_from_type(self.types[expr])
+
+    def get_sequence_type_from_type(self, target_type: Type) -> RType:
+        target_type = get_proper_type(target_type)
+        if isinstance(target_type, UnionType):
+            return RUnion.make_simplified_union(
+                [self.get_sequence_type_from_type(item) for item in target_type.items]
+            )
+        assert isinstance(target_type, Instance), target_type
         if target_type.type.fullname == "builtins.str":
             return str_rprimitive
         else:
diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py
index a108766644ce..dddb35230fd5 100644
--- a/mypyc/irbuild/mapper.py
+++ b/mypyc/irbuild/mapper.py
@@ -116,18 +116,7 @@ def type_to_rtype(self, typ: Type | None) -> RType:
         elif isinstance(typ, NoneTyp):
             return none_rprimitive
         elif isinstance(typ, UnionType):
-            # Remove redundant items using set + list to preserve item order
-            seen = set()
-            items = []
-            for item in typ.items:
-                rtype = self.type_to_rtype(item)
-                if rtype not in seen:
-                    items.append(rtype)
-                    seen.add(rtype)
-            if len(items) > 1:
-                return RUnion(items)
-            else:
-                return items[0]
+            return RUnion.make_simplified_union([self.type_to_rtype(item) for item in typ.items])
         elif isinstance(typ, AnyType):
             return object_rprimitive
         elif isinstance(typ, TypeType):
diff --git a/mypyc/test-data/irbuild-lists.test b/mypyc/test-data/irbuild-lists.test
index b82217465fef..cb9687a2f942 100644
--- a/mypyc/test-data/irbuild-lists.test
+++ b/mypyc/test-data/irbuild-lists.test
@@ -430,14 +430,20 @@ L5:
     return 1
 
 [case testSimplifyListUnion]
-from typing import List, Union
+from typing import List, Union, Optional
 
-def f(a: Union[List[str], List[bytes], int]) -> int:
+def narrow(a: Union[List[str], List[bytes], int]) -> int:
     if isinstance(a, list):
         return len(a)
     return a
+def loop(a: Union[List[str], List[bytes]]) -> None:
+    for x in a:
+         pass
+def nested_union(a: Union[List[str], List[Optional[str]]]) -> None:
+    for x in a:
+        pass
 [out]
-def f(a):
+def narrow(a):
     a :: union[list, int]
     r0 :: object
     r1 :: int32
@@ -465,3 +471,61 @@ L1:
 L2:
     r8 = unbox(int, a)
     return r8
+def loop(a):
+    a :: list
+    r0 :: short_int
+    r1 :: ptr
+    r2 :: native_int
+    r3 :: short_int
+    r4 :: bit
+    r5 :: object
+    r6, x :: union[str, bytes]
+    r7 :: short_int
+L0:
+    r0 = 0
+L1:
+    r1 = get_element_ptr a ob_size :: PyVarObject
+    r2 = load_mem r1 :: native_int*
+    keep_alive a
+    r3 = r2 << 1
+    r4 = r0 < r3 :: signed
+    if r4 goto L2 else goto L4 :: bool
+L2:
+    r5 = CPyList_GetItemUnsafe(a, r0)
+    r6 = cast(union[str, bytes], r5)
+    x = r6
+L3:
+    r7 = r0 + 2
+    r0 = r7
+    goto L1
+L4:
+    return 1
+def nested_union(a):
+    a :: list
+    r0 :: short_int
+    r1 :: ptr
+    r2 :: native_int
+    r3 :: short_int
+    r4 :: bit
+    r5 :: object
+    r6, x :: union[str, None]
+    r7 :: short_int
+L0:
+    r0 = 0
+L1:
+    r1 = get_element_ptr a ob_size :: PyVarObject
+    r2 = load_mem r1 :: native_int*
+    keep_alive a
+    r3 = r2 << 1
+    r4 = r0 < r3 :: signed
+    if r4 goto L2 else goto L4 :: bool
+L2:
+    r5 = CPyList_GetItemUnsafe(a, r0)
+    r6 = cast(union[str, None], r5)
+    x = r6
+L3:
+    r7 = r0 + 2
+    r0 = r7
+    goto L1
+L4:
+    return 1
diff --git a/mypyc/test/test_subtype.py b/mypyc/test/test_typeops.py
similarity index 64%
rename from mypyc/test/test_subtype.py
rename to mypyc/test/test_typeops.py
index 4a0d8737c852..f414edd1a2bb 100644
--- a/mypyc/test/test_subtype.py
+++ b/mypyc/test/test_typeops.py
@@ -1,16 +1,19 @@
-"""Test cases for is_subtype and is_runtime_subtype."""
+"""Test cases for various RType operations."""
 
 from __future__ import annotations
 
 import unittest
 
 from mypyc.ir.rtypes import (
+    RUnion,
     bit_rprimitive,
     bool_rprimitive,
     int32_rprimitive,
     int64_rprimitive,
     int_rprimitive,
+    object_rprimitive,
     short_int_rprimitive,
+    str_rprimitive,
 )
 from mypyc.rt_subtype import is_runtime_subtype
 from mypyc.subtype import is_subtype
@@ -50,3 +53,24 @@ def test_bit(self) -> None:
     def test_bool(self) -> None:
         assert not is_runtime_subtype(bool_rprimitive, bit_rprimitive)
         assert not is_runtime_subtype(bool_rprimitive, int_rprimitive)
+
+
+class TestUnionSimplification(unittest.TestCase):
+    def test_simple_type_result(self) -> None:
+        assert RUnion.make_simplified_union([int_rprimitive]) == int_rprimitive
+
+    def test_remove_duplicate(self) -> None:
+        assert RUnion.make_simplified_union([int_rprimitive, int_rprimitive]) == int_rprimitive
+
+    def test_cannot_simplify(self) -> None:
+        assert RUnion.make_simplified_union(
+            [int_rprimitive, str_rprimitive, object_rprimitive]
+        ) == RUnion([int_rprimitive, str_rprimitive, object_rprimitive])
+
+    def test_nested(self) -> None:
+        assert RUnion.make_simplified_union(
+            [int_rprimitive, RUnion([str_rprimitive, int_rprimitive])]
+        ) == RUnion([int_rprimitive, str_rprimitive])
+        assert RUnion.make_simplified_union(
+            [int_rprimitive, RUnion([str_rprimitive, RUnion([int_rprimitive])])]
+        ) == RUnion([int_rprimitive, str_rprimitive])