diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 2eea869af516..4df38ff543b5 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -582,10 +582,12 @@ class CrossThreadReductionTransformer : public StmtMutator { PrimExpr combiner_rhs{nullptr}; std::tie(n_bound_reduction_loops, reducer, combiner_rhs) = CheckCanApplyCrossThreadReduction(block, reduction_loops); - // Step 3. When not all the reduction-related loops are bound to thread axes, in-thread - // reduction is needed in this cross-thread reduction. + // Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when + // - not all the reduction-related loops are bound to thread axes, or + // - the block-realize has a non-constant-true predicate. bool need_in_thread_reduction = - n_bound_reduction_loops < static_cast(reduction_loops.size()); + n_bound_reduction_loops < static_cast(reduction_loops.size()) || + !is_one(realize->predicate); // Step 4. Create intermediate buffers, storing them in `ct_buffer` and // `it_buffer`. Let the scope block allocate these new buffers. std::vector& new_buffers = block2new_buffers_[block_stack_.back()]; diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 4fa3ab0c550c..5b3d7283f14f 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -327,6 +327,162 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: B[vi] = reduce_temp0[0] +@T.prim_func +def single_reduction_loop_with_block_predicate( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] +) -> None: + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + for i0 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + i0_1 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax1_1) + T.where(ax1_0 * 512 + ax1_1 < 256) + T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) + T.writes(T_softmax_maxelem_shared[i0_1]) + with T.init(): + T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem_shared[i0_1] = T.max( + T_softmax_maxelem_shared[i0_1], A[i0_1, k] + ) + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + i0_2 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax1_1) + T.where(ax1_0 * 512 + ax1_1 < 256) + T.reads( + T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2] + ) + T.writes(T_softmax_expsum_shared[i0_2]) + with T.init(): + T_softmax_expsum_shared[i0_2] = T.float32(0) + T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( + A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" + ) + for i1_0 in T.serial(1): + for i1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_3 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_1) + T.where(i1_0 * 512 + i1_1 < 256) + T.reads( + A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3] + ) + T.writes(T_softmax_norm[i0_3, i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_3, i1] = ( + T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") + / T_softmax_expsum_shared[i0_3] + ) + + +@T.prim_func +def lowered_single_reduction_loop_with_block_predicate( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] +) -> None: + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + cross_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + in_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i0 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_maxelem_in_thread_init"): + T.reads() + T.writes(in_thread_0[0]) + in_thread_0[0] = T.float32(-3.4028234663852886e38) + with T.block("T_softmax_maxelem_in_thread"): + i0_1 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax1_1) + T.where(ax1_0 * 512 + ax1_1 < 256) + T.reads(A[i0_1, k], in_thread_0[0]) + T.writes(in_thread_0[0]) + in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k]) + with T.block("T_softmax_maxelem_cross_thread"): + T.reads(in_thread_0[0]) + T.writes(cross_thread_0[0]) + T.attr( + T.comm_reducer( + lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e38)] + ), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + in_thread_0[0], + True, + cross_thread_0.data, + ax1_1, + dtype="handle", + ) + ) + with T.block("T_softmax_maxelem_write_back"): + i0_2 = T.axis.spatial(256, i0) + T.reads(cross_thread_0[0]) + T.writes(T_softmax_maxelem_shared[i0_2]) + T_softmax_maxelem_shared[i0_2] = cross_thread_0[0] + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_expsum_in_thread_init"): + T.reads() + T.writes(in_thread_1[0]) + in_thread_1[0] = T.float32(0) + with T.block("T_softmax_expsum_in_thread"): + i0_3 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax1_1) + T.where(ax1_0 * 512 + ax1_1 < 256) + T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3], in_thread_1[0]) + T.writes(in_thread_1[0]) + in_thread_1[0] = in_thread_1[0] + T.exp( + A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32" + ) + with T.block("T_softmax_expsum_cross_thread"): + T.reads(in_thread_1[0]) + T.writes(cross_thread_1[0]) + T.attr( + T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + in_thread_1[0], + True, + cross_thread_1.data, + ax1_1, + dtype="handle", + ) + ) + with T.block("T_softmax_expsum_write_back"): + i0_4 = T.axis.spatial(256, i0) + T.reads(cross_thread_1[0]) + T.writes(T_softmax_expsum_shared[i0_4]) + T_softmax_expsum_shared[i0_4] = cross_thread_1[0] + for i1_0 in T.serial(1): + for i1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_5 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_1) + T.where(i1_0 * 512 + i1_1 < 256) + T.reads( + A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5] + ) + T.writes(T_softmax_norm[i0_5, i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_5, i1] = ( + T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32") + / T_softmax_expsum_shared[i0_5] + ) + + @T.prim_func def reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") @@ -686,6 +842,13 @@ def test_with_block_predicate(): _check(with_block_predicate, lowered_with_block_predicate) +def test_single_reduction_loop_with_block_predicate(): + _check( + single_reduction_loop_with_block_predicate, + lowered_single_reduction_loop_with_block_predicate, + ) + + def test_reducer_max(): _check(reducer_max, lowered_reducer_max)