From 89fa2419e218ce979768774ae16b8dbf3c7d06d4 Mon Sep 17 00:00:00 2001
From: Ruihang Lai <lairuihangdongdong@qq.com>
Date: Sun, 23 Jan 2022 01:51:09 +0800
Subject: [PATCH] [BugFix][TIR] Fix cross-thread reduction when single
 reduction loop with predicate (#10016)

---
 .../lower_cross_thread_reduction.cc           |   8 +-
 ..._transform_lower_cross_thread_reduction.py | 163 ++++++++++++++++++
 2 files changed, 168 insertions(+), 3 deletions(-)

diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc
index 2eea869af516..4df38ff543b5 100644
--- a/src/tir/transforms/lower_cross_thread_reduction.cc
+++ b/src/tir/transforms/lower_cross_thread_reduction.cc
@@ -582,10 +582,12 @@ class CrossThreadReductionTransformer : public StmtMutator {
     PrimExpr combiner_rhs{nullptr};
     std::tie(n_bound_reduction_loops, reducer, combiner_rhs) =
         CheckCanApplyCrossThreadReduction(block, reduction_loops);
-    // Step 3. When not all the reduction-related loops are bound to thread axes, in-thread
-    // reduction is needed in this cross-thread reduction.
+    // Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when
+    //  - not all the reduction-related loops are bound to thread axes, or
+    //  - the block-realize has a non-constant-true predicate.
     bool need_in_thread_reduction =
-        n_bound_reduction_loops < static_cast<int>(reduction_loops.size());
+        n_bound_reduction_loops < static_cast<int>(reduction_loops.size()) ||
+        !is_one(realize->predicate);
     // Step 4. Create intermediate buffers, storing them in `ct_buffer` and
     // `it_buffer`. Let the scope block allocate these new buffers.
     std::vector<Buffer>& new_buffers = block2new_buffers_[block_stack_.back()];
diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
index 4fa3ab0c550c..5b3d7283f14f 100644
--- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
+++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
@@ -327,6 +327,162 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None:
                 B[vi] = reduce_temp0[0]
 
 
+@T.prim_func
+def single_reduction_loop_with_block_predicate(
+    A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+) -> None:
+    T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+    T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+    for i0 in T.serial(256):
+        for ax0, ax1_0 in T.grid(1, 1):
+            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                with T.block("T_softmax_maxelem"):
+                    i0_1 = T.axis.spatial(256, i0)
+                    k = T.axis.reduce(256, ax1_1)
+                    T.where(ax1_0 * 512 + ax1_1 < 256)
+                    T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k])
+                    T.writes(T_softmax_maxelem_shared[i0_1])
+                    with T.init():
+                        T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38)
+                    T_softmax_maxelem_shared[i0_1] = T.max(
+                        T_softmax_maxelem_shared[i0_1], A[i0_1, k]
+                    )
+        for ax0, ax1_0 in T.grid(1, 1):
+            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                with T.block("T_softmax_expsum"):
+                    i0_2 = T.axis.spatial(256, i0)
+                    k = T.axis.reduce(256, ax1_1)
+                    T.where(ax1_0 * 512 + ax1_1 < 256)
+                    T.reads(
+                        T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]
+                    )
+                    T.writes(T_softmax_expsum_shared[i0_2])
+                    with T.init():
+                        T_softmax_expsum_shared[i0_2] = T.float32(0)
+                    T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
+                        A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32"
+                    )
+        for i1_0 in T.serial(1):
+            for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                with T.block("T_softmax_norm"):
+                    i0_3 = T.axis.spatial(256, i0)
+                    i1 = T.axis.spatial(256, i1_1)
+                    T.where(i1_0 * 512 + i1_1 < 256)
+                    T.reads(
+                        A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]
+                    )
+                    T.writes(T_softmax_norm[i0_3, i1])
+                    T.block_attr({"axis": 1})
+                    T_softmax_norm[i0_3, i1] = (
+                        T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32")
+                        / T_softmax_expsum_shared[i0_3]
+                    )
+
+
+@T.prim_func
+def lowered_single_reduction_loop_with_block_predicate(
+    A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
+) -> None:
+    T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+    T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
+    cross_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
+    in_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
+    cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
+    in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
+    for i0 in T.serial(256):
+        for ax0, ax1_0 in T.grid(1, 1):
+            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                with T.block("T_softmax_maxelem_in_thread_init"):
+                    T.reads()
+                    T.writes(in_thread_0[0])
+                    in_thread_0[0] = T.float32(-3.4028234663852886e38)
+                with T.block("T_softmax_maxelem_in_thread"):
+                    i0_1 = T.axis.spatial(256, i0)
+                    k = T.axis.reduce(256, ax1_1)
+                    T.where(ax1_0 * 512 + ax1_1 < 256)
+                    T.reads(A[i0_1, k], in_thread_0[0])
+                    T.writes(in_thread_0[0])
+                    in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k])
+                with T.block("T_softmax_maxelem_cross_thread"):
+                    T.reads(in_thread_0[0])
+                    T.writes(cross_thread_0[0])
+                    T.attr(
+                        T.comm_reducer(
+                            lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e38)]
+                        ),
+                        "reduce_scope",
+                        T.reinterpret(T.uint64(0), dtype="handle"),
+                    )
+                    T.evaluate(
+                        T.tvm_thread_allreduce(
+                            T.uint32(1),
+                            in_thread_0[0],
+                            True,
+                            cross_thread_0.data,
+                            ax1_1,
+                            dtype="handle",
+                        )
+                    )
+                with T.block("T_softmax_maxelem_write_back"):
+                    i0_2 = T.axis.spatial(256, i0)
+                    T.reads(cross_thread_0[0])
+                    T.writes(T_softmax_maxelem_shared[i0_2])
+                    T_softmax_maxelem_shared[i0_2] = cross_thread_0[0]
+        for ax0, ax1_0 in T.grid(1, 1):
+            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                with T.block("T_softmax_expsum_in_thread_init"):
+                    T.reads()
+                    T.writes(in_thread_1[0])
+                    in_thread_1[0] = T.float32(0)
+                with T.block("T_softmax_expsum_in_thread"):
+                    i0_3 = T.axis.spatial(256, i0)
+                    k = T.axis.reduce(256, ax1_1)
+                    T.where(ax1_0 * 512 + ax1_1 < 256)
+                    T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3], in_thread_1[0])
+                    T.writes(in_thread_1[0])
+                    in_thread_1[0] = in_thread_1[0] + T.exp(
+                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32"
+                    )
+                with T.block("T_softmax_expsum_cross_thread"):
+                    T.reads(in_thread_1[0])
+                    T.writes(cross_thread_1[0])
+                    T.attr(
+                        T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]),
+                        "reduce_scope",
+                        T.reinterpret(T.uint64(0), dtype="handle"),
+                    )
+                    T.evaluate(
+                        T.tvm_thread_allreduce(
+                            T.uint32(1),
+                            in_thread_1[0],
+                            True,
+                            cross_thread_1.data,
+                            ax1_1,
+                            dtype="handle",
+                        )
+                    )
+                with T.block("T_softmax_expsum_write_back"):
+                    i0_4 = T.axis.spatial(256, i0)
+                    T.reads(cross_thread_1[0])
+                    T.writes(T_softmax_expsum_shared[i0_4])
+                    T_softmax_expsum_shared[i0_4] = cross_thread_1[0]
+        for i1_0 in T.serial(1):
+            for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
+                with T.block("T_softmax_norm"):
+                    i0_5 = T.axis.spatial(256, i0)
+                    i1 = T.axis.spatial(256, i1_1)
+                    T.where(i1_0 * 512 + i1_1 < 256)
+                    T.reads(
+                        A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5]
+                    )
+                    T.writes(T_softmax_norm[i0_5, i1])
+                    T.block_attr({"axis": 1})
+                    T_softmax_norm[i0_5, i1] = (
+                        T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32")
+                        / T_softmax_expsum_shared[i0_5]
+                    )
+
+
 @T.prim_func
 def reducer_max(a: T.handle, b: T.handle) -> None:
     A = T.match_buffer(a, [128, 128], dtype="float32")
@@ -686,6 +842,13 @@ def test_with_block_predicate():
     _check(with_block_predicate, lowered_with_block_predicate)
 
 
+def test_single_reduction_loop_with_block_predicate():
+    _check(
+        single_reduction_loop_with_block_predicate,
+        lowered_single_reduction_loop_with_block_predicate,
+    )
+
+
 def test_reducer_max():
     _check(reducer_max, lowered_reducer_max)