Skip to content

Commit

Permalink
fix compute inline not to over write annotated opaque accesses (#9509)
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif authored Nov 15, 2021
1 parent 3f9b72d commit 22ba652
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,11 @@ class BaseInliner : public StmtExprMutator {
// Step 2. Update `BlockNode::reads` and `BlockNode::writes`
Array<BufferRegion> reads = std::move(block->reads);
Array<BufferRegion> writes = std::move(block->writes);
if (!is_scope_root) {
auto f_access_inline_buffer = [this](const BufferRegion& access) {
return access->buffer.same_as(this->inlined_buffer_);
};
if (!is_scope_root && (std::any_of(reads.begin(), reads.end(), f_access_inline_buffer) ||
std::any_of(writes.begin(), writes.end(), f_access_inline_buffer))) {
Array<Array<BufferRegion>> inspected = GetBlockReadWriteRegion(block, buffer_var_map_);
reads = std::move(inspected[0]);
writes = std::move(inspected[1]);
Expand Down
65 changes: 65 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,63 @@ def elementwise_multi_loads_inlined(a: T.handle, c: T.handle) -> None:
C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0


@T.prim_func
def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [1024])
B = T.match_buffer(b, [1024])
A_cache = T.alloc_buffer([1024])
BB = T.alloc_buffer([1024])
with T.block("opaque"):
# annotated opaque partial access
T.reads(A[0:512])
T.writes(A_cache[0:512])
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle"
)
)
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle"
)
)
for i in range(512):
with T.block("BB"):
vi = T.axis.remap("S", [i])
BB[vi] = A_cache[vi] * 2.0
for i in range(512):
with T.block("B"):
vi = T.axis.remap("S", [i])
B[vi] = BB[vi] + 1.0


@T.prim_func
def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [1024], dtype="float32")
B = T.match_buffer(b, [1024], dtype="float32")
A_cache = T.alloc_buffer([1024], dtype="float32")
with T.block("opaque"):
# annotated opaque partial access should be kept
T.reads(A[0:512])
T.writes([A_cache[0:512]])
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), A.data, 0, 512, "r", dtype="handle"
)
)
T.evaluate(
T.tvm_access_ptr(
T.type_annotation(dtype="float32"), A_cache.data, 0, 512, "w", dtype="handle"
)
)
for i in T.serial(0, 512):
with T.block("B"):
vi = T.axis.spatial(512, i)
T.reads([A_cache[vi]])
T.writes([B[vi]])
B[vi] = A_cache[vi] * 2.0 + 1.0


# pylint: enable=no-member,invalid-name,unused-variable


Expand Down Expand Up @@ -417,5 +474,13 @@ def test_compute_inline_multi_loads():
verify_trace_roundtrip(sch=sch, mod=elementwise_multi_loads)


def test_compute_inline_with_opaque_access():
"""Test not rewrite opaque reads/writes after irrelavant compute inline"""
sch = tir.Schedule(access_opaque_ptr_then_elemwise, debug_mask="all")
BB = sch.get_block("BB")
sch.compute_inline(BB)
tvm.ir.assert_structural_equal(access_opaque_ptr_then_elemwise_inline, sch.mod["main"])


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 22ba652

Please sign in to comment.