diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 76033c4890a5..b922138057e9 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -104,6 +104,9 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { const auto* clhs = lhs.as(); const auto* crhs = rhs.as(); if (clhs && crhs) return clhs->value == crhs->value; + if (lhs->dtype.is_handle() || rhs->dtype.is_handle()) { + return lhs.same_as(rhs); + } return CanProve(lhs - rhs == 0); } diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index eb308dd385a4..98fe5bdc2bc6 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -59,7 +59,13 @@ class IntervalSetNode : public IntSetNode { /*! \return Whether the interval has lower bound. */ bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); } /*! \return Whether the interval is a single point. */ - bool IsSinglePoint() const { return min_value.same_as(max_value); } + bool IsSinglePoint() const { + if (min_value.same_as(max_value)) { + return true; + } + Analyzer analyzer; + return analyzer.CanProveEqual(min_value, max_value); + } /*! \return whether interval represent nothing */ bool IsEmpty() const { // during computations, either extreme could occur. diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index b811afb23614..2a349f8fe61e 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -428,6 +428,12 @@ void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& re const arith::IntSet& required_bound, std::unordered_map* iter_doms, arith::Analyzer* analyzer) { + if (provided.IsSinglePoint() && is_const_int(provided.min())) { + ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min())); + ICHECK(required_bound.IsSinglePoint() && + analyzer->CanProveEqual(provided.min(), required_bound.min())); + return; + } auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer); auto var_with_bound = SolveBlockVarDomain(provided, required_bound, analyzer); const Var& var = var_with_dom.first; diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index e1cf399d49a1..25b69aa6de19 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -582,6 +582,40 @@ def tiled_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: C[vi, vj] = B[vi, vj] + 1.0 +@T.prim_func +def tiled_trivial_binding(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [1, 128, 128], "float32") + B = T.alloc_buffer([1, 128, 128], "float32") + C = T.match_buffer(c, [1, 128, 128], "float32") + for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): + with T.block("B"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) + B[0, vi, vj] = A[0, vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[0, vi, vj] = B[0, vi, vj] + 1.0 + + +@T.prim_func +def tiled_trivial_binding_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [1, 128, 128], "float32") + B = T.alloc_buffer([1, 128, 128], "float32") + C = T.match_buffer(c, [1, 128, 128], "float32") + for i_0, j_0, i_1 in T.grid(8, 8, 16): + for j_1 in T.serial(0, 16): + with T.block("B"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) + B[0, vi, vj] = A[0, vi, vj] * 2.0 + for j_1 in T.serial(0, 16): + with T.block("C"): + vi = T.axis.S(128, i_0 * 16 + i_1) + vj = T.axis.S(128, j_0 * 16 + j_1) + C[0, vi, vj] = B[0, vi, vj] + 1.0 + + @T.prim_func def factorized(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16], "float32") @@ -1149,6 +1183,15 @@ def test_reverse_compute_at_tiled(): verify_trace_roundtrip(sch=sch, mod=tiled) +def test_reverse_compute_at_tiled_trivial_binding(): + sch = tir.Schedule(tiled_trivial_binding, debug_mask="all") + block = sch.get_block("C") + _, _, loop, _ = sch.get_loops(sch.get_block("B")) + sch.reverse_compute_at(block, loop, preserve_unit_loops=False) + tvm.ir.assert_structural_equal(tiled_trivial_binding_after_reverse_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=tiled_trivial_binding) + + def test_reverse_compute_at_blockized_2(): sch = tir.Schedule(blockized_2, debug_mask="all") block = sch.get_block("C")