Skip to content

Commit

Permalink
[TIR] Fix reverse_compute_at for trivial region with trivial block var (
Browse files Browse the repository at this point in the history
apache#11234)

* [TIR] Fix reverse_compute_at for trivial region with trivial block var

* Prevent handle arithmetics
  • Loading branch information
vinx13 authored and Sergey Shtin committed May 17, 2022
1 parent f11ea79 commit eb09860
Showing 4 changed files with 59 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
@@ -104,6 +104,9 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) {
const auto* clhs = lhs.as<IntImmNode>();
const auto* crhs = rhs.as<IntImmNode>();
if (clhs && crhs) return clhs->value == crhs->value;
if (lhs->dtype.is_handle() || rhs->dtype.is_handle()) {
return lhs.same_as(rhs);
}
return CanProve(lhs - rhs == 0);
}

8 changes: 7 additions & 1 deletion src/arith/interval_set.h
Original file line number Diff line number Diff line change
@@ -59,7 +59,13 @@ class IntervalSetNode : public IntSetNode {
/*! \return Whether the interval has lower bound. */
bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); }
/*! \return Whether the interval is a single point. */
bool IsSinglePoint() const { return min_value.same_as(max_value); }
bool IsSinglePoint() const {
if (min_value.same_as(max_value)) {
return true;
}
Analyzer analyzer;
return analyzer.CanProveEqual(min_value, max_value);
}
/*! \return whether interval represent nothing */
bool IsEmpty() const {
// during computations, either extreme could occur.
6 changes: 6 additions & 0 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
@@ -428,6 +428,12 @@ void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& re
const arith::IntSet& required_bound,
std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms,
arith::Analyzer* analyzer) {
if (provided.IsSinglePoint() && is_const_int(provided.min())) {
ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min()));
ICHECK(required_bound.IsSinglePoint() &&
analyzer->CanProveEqual(provided.min(), required_bound.min()));
return;
}
auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer);
auto var_with_bound = SolveBlockVarDomain(provided, required_bound, analyzer);
const Var& var = var_with_dom.first;
43 changes: 43 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
@@ -582,6 +582,40 @@ def tiled_after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
C[vi, vj] = B[vi, vj] + 1.0


@T.prim_func
def tiled_trivial_binding(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [1, 128, 128], "float32")
B = T.alloc_buffer([1, 128, 128], "float32")
C = T.match_buffer(c, [1, 128, 128], "float32")
for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
with T.block("B"):
vi = T.axis.S(128, i_0 * 16 + i_1)
vj = T.axis.S(128, j_0 * 16 + j_1)
B[0, vi, vj] = A[0, vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[0, vi, vj] = B[0, vi, vj] + 1.0


@T.prim_func
def tiled_trivial_binding_after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [1, 128, 128], "float32")
B = T.alloc_buffer([1, 128, 128], "float32")
C = T.match_buffer(c, [1, 128, 128], "float32")
for i_0, j_0, i_1 in T.grid(8, 8, 16):
for j_1 in T.serial(0, 16):
with T.block("B"):
vi = T.axis.S(128, i_0 * 16 + i_1)
vj = T.axis.S(128, j_0 * 16 + j_1)
B[0, vi, vj] = A[0, vi, vj] * 2.0
for j_1 in T.serial(0, 16):
with T.block("C"):
vi = T.axis.S(128, i_0 * 16 + i_1)
vj = T.axis.S(128, j_0 * 16 + j_1)
C[0, vi, vj] = B[0, vi, vj] + 1.0


@T.prim_func
def factorized(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [16, 16, 16], "float32")
@@ -1149,6 +1183,15 @@ def test_reverse_compute_at_tiled():
verify_trace_roundtrip(sch=sch, mod=tiled)


def test_reverse_compute_at_tiled_trivial_binding():
sch = tir.Schedule(tiled_trivial_binding, debug_mask="all")
block = sch.get_block("C")
_, _, loop, _ = sch.get_loops(sch.get_block("B"))
sch.reverse_compute_at(block, loop, preserve_unit_loops=False)
tvm.ir.assert_structural_equal(tiled_trivial_binding_after_reverse_compute_at, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=tiled_trivial_binding)


def test_reverse_compute_at_blockized_2():
sch = tir.Schedule(blockized_2, debug_mask="all")
block = sch.get_block("C")

0 comments on commit eb09860

Please sign in to comment.