Skip to content

Commit

Permalink
[BugFix][TIR] Fix cross-thread reduction when single reduction loop w…
Browse files Browse the repository at this point in the history
…ith predicate (apache#10016)
  • Loading branch information
MasterJH5574 authored and ylc committed Feb 16, 2022
1 parent c1637c8 commit f5b3b4b
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/tir/transforms/lower_cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,10 +582,12 @@ class CrossThreadReductionTransformer : public StmtMutator {
PrimExpr combiner_rhs{nullptr};
std::tie(n_bound_reduction_loops, reducer, combiner_rhs) =
CheckCanApplyCrossThreadReduction(block, reduction_loops);
// Step 3. When not all the reduction-related loops are bound to thread axes, in-thread
// reduction is needed in this cross-thread reduction.
// Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when
// - not all the reduction-related loops are bound to thread axes, or
// - the block-realize has a non-constant-true predicate.
bool need_in_thread_reduction =
n_bound_reduction_loops < static_cast<int>(reduction_loops.size());
n_bound_reduction_loops < static_cast<int>(reduction_loops.size()) ||
!is_one(realize->predicate);
// Step 4. Create intermediate buffers, storing them in `ct_buffer` and
// `it_buffer`. Let the scope block allocate these new buffers.
std::vector<Buffer>& new_buffers = block2new_buffers_[block_stack_.back()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,162 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None:
B[vi] = reduce_temp0[0]


@T.prim_func
def single_reduction_loop_with_block_predicate(
A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
) -> None:
T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
for i0 in T.serial(256):
for ax0, ax1_0 in T.grid(1, 1):
for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_maxelem"):
i0_1 = T.axis.spatial(256, i0)
k = T.axis.reduce(256, ax1_1)
T.where(ax1_0 * 512 + ax1_1 < 256)
T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k])
T.writes(T_softmax_maxelem_shared[i0_1])
with T.init():
T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38)
T_softmax_maxelem_shared[i0_1] = T.max(
T_softmax_maxelem_shared[i0_1], A[i0_1, k]
)
for ax0, ax1_0 in T.grid(1, 1):
for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_expsum"):
i0_2 = T.axis.spatial(256, i0)
k = T.axis.reduce(256, ax1_1)
T.where(ax1_0 * 512 + ax1_1 < 256)
T.reads(
T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2]
)
T.writes(T_softmax_expsum_shared[i0_2])
with T.init():
T_softmax_expsum_shared[i0_2] = T.float32(0)
T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32"
)
for i1_0 in T.serial(1):
for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_norm"):
i0_3 = T.axis.spatial(256, i0)
i1 = T.axis.spatial(256, i1_1)
T.where(i1_0 * 512 + i1_1 < 256)
T.reads(
A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3]
)
T.writes(T_softmax_norm[i0_3, i1])
T.block_attr({"axis": 1})
T_softmax_norm[i0_3, i1] = (
T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32")
/ T_softmax_expsum_shared[i0_3]
)


@T.prim_func
def lowered_single_reduction_loop_with_block_predicate(
A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]
) -> None:
T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
cross_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
in_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local")
for i0 in T.serial(256):
for ax0, ax1_0 in T.grid(1, 1):
for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_maxelem_in_thread_init"):
T.reads()
T.writes(in_thread_0[0])
in_thread_0[0] = T.float32(-3.4028234663852886e38)
with T.block("T_softmax_maxelem_in_thread"):
i0_1 = T.axis.spatial(256, i0)
k = T.axis.reduce(256, ax1_1)
T.where(ax1_0 * 512 + ax1_1 < 256)
T.reads(A[i0_1, k], in_thread_0[0])
T.writes(in_thread_0[0])
in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k])
with T.block("T_softmax_maxelem_cross_thread"):
T.reads(in_thread_0[0])
T.writes(cross_thread_0[0])
T.attr(
T.comm_reducer(
lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e38)]
),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
)
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
in_thread_0[0],
True,
cross_thread_0.data,
ax1_1,
dtype="handle",
)
)
with T.block("T_softmax_maxelem_write_back"):
i0_2 = T.axis.spatial(256, i0)
T.reads(cross_thread_0[0])
T.writes(T_softmax_maxelem_shared[i0_2])
T_softmax_maxelem_shared[i0_2] = cross_thread_0[0]
for ax0, ax1_0 in T.grid(1, 1):
for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_expsum_in_thread_init"):
T.reads()
T.writes(in_thread_1[0])
in_thread_1[0] = T.float32(0)
with T.block("T_softmax_expsum_in_thread"):
i0_3 = T.axis.spatial(256, i0)
k = T.axis.reduce(256, ax1_1)
T.where(ax1_0 * 512 + ax1_1 < 256)
T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3], in_thread_1[0])
T.writes(in_thread_1[0])
in_thread_1[0] = in_thread_1[0] + T.exp(
A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32"
)
with T.block("T_softmax_expsum_cross_thread"):
T.reads(in_thread_1[0])
T.writes(cross_thread_1[0])
T.attr(
T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
)
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
in_thread_1[0],
True,
cross_thread_1.data,
ax1_1,
dtype="handle",
)
)
with T.block("T_softmax_expsum_write_back"):
i0_4 = T.axis.spatial(256, i0)
T.reads(cross_thread_1[0])
T.writes(T_softmax_expsum_shared[i0_4])
T_softmax_expsum_shared[i0_4] = cross_thread_1[0]
for i1_0 in T.serial(1):
for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
with T.block("T_softmax_norm"):
i0_5 = T.axis.spatial(256, i0)
i1 = T.axis.spatial(256, i1_1)
T.where(i1_0 * 512 + i1_1 < 256)
T.reads(
A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5]
)
T.writes(T_softmax_norm[i0_5, i1])
T.block_attr({"axis": 1})
T_softmax_norm[i0_5, i1] = (
T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32")
/ T_softmax_expsum_shared[i0_5]
)


@T.prim_func
def reducer_max(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128], dtype="float32")
Expand Down Expand Up @@ -686,6 +842,13 @@ def test_with_block_predicate():
_check(with_block_predicate, lowered_with_block_predicate)


def test_single_reduction_loop_with_block_predicate():
_check(
single_reduction_loop_with_block_predicate,
lowered_single_reduction_loop_with_block_predicate,
)


def test_reducer_max():
_check(reducer_max, lowered_reducer_max)

Expand Down

0 comments on commit f5b3b4b

Please sign in to comment.