Skip to content

Commit

Permalink
[TIR][Schedule] simpilfy compute_at static bound (apache#10307)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored Feb 18, 2022
1 parent ffd517a commit f3ea291
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ struct BlockVarDomainInfo {
if (analyzer->CanProveEqual(dom.min(), intersect.min()) &&
analyzer->CanProveEqual(dom.max(), intersect.max())) {
bound = arith::IntSet::Nothing();
} else if (analyzer->CanProveEqual(bound.min(), intersect.min()) &&
analyzer->CanProveEqual(bound.max(), intersect.max())) {
dom = bound;
bound = arith::IntSet::Nothing();
}
}
};
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,36 @@ def tiled_repeat_op_after_compute_at(x: T.Buffer[(4,), "float32"], T_repeat: T.B
ax0 = T.axis.spatial(64, i0_0 * 8 + i0_1)
T_repeat[ax0] = T_add[ax0 // 16]

@T.prim_func
def static_bound(A: T.Buffer[(32, 1), "float32"], C: T.Buffer[(32, 1), "float32"]) -> None:
B = T.alloc_buffer((32, 1), "float32")
for i, j in T.grid(32, 1):
with T.block("B"):
vi = T.axis.spatial(32, i)
vj = T.axis.spatial(1, j)
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(32, 32):
with T.block("C"):
vi = T.axis.spatial(32, i)
vj = T.axis.spatial(1, j)
T.where(j < 1)
C[vi, vj] = B[vi, vj] + 1.0

@T.prim_func
def static_bound_after_compute_at(A: T.Buffer[(32, 1), "float32"], C: T.Buffer[(32, 1), "float32"]) -> None:
B = T.alloc_buffer((32, 1), "float32")
for i in range(32):
for ax0, ax1 in T.grid(1, 1):
with T.block("B"):
vi = T.axis.spatial(32, i + ax0)
vj = T.axis.spatial(1, ax1)
B[vi, vj] = A[vi, vj] * 2.0
for j in range(32):
with T.block("C"):
vi = T.axis.spatial(32, i)
vj = T.axis.spatial(1, j)
T.where(j < 1)
C[vi, vj] = B[vi, vj] + 1.0
# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
# fmt: on

Expand Down Expand Up @@ -1166,6 +1196,15 @@ def test_compact_dataflow():
verify_trace_roundtrip(sch=sch, mod=not_all_compact_data_flow)


def test_compute_at_simplify_static_bound():
sch = tir.Schedule(static_bound, debug_mask="all")
block = sch.get_block("B")
loop, _ = sch.get_loops(sch.get_block("C"))
sch.compute_at(block, loop, preserve_unit_loops=True)
tvm.ir.assert_structural_equal(static_bound_after_compute_at, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=static_bound)


def test_fail_subtree_complete_block():
sch = tir.Schedule(fail_subtree_compact_dataflow, debug_mask="all")
block = sch.get_block("B_0")
Expand Down

0 comments on commit f3ea291

Please sign in to comment.