From befdc4e63100bdf25b10ae7a54b40408e2e1805a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 5 Mar 2023 07:50:32 -0500 Subject: [PATCH] [Fix][TIR] LowerCrossThreadReduction with write-back predicate (#14199) Prior to this PR, the cross-thread reduction lowering pass does not add a store predicate to the write-back block. This is in consideration that for a certain write-back buffer position, all values being stored (by all the threads) in the write-back block are the same. Since all threads are writing the same value, we were assuming that not having a write-back block predicate is fine, because the result will not be wrong in any way. However, recently we noticed that some GPU backend compiler will capture this behavior (multiple threads writing a same position) as a race condition and thus throw compilation error. The compiler does not take the fact that all values being stored are the same, and insist on complaining. This means that we will still need the write-back block predicate to make things work. And this PR does this change. I have done integration tests locally to make sure that the generated kernels is right and produces the right results numerically. --- src/tir/transforms/lower_cross_thread_reduction.cc | 8 +++++++- ...t_tir_transform_lower_cross_thread_reduction.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index c10555e74d07..ec23d21598ee 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -407,9 +407,15 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices)); wb_regions.push_back(BufferRegion(wb_buffers[i], region)); } + PrimExpr wb_predicate = const_true(); + for (const ForNode* loop : reduction_loops) { + if (loop->thread_binding.defined()) { + wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0)); + } + } stmts.push_back(BlockRealize( /*iter_values=*/std::move(bindings), - /*predicate=*/const_true(), + /*predicate=*/wb_predicate, /*block=*/ Block(/*iter_vars=*/std::move(iter_vars), /*reads=*/std::move(ct_buffer_regions), diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index dc0a6ce62a66..8b5c21224148 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -93,6 +93,7 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: ) with T.block("B_write_back"): vi = T.axis.S(128, i) + T.where(ki == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -136,6 +137,7 @@ def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None: ) with T.block("B_write_back"): vi = T.axis.spatial(128, i) + T.where(k == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -183,6 +185,7 @@ def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: ) with T.block("B_write_back"): vi = T.axis.spatial(128, i) + T.where(ko == 0 and ki == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -264,6 +267,7 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No ) with T.block("B_write_back"): vi = T.axis.spatial(16, i) + T.where(k0o == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -326,6 +330,7 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: ) with T.block("B_write_back"): vi = T.axis.spatial(128, i) + T.where(ki == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -428,6 +433,7 @@ def lowered_single_reduction_loop_with_block_predicate( ) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0 + ax0) + T.where(ax1_1 == 0) T.reads(cross_thread_0[0]) T.writes(T_softmax_maxelem_shared[i0_2]) T_softmax_maxelem_shared[i0_2] = cross_thread_0[0] @@ -467,6 +473,7 @@ def lowered_single_reduction_loop_with_block_predicate( ) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0 + ax0) + T.where(ax1_1 == 0) T.reads(cross_thread_1[0]) T.writes(T_softmax_expsum_shared[i0_4]) T_softmax_expsum_shared[i0_4] = cross_thread_1[0] @@ -636,6 +643,7 @@ def lowered_reducer_max(a: T.handle, b: T.handle) -> None: ) with T.block("B_write_back"): vi = T.axis.spatial(128, i) + T.where(k == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -676,6 +684,7 @@ def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: with T.block("B_write_back"): T.reads([reduce_temp0[0]]) T.writes([B[()]]) + T.where(k == 0) B[()] = reduce_temp0[0] @@ -865,6 +874,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: ) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) + T.where(ax0_1 == 0) T.reads([reduce_temp0[0]]) T.writes([T_softmax_maxelem_shared[i0_2]]) T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] @@ -907,6 +917,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: ) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) + T.where(ax0_1 == 0) T.reads([reduce_temp1[0]]) T.writes([T_softmax_expsum_shared[i0_4]]) T_softmax_expsum_shared[i0_4] = reduce_temp1[0] @@ -1018,6 +1029,7 @@ def lowered_argmax_split( ) with T.block("argmax_write_back"): i = T.axis.spatial(128, i0) + T.where(i1_1 == 0) T.reads(cross_thread_argmax_v0[0], cross_thread_argmax_v1[0]) T.writes(argmax_v0[i], argmax_v1[i]) argmax_v0[i] = cross_thread_argmax_v0[0] @@ -1109,6 +1121,7 @@ def lowered_argmin_split_init_update_reordered( ) with T.block("argmin_write_back"): i = T.axis.spatial(128, i0) + T.where(i1_1 == 0) T.reads(cross_thread_argmin_v0[0], cross_thread_argmin_v1[0]) T.writes(argmin_v0[i], argmin_v1[i]) argmin_v0[i] = cross_thread_argmin_v0[0] @@ -1227,6 +1240,7 @@ def lowered_layer_norm_tuple_sum( ) with T.block("data_red_temp_write_back"): ax0 = T.axis.spatial(128, i0_fused) + T.where(i1_1 == 0) T.reads(cross_thread_data_red_temp_v0[0], cross_thread_data_red_temp_v1[0]) T.writes(data_red_temp_v0[ax0], data_red_temp_v1[ax0]) data_red_temp_v0[ax0] = cross_thread_data_red_temp_v0[0]