Skip to content

Commit

Permalink
[Schedule][Bugfix] Fix decompose padding wrt the single child subtree (
Browse files Browse the repository at this point in the history
…apache#13646)

Fix bug when decompose padding wrt the single child subtree
  • Loading branch information
wrongtest-intellif authored and fzi-peccia committed Mar 27, 2023
1 parent f554f35 commit 5a043c9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/tir/schedule/primitive/decompose_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class PaddingInfoAnalyzer {

// Step 3. Analyze in-bound write region.
PrimExpr in_bound_predicate = RewritePredicate(pad_predicate && realize->predicate);
if (analyzer_->CanProveEqual(in_bound_predicate, 1)) {
SetError("The in-bound predicate is trivial");
return false;
}
Array<Range> in_bound_region = this->EstimateInBoundRegion(
/*iter_values=*/realize->iter_values, /*dom_map=*/dom_map,
/*in_bound_predicate=*/in_bound_predicate);
Expand Down Expand Up @@ -439,13 +443,14 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
analyzer.Bind(cur_loop->loop_var, range);
loops.push_back(cur_loop);

if (!found_const_filling_pos) {
if (cur_loop.same_as(const_filling_pos)) {
found_const_filling_pos = true;
if (cur_loop.same_as(const_filling_pos)) {
ICHECK(!found_const_filling_pos);
found_const_filling_pos = true;
if (!found_in_bound_filling_pos) {
found_in_bound_filling_pos = true;
in_bound_filling_pos = cur_loop;
}
}

if (!found_in_bound_filling_pos) {
} else if (!found_in_bound_filling_pos) {
if (!cur_loop->body->IsInstance<ForNode>() &&
!cur_loop->body->IsInstance<BlockRealizeNode>()) {
found_in_bound_filling_pos = true;
Expand Down
63 changes: 63 additions & 0 deletions tests/python/unittest/test_tir_schedule_decompose_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,5 +309,68 @@ def pooling_decompose_3(
check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_3, check_run=True)


def test_decompose_wrt_single_child_subtree():
"""Test the case when the decompose position is under the single child subtree"""

@T.prim_func
def pad_op(
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 231, 231], dtype="int8")
):
for i0, i1, i2, i3 in T.grid(1, 16, 231, 231):
with T.block("pad_temp"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
y[ax0, ax1, ax2, ax3] = T.if_then_else(
3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228,
x[ax0, ax1, ax2 - 3, ax3 - 3],
T.int8(0),
dtype="int8",
)

@T.prim_func
def pad_op_after(
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer[(1, 16, 231, 231), "int8"]
):
for i0, i1 in T.grid(1, 16):
for i2, i3 in T.grid(231, 231):
with T.block("pad_temp_pad_const"):
ax0 = T.axis.spatial(1, 0)
ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
y[ax0, ax1, ax2, ax3] = T.int8(0)
for i2, i3 in T.grid(225, 225):
with T.block("pad_temp"):
ax0 = T.axis.spatial(1, 0)
ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
y[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3]

sch = tir.Schedule(pad_op, debug_mask="all")
pad = sch.get_block("pad_temp")
_, _, h, _ = sch.get_loops(pad)
sch.decompose_padding(pad, h)
check_decompose_padding(pad_op, sch.mod["main"], pad_op_after, check_run=True)


def test_not_to_decompose_trivial_predicate():
"""Test the case when the padding condition is trivial"""

@T.prim_func
def trivial_pad(
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 225, 225], dtype="int8")
):
for i0, i1, i2, i3 in T.grid(1, 16, 225, 225):
with T.block("pad_temp"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
y[ax0, ax1, ax2, ax3] = T.if_then_else(
0 <= ax2 and ax2 < 225 and 0 <= ax3 and ax3 < 225,
x[ax0, ax1, ax2, ax3],
T.int8(0),
dtype="int8",
)

sch = tir.Schedule(trivial_pad, debug_mask="all")
pad = sch.get_block("pad_temp")
_, _, h, _ = sch.get_loops(pad)
assert not sch.can_decompose_padding(pad, h)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 5a043c9

Please sign in to comment.