diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index c10555e74d07..ec23d21598ee 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -407,9 +407,15 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, // BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices)); wb_regions.push_back(BufferRegion(wb_buffers[i], region)); } + PrimExpr wb_predicate = const_true(); + for (const ForNode* loop : reduction_loops) { + if (loop->thread_binding.defined()) { + wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0)); + } + } stmts.push_back(BlockRealize( /*iter_values=*/std::move(bindings), - /*predicate=*/const_true(), + /*predicate=*/wb_predicate, /*block=*/ Block(/*iter_vars=*/std::move(iter_vars), /*reads=*/std::move(ct_buffer_regions), diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index dc0a6ce62a66..8b5c21224148 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -93,6 +93,7 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: ) with T.block("B_write_back"): vi = T.axis.S(128, i) + T.where(ki == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -136,6 +137,7 @@ def lowered_no_normal_reduction(a: T.handle, b: T.handle) -> None: ) with T.block("B_write_back"): vi = T.axis.spatial(128, i) + T.where(k == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -183,6 +185,7 @@ def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None: ) with T.block("B_write_back"): vi = T.axis.spatial(128, i) + T.where(ko == 0 and ki == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -264,6 +267,7 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No ) with T.block("B_write_back"): vi = T.axis.spatial(16, i) + T.where(k0o == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -326,6 +330,7 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: ) with T.block("B_write_back"): vi = T.axis.spatial(128, i) + T.where(ki == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -428,6 +433,7 @@ def lowered_single_reduction_loop_with_block_predicate( ) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0 + ax0) + T.where(ax1_1 == 0) T.reads(cross_thread_0[0]) T.writes(T_softmax_maxelem_shared[i0_2]) T_softmax_maxelem_shared[i0_2] = cross_thread_0[0] @@ -467,6 +473,7 @@ def lowered_single_reduction_loop_with_block_predicate( ) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0 + ax0) + T.where(ax1_1 == 0) T.reads(cross_thread_1[0]) T.writes(T_softmax_expsum_shared[i0_4]) T_softmax_expsum_shared[i0_4] = cross_thread_1[0] @@ -636,6 +643,7 @@ def lowered_reducer_max(a: T.handle, b: T.handle) -> None: ) with T.block("B_write_back"): vi = T.axis.spatial(128, i) + T.where(k == 0) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0] @@ -676,6 +684,7 @@ def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None: with T.block("B_write_back"): T.reads([reduce_temp0[0]]) T.writes([B[()]]) + T.where(k == 0) B[()] = reduce_temp0[0] @@ -865,6 +874,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: ) with T.block("T_softmax_maxelem_write_back"): i0_2 = T.axis.spatial(256, i0) + T.where(ax0_1 == 0) T.reads([reduce_temp0[0]]) T.writes([T_softmax_maxelem_shared[i0_2]]) T_softmax_maxelem_shared[i0_2] = reduce_temp0[0] @@ -907,6 +917,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: ) with T.block("T_softmax_expsum_write_back"): i0_4 = T.axis.spatial(256, i0) + T.where(ax0_1 == 0) T.reads([reduce_temp1[0]]) T.writes([T_softmax_expsum_shared[i0_4]]) T_softmax_expsum_shared[i0_4] = reduce_temp1[0] @@ -1018,6 +1029,7 @@ def lowered_argmax_split( ) with T.block("argmax_write_back"): i = T.axis.spatial(128, i0) + T.where(i1_1 == 0) T.reads(cross_thread_argmax_v0[0], cross_thread_argmax_v1[0]) T.writes(argmax_v0[i], argmax_v1[i]) argmax_v0[i] = cross_thread_argmax_v0[0] @@ -1109,6 +1121,7 @@ def lowered_argmin_split_init_update_reordered( ) with T.block("argmin_write_back"): i = T.axis.spatial(128, i0) + T.where(i1_1 == 0) T.reads(cross_thread_argmin_v0[0], cross_thread_argmin_v1[0]) T.writes(argmin_v0[i], argmin_v1[i]) argmin_v0[i] = cross_thread_argmin_v0[0] @@ -1227,6 +1240,7 @@ def lowered_layer_norm_tuple_sum( ) with T.block("data_red_temp_write_back"): ax0 = T.axis.spatial(128, i0_fused) + T.where(i1_1 == 0) T.reads(cross_thread_data_red_temp_v0[0], cross_thread_data_red_temp_v1[0]) T.writes(data_red_temp_v0[ax0], data_red_temp_v1[ax0]) data_red_temp_v0[ax0] = cross_thread_data_red_temp_v0[0]