diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 88bf7663e85d..b811afb23614 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -202,6 +202,10 @@ struct BlockVarDomainInfo { if (analyzer->CanProveEqual(dom.min(), intersect.min()) && analyzer->CanProveEqual(dom.max(), intersect.max())) { bound = arith::IntSet::Nothing(); + } else if (analyzer->CanProveEqual(bound.min(), intersect.min()) && + analyzer->CanProveEqual(bound.max(), intersect.max())) { + dom = bound; + bound = arith::IntSet::Nothing(); } } }; diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 4d081b507403..e1cf399d49a1 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -987,6 +987,36 @@ def tiled_repeat_op_after_compute_at(x: T.Buffer[(4,), "float32"], T_repeat: T.B ax0 = T.axis.spatial(64, i0_0 * 8 + i0_1) T_repeat[ax0] = T_add[ax0 // 16] +@T.prim_func +def static_bound(A: T.Buffer[(32, 1), "float32"], C: T.Buffer[(32, 1), "float32"]) -> None: + B = T.alloc_buffer((32, 1), "float32") + for i, j in T.grid(32, 1): + with T.block("B"): + vi = T.axis.spatial(32, i) + vj = T.axis.spatial(1, j) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(32, 32): + with T.block("C"): + vi = T.axis.spatial(32, i) + vj = T.axis.spatial(1, j) + T.where(j < 1) + C[vi, vj] = B[vi, vj] + 1.0 + +@T.prim_func +def static_bound_after_compute_at(A: T.Buffer[(32, 1), "float32"], C: T.Buffer[(32, 1), "float32"]) -> None: + B = T.alloc_buffer((32, 1), "float32") + for i in range(32): + for ax0, ax1 in T.grid(1, 1): + with T.block("B"): + vi = T.axis.spatial(32, i + ax0) + vj = T.axis.spatial(1, ax1) + B[vi, vj] = A[vi, vj] * 2.0 + for j in range(32): + with T.block("C"): + vi = T.axis.spatial(32, i) + vj = T.axis.spatial(1, j) + T.where(j < 1) + C[vi, vj] = B[vi, vj] + 1.0 # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on @@ -1166,6 +1196,15 @@ def test_compact_dataflow(): verify_trace_roundtrip(sch=sch, mod=not_all_compact_data_flow) +def test_compute_at_simplify_static_bound(): + sch = tir.Schedule(static_bound, debug_mask="all") + block = sch.get_block("B") + loop, _ = sch.get_loops(sch.get_block("C")) + sch.compute_at(block, loop, preserve_unit_loops=True) + tvm.ir.assert_structural_equal(static_bound_after_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=static_bound) + + def test_fail_subtree_complete_block(): sch = tir.Schedule(fail_subtree_compact_dataflow, debug_mask="all") block = sch.get_block("B_0")