Skip to content

Commit

Permalink
[TIR] Fix opaque access in buffer locator pass and match_buffer in re…
Browse files Browse the repository at this point in the history
…gion detector (#8855)

* init

* fix

* Update src/tir/transforms/plan_update_buffer_allocation_location.cc

Co-authored-by: Ruihang Lai <[email protected]>

* Update src/tir/transforms/plan_update_buffer_allocation_location.cc

Co-authored-by: Ruihang Lai <[email protected]>

* address

Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
  • Loading branch information
3 people authored Aug 28, 2021
1 parent 1df6c27 commit 7214f52
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 20 deletions.
7 changes: 5 additions & 2 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ void BlockReadWriteDetector::operator()(const Stmt& stmt) {
ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey();
for (const MatchBufferRegion& match_buffer : block->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
match_buffers_[target_var.get()] = match_buffer;
buffer_var_map_.Set(target_var, match_buffer->buffer);
const Var& source_var = match_buffer->source->buffer->data;
if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) {
match_buffers_[target_var.get()] = match_buffer;
buffer_var_map_.Set(target_var, match_buffer->buffer);
}
}
StmtExprVisitor::operator()(stmt);
}
Expand Down
39 changes: 27 additions & 12 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ class BufferAllocationLocator : public StmtExprMutator {

Stmt VisitStmt_(const BlockNode* op) final {
ICHECK(!op->init.defined());
bool is_root = is_root_;
is_root_ = false;
Array<Buffer> alloc_buffers;
auto it = alloc_buffers_.find(op);
if (it != alloc_buffers_.end()) {
Expand All @@ -85,11 +83,23 @@ class BufferAllocationLocator : public StmtExprMutator {
buffer_data_to_buffer_.Set(buf->data, buf);
}
}
for (const MatchBufferRegion match_buffer : op->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
const Var& source_var = match_buffer->source->buffer->data;
ICHECK(buffer_data_to_buffer_.count(source_var));
buffer_data_to_buffer_.Set(target_var, match_buffer->buffer);
}
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<BlockNode>();
ICHECK(op != nullptr);

// Ignore buffer allocated inside the block when getting access region.
// No longer consider buffers created by match_buffer inside the block when updating access
// region.
for (const MatchBufferRegion match_buffer : op->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
buffer_data_to_buffer_.erase(target_var);
}
// No longer consider buffers allocated inside the block when updating access region.
if (it != alloc_buffers_.end()) {
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.erase(buf->data);
Expand All @@ -98,12 +108,9 @@ class BufferAllocationLocator : public StmtExprMutator {

ObjectPtr<BlockNode> n = CopyOnWrite(op);
n->alloc_buffers = std::move(alloc_buffers);
// The read/write regions of root block are always empty.
if (!is_root) {
// Recalculate block access region
CollectReadWrite(GetRef<Block>(op), &n->reads, &n->writes);
}

// Erase buffer allocated inside the block from access region.
n->reads = RemoveRedundantBufferRegion(n->reads);
n->writes = RemoveRedundantBufferRegion(n->writes);
return Stmt(n);
}

Expand All @@ -127,8 +134,18 @@ class BufferAllocationLocator : public StmtExprMutator {
return std::move(realize);
}

Array<BufferRegion> RemoveRedundantBufferRegion(const Array<BufferRegion>& region) const {
Array<BufferRegion> result;
for (const BufferRegion& buffer_region : region) {
if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) {
result.push_back(buffer_region);
}
}
return result;
}

void CollectReadWrite(const Block& block, Array<BufferRegion>* reads,
Array<BufferRegion>* writes) {
Array<BufferRegion>* writes) const {
Array<Array<BufferRegion>> access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
*reads = access[0];
*writes = access[1];
Expand All @@ -142,8 +159,6 @@ class BufferAllocationLocator : public StmtExprMutator {
std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
/*! \brief The buffer already allocated during recursive visiting. */
Map<Var, Buffer> buffer_data_to_buffer_;
/*! \brief indicate the whether the block is root. */
bool is_root_{true};
};

PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
Expand Down
21 changes: 15 additions & 6 deletions tests/python/unittest/test_tir_analysis_get_block_access_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,29 @@ def test_match_buffer():
root_block = match_buffer_func.body.block
block = root_block.body.body.body.block
block_inner = block.body[0].body.body.block
alloc_buffers = func.body.block.alloc_buffers
alloc_buffers = match_buffer_func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}

# Check inner block AAA
ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
tvm.ir.assert_structural_equal(block_inner.reads, ret[0])
tvm.ir.assert_structural_equal(block_inner.writes, ret[1])

# Check block
ret = tir.analysis.get_block_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.writes, ret[1])
# B is opaque access
tvm.ir.assert_structural_equal(block.reads, ret[2])

# Check inner block AAA without updating buffer_var_map
ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
# Since AA is not in the buffer_var_map, region of AA will not be collected.
tvm.ir.assert_structural_equal([], ret[1])

# Check inner block AAA
for match_buffer in block.match_buffers:
target_buffer = match_buffer.buffer
buffer_var_map[target_buffer.data] = target_buffer

ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
tvm.ir.assert_structural_equal(block_inner.reads, ret[0])
tvm.ir.assert_structural_equal(block_inner.writes, ret[1])


if __name__ == "__main__":
test_block_access_region_detector()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,63 @@ def transformed_match_buffer_func() -> None:
C1[()] = 0


@tvm.script.tir
def opaque_access(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, [1024])
B = tir.match_buffer(b, [1024])
A_cache = tir.alloc_buffer([1024])
for i in tir.serial(0, 8):
with tir.block([8]) as [vi]:
with tir.block([8]) as [v]:
tir.bind(v, vi)
tir.reads([A[(v * 128) : ((v * 128) + 128)]])
tir.writes([A_cache[(v * 128) : ((v * 128) + 128)]])
tir.evaluate(
tir.call_extern(
"test",
A_cache.data,
(v * 128),
128,
A.data,
(v * 128),
128,
dtype="float32",
)
)
for j in tir.serial(0, 128):
with tir.block([1024]) as [v]:
tir.bind(v, ((vi * 128) + j))
tir.reads([A_cache[v]])
tir.writes([B[v]])
B[v] = A_cache[v]


@tvm.script.tir
def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, [1024])
B = tir.match_buffer(b, [1024])
for i in tir.serial(0, 8):
with tir.block([8]) as [vi]:
tir.reads(A[vi * 128 : vi * 128 + 128])
tir.writes(B[vi * 128 : vi * 128 + 128])
A_cache = tir.alloc_buffer([1024])
with tir.block([8]) as [v]:
tir.bind(v, vi)
tir.reads([A[v * 128 : v * 128 + 128]])
tir.writes([A_cache[v * 128 : v * 128 + 128]])
tir.evaluate(
tir.call_extern(
"test", A_cache.data, v * 128, 128, A.data, v * 128, 128, dtype="float32"
)
)
for j in tir.serial(0, 128):
with tir.block([1024]) as [v]:
tir.bind(v, ((vi * 128) + j))
tir.reads([A_cache[v]])
tir.writes([B[v]])
B[v] = A_cache[v]


def test_elementwise():
_check(element_func, transformed_element_func)

Expand All @@ -149,6 +206,10 @@ def test_match_buffer_allocation():
_check(match_buffer_func, transformed_match_buffer_func)


def test_opaque_access():
_check(opaque_access, transformed_opaque_access)


def test_lower_te():
x = te.placeholder((1,))
y = te.compute((1,), lambda i: x[i] + 2)
Expand All @@ -164,4 +225,5 @@ def test_lower_te():
test_elementwise()
test_locate_buffer_allocation()
test_match_buffer_allocation()
test_opaque_access()
test_lower_te()

0 comments on commit 7214f52

Please sign in to comment.