Skip to content

Commit

Permalink
[TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator (apach…
Browse files Browse the repository at this point in the history
…e#10998)

* [TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator

Prior to this commit, the BufferAllocationLocator mutator used in the
PlanAndUpdateBufferAllocationLocation pass would erroneously insert an
entry to `BlockNode::alloc_buffers` for buffers allocated using
`Allocate` or `AllocateConst` nodes.  This error was introduced in
apache#9727, which deprecated `Load` and
`Store` nodes, replacing them with `BufferLoad` and `BufferStore`
nodes.  As a result, BufferAllocationLocator identified these as
buffers whose allocations should be moved to inner loops, rather than
as unmanaged allocations that should be ignored.

This commit restores the earlier behavior by only operating on buffer
allocations in `BlockNode::alloc_buffers`, and explicitly ignoring any
buffers whose allocation is done with `Allocate` or `AllocateConst`.

* Only inject opaque block if managed buffers exist.

Previously, all buffers found were managed buffers, so this check
wasn't needed.
  • Loading branch information
Lunderberg authored and altanh committed Apr 28, 2022
1 parent b2045e9 commit f11d1c9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
33 changes: 25 additions & 8 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,21 @@ class BufferAllocationLocator : public StmtExprMutator {
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.Set(buf->data, buf);
}
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
ICHECK(op != nullptr);
auto node = Downcast<For>(StmtMutator::VisitStmt_(op));

Array<Buffer> new_block_alloc_bufs;
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.erase(buf->data);
if (!unmanaged_allocations_.count(buf->data.get())) {
buffer_data_to_buffer_.erase(buf->data);
new_block_alloc_bufs.push_back(buf);
}
}
Stmt body = InjectOpaqueBlock(op->body, it->second);
ObjectPtr<ForNode> n = CopyOnWrite(op);
n->body = std::move(body);
return Stmt(n);

if (new_block_alloc_bufs.size()) {
node.CopyOnWrite()->body = InjectOpaqueBlock(node->body, new_block_alloc_bufs);
}

return std::move(node);
}

Stmt VisitStmt_(const BlockNode* op) final {
Expand Down Expand Up @@ -114,6 +119,16 @@ class BufferAllocationLocator : public StmtExprMutator {
return Stmt(n);
}

Stmt VisitStmt_(const AllocateNode* op) final {
unmanaged_allocations_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const AllocateConstNode* op) final {
unmanaged_allocations_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const BufferRealizeNode* op) final {
ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR.";
throw;
Expand Down Expand Up @@ -151,6 +166,8 @@ class BufferAllocationLocator : public StmtExprMutator {
std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
/*! \brief The buffer already allocated during recursive visiting. */
Map<Var, Buffer> buffer_data_to_buffer_;
/*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved. */
std::unordered_set<const VarNode*> unmanaged_allocations_;
};

PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_tir_transform_extract_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def _visit(stmt):
for n, f in mod.functions.items():
tvm.tir.stmt_functor.post_order_visit(f.body, _visit)

tvm.lower(mod)


if __name__ == "__main__":
test_const_extraction()

0 comments on commit f11d1c9

Please sign in to comment.