Skip to content

Commit

Permalink
[TensorIR] Fix ComputeAt with perfect symbolic bound (#14592)
Browse files Browse the repository at this point in the history
This PR fixes a case where we have perfect symbolic bound
in compute at and can be eliminated.

Testcases are added.
  • Loading branch information
tqchen authored Apr 11, 2023
1 parent 8554e7a commit f28fcd1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,18 @@ class ScopeReconstructor : private StmtMutator {
}
const arith::IntSet& pred_bound = iter_doms[i].bound;
if (!pred_bound.IsNothing()) {
// NOTE: Apply strong analyzer proofs to get rid of symbolic bound
if (pred_bound.HasLowerBound()) {
PrimExpr lower_bound = iter_values[i] >= pred_bound.min();
predicate = predicate && lower_bound;
if (!analyzer->CanProve(lower_bound, arith::ProofStrength::kSymbolicBound)) {
predicate = predicate && lower_bound;
}
}
if (pred_bound.HasUpperBound()) {
PrimExpr upper_bound = iter_values[i] < pred_bound.max() + 1;
predicate = predicate && upper_bound;
if (!analyzer->CanProve(upper_bound, arith::ProofStrength::kSymbolicBound)) {
predicate = predicate && upper_bound;
}
}
}
}
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,47 @@ def test_compute_at_simplify_static_bound(use_block_name):
verify_trace_roundtrip(sch=sch, mod=static_bound)


def test_compute_at_simplify_symbolic_predicate():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(x: T.handle, y: T.handle, n: T.int64):
X = T.match_buffer(x, (T.int64(8), n * 32), "float32")
Y = T.match_buffer(y, (T.int64(8), n * 32), "float32")
for i, k in T.grid(T.int64(8), n * 32):
with T.block("Y"):
vi, vk = T.axis.remap("SS", [i, k])
Y[vi, vk] = X[vi, vk]

@tvm.script.ir_module
class After:
@T.prim_func
def main(x: T.handle, y: T.handle, n: T.int64):
X = T.match_buffer(x, (T.int64(8), n * T.int64(32)))
Y = T.match_buffer(y, (T.int64(8), n * T.int64(32)))
X_global = T.alloc_buffer((T.int64(8), n * T.int64(32)))

for i, k_0 in T.grid(T.int64(8), n):
for ax0 in range(T.int64(32)):
with T.block("X_global"):
v0 = T.axis.spatial(T.int64(8), i)
v1 = T.axis.spatial(n * T.int64(32), k_0 * T.int64(32) + ax0)
X_global[v0, v1] = X[v0, v1]
for k_1 in range(T.int64(32)):
with T.block("Y"):
vi = T.axis.spatial(T.int64(8), i)
vk = T.axis.spatial(n * T.int64(32), k_0 * T.int64(32) + k_1)
Y[vi, vk] = X_global[vi, vk]

sch = tir.Schedule(Before, debug_mask="all")
block = sch.get_block("Y")
i, k = sch.get_loops(sch.get_block("Y"))
ko, ki = sch.split(k, [None, 32])
XX = sch.cache_read(block, 0, "global")
sch.compute_at(XX, ko)
tvm.ir.assert_structural_equal(sch.mod, After)


def test_compute_at_non_perfect_channel_group(use_block_name):
@T.prim_func
def grouped_channel_bias(
Expand Down

0 comments on commit f28fcd1

Please sign in to comment.