From f11d1c9a0669b430a961118687c58ef50691987d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 14 Apr 2022 15:18:09 -0500 Subject: [PATCH] [TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator (#10998) * [TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator Prior to this commit, the BufferAllocationLocator mutator used in the PlanAndUpdateBufferAllocationLocation pass would erroneously insert an entry to `BlockNode::alloc_buffers` for buffers allocated using `Allocate` or `AllocateConst` nodes. This error was introduced in https://github.com/apache/tvm/pull/9727, which deprecated `Load` and `Store` nodes, replacing them with `BufferLoad` and `BufferStore` nodes. As a result, BufferAllocationLocator identified these as buffers whose allocations should be moved to inner loops, rather than as unmanaged allocations that should be ignored. This commit restores the earlier behavior by only operating on buffer allocations in `BlockNode::alloc_buffers`, and explicitly ignoring any buffers whose allocation is done with `Allocate` or `AllocateConst`. * Only inject opaque block if managed buffers exist. Previously, all buffers found were managed buffers, so this check wasn't needed. --- .../plan_update_buffer_allocation_location.cc | 33 ++++++++++++++----- .../test_tir_transform_extract_constants.py | 2 ++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 6b495b3bf4b5..81dfceb40d32 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -61,16 +61,21 @@ class BufferAllocationLocator : public StmtExprMutator { for (const Buffer& buf : it->second) { buffer_data_to_buffer_.Set(buf->data, buf); } - Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op != nullptr); + auto node = Downcast(StmtMutator::VisitStmt_(op)); + + Array new_block_alloc_bufs; for (const Buffer& buf : it->second) { - buffer_data_to_buffer_.erase(buf->data); + if (!unmanaged_allocations_.count(buf->data.get())) { + buffer_data_to_buffer_.erase(buf->data); + new_block_alloc_bufs.push_back(buf); + } } - Stmt body = InjectOpaqueBlock(op->body, it->second); - ObjectPtr n = CopyOnWrite(op); - n->body = std::move(body); - return Stmt(n); + + if (new_block_alloc_bufs.size()) { + node.CopyOnWrite()->body = InjectOpaqueBlock(node->body, new_block_alloc_bufs); + } + + return std::move(node); } Stmt VisitStmt_(const BlockNode* op) final { @@ -114,6 +119,16 @@ class BufferAllocationLocator : public StmtExprMutator { return Stmt(n); } + Stmt VisitStmt_(const AllocateNode* op) final { + unmanaged_allocations_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AllocateConstNode* op) final { + unmanaged_allocations_.insert(op->buffer_var.get()); + return StmtExprMutator::VisitStmt_(op); + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR."; throw; @@ -151,6 +166,8 @@ class BufferAllocationLocator : public StmtExprMutator { std::unordered_map> alloc_buffers_; /*! \brief The buffer already allocated during recursive visiting. */ Map buffer_data_to_buffer_; + /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved. */ + std::unordered_set unmanaged_allocations_; }; PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { diff --git a/tests/python/unittest/test_tir_transform_extract_constants.py b/tests/python/unittest/test_tir_transform_extract_constants.py index 9636a9bdde4c..cb49e7286fbb 100644 --- a/tests/python/unittest/test_tir_transform_extract_constants.py +++ b/tests/python/unittest/test_tir_transform_extract_constants.py @@ -59,6 +59,8 @@ def _visit(stmt): for n, f in mod.functions.items(): tvm.tir.stmt_functor.post_order_visit(f.body, _visit) + tvm.lower(mod) + if __name__ == "__main__": test_const_extraction()