From f28fcd1239315059b66c3b61badc99ffc8181c92 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 11 Apr 2023 18:01:08 -0400 Subject: [PATCH] [TensorIR] Fix ComputeAt with perfect symbolic bound (#14592) This PR fixes a case where we have perfect symbolic bound in compute at and can be eliminated. Testcases are added. --- src/tir/schedule/primitive/compute_at.cc | 9 +++- .../unittest/test_tir_schedule_compute_at.py | 41 +++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 75ea308de8a3..b161bf954dbd 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -271,13 +271,18 @@ class ScopeReconstructor : private StmtMutator { } const arith::IntSet& pred_bound = iter_doms[i].bound; if (!pred_bound.IsNothing()) { + // NOTE: Apply strong analyzer proofs to get rid of symbolic bound if (pred_bound.HasLowerBound()) { PrimExpr lower_bound = iter_values[i] >= pred_bound.min(); - predicate = predicate && lower_bound; + if (!analyzer->CanProve(lower_bound, arith::ProofStrength::kSymbolicBound)) { + predicate = predicate && lower_bound; + } } if (pred_bound.HasUpperBound()) { PrimExpr upper_bound = iter_values[i] < pred_bound.max() + 1; - predicate = predicate && upper_bound; + if (!analyzer->CanProve(upper_bound, arith::ProofStrength::kSymbolicBound)) { + predicate = predicate && upper_bound; + } } } } diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 364a43acda49..0623fb02f3d6 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -1282,6 +1282,47 @@ def test_compute_at_simplify_static_bound(use_block_name): verify_trace_roundtrip(sch=sch, mod=static_bound) +def test_compute_at_simplify_symbolic_predicate(): + @tvm.script.ir_module + class Before: + @T.prim_func + def main(x: T.handle, y: T.handle, n: T.int64): + X = T.match_buffer(x, (T.int64(8), n * 32), "float32") + Y = T.match_buffer(y, (T.int64(8), n * 32), "float32") + for i, k in T.grid(T.int64(8), n * 32): + with T.block("Y"): + vi, vk = T.axis.remap("SS", [i, k]) + Y[vi, vk] = X[vi, vk] + + @tvm.script.ir_module + class After: + @T.prim_func + def main(x: T.handle, y: T.handle, n: T.int64): + X = T.match_buffer(x, (T.int64(8), n * T.int64(32))) + Y = T.match_buffer(y, (T.int64(8), n * T.int64(32))) + X_global = T.alloc_buffer((T.int64(8), n * T.int64(32))) + + for i, k_0 in T.grid(T.int64(8), n): + for ax0 in range(T.int64(32)): + with T.block("X_global"): + v0 = T.axis.spatial(T.int64(8), i) + v1 = T.axis.spatial(n * T.int64(32), k_0 * T.int64(32) + ax0) + X_global[v0, v1] = X[v0, v1] + for k_1 in range(T.int64(32)): + with T.block("Y"): + vi = T.axis.spatial(T.int64(8), i) + vk = T.axis.spatial(n * T.int64(32), k_0 * T.int64(32) + k_1) + Y[vi, vk] = X_global[vi, vk] + + sch = tir.Schedule(Before, debug_mask="all") + block = sch.get_block("Y") + i, k = sch.get_loops(sch.get_block("Y")) + ko, ki = sch.split(k, [None, 32]) + XX = sch.cache_read(block, 0, "global") + sch.compute_at(XX, ko) + tvm.ir.assert_structural_equal(sch.mod, After) + + def test_compute_at_non_perfect_channel_group(use_block_name): @T.prim_func def grouped_channel_bias(