From add5eebe9c204bd193d1937bbfa11176b7a19410 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 30 Mar 2022 14:48:19 -0500 Subject: [PATCH] [Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite. (#10787) * [Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite. As a follow-up from https://github.com/apache/tvm/pull/9727, restricting StorageRewrite to only modify flat memory buffers. When rewriting, the existing algorithm in StorageRewrite flattens N-d allocations into 1-d allocations, preventing them from being exposed to the codegen. * Bugfix, flattening of Allocate/AllocateConst extents Previously, these were ignored entirely. This worked so long as all allocations were 1-d, as `StorageRewrite` erroneously flattened merged arrays into 1-d. --- src/tir/transforms/storage_flatten.cc | 97 ++++++++++++++++++++++++++- src/tir/transforms/storage_rewrite.cc | 77 ++++++++++++++++----- 2 files changed, 155 insertions(+), 19 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 2bfc8420b025e..0923517634373 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1405,12 +1405,25 @@ class StorageFlattener : public StmtExprMutator { // rather than a buffer_var. Stmt VisitStmt_(const AllocateNode* op) final { buffer_var_defines_.insert(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); + auto stmt = Downcast(StmtExprMutator::VisitStmt_(op)); + return Allocate(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), stmt->condition, + stmt->body, stmt->annotations, stmt->span); } Stmt VisitStmt_(const AllocateConstNode* op) final { buffer_var_defines_.insert(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); + auto stmt = Downcast(StmtExprMutator::VisitStmt_(op)); + ObjectRef data_or_idx; + if (stmt->data) { + data_or_idx = stmt->data.value(); + } else if (stmt->irmod_storage_idx) { + data_or_idx = stmt->irmod_storage_idx.value(); + } else { + LOG(FATAL) << "Neither data array nor data index specified for allocation of const " + << op->buffer_var->name_hint; + } + return AllocateConst(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), data_or_idx, + stmt->body, stmt->span); } Stmt VisitStmt_(const LetStmtNode* op) final { @@ -1598,6 +1611,82 @@ class StorageFlattener : public StmtExprMutator { } private: + // Helper function for visiting Allocate and AllocateConst. If, in + // the future, these are updated to hold a buffer (Buffer) object + // rather than a buffer_var (Var), this function can be replaced + // with a call to GetBufferEntry. + template + Array FlattenExtents(const Node& node) { + arith::Analyzer analyzer; + + // If an allocation has extents that match the buffer + auto is_compatible_buffer = [&](const Buffer& buffer) { + if (buffer->shape.size() != node->extents.size()) { + return false; + } + for (size_t i = 0; i < buffer->shape.size(); i++) { + if (!analyzer.CanProveEqual(buffer->shape[i], node->extents[i])) { + return false; + } + } + + return true; + }; + + auto int_array_equal = [](const Array& a, const Array& b) { + if (a.size() != b.size()) { + return false; + } + + for (size_t i = 0; i < a.size(); i++) { + if (a[i]->value != b[i]->value) { + return false; + } + } + + return true; + }; + + Array axis_separators; + auto it = buffer_var_map_.find(node->buffer_var.get()); + if (it != buffer_var_map_.end()) { + const auto& buffers = it->second; + if (buffers.size() == 0) { + // No buffers use this allocation, treat as flat and optimize + // out later. + } else if (buffers.size() == 1) { + // Only one buffer uses this allocation, so use its axis + // separators. + axis_separators = buffers[0]->axis_separators; + } else { + // Try to find a buffer using this allocation with a matching + // shape. + Buffer compatible_buffer; + for (const auto& buffer : buffers) { + if (is_compatible_buffer(buffer)) { + ICHECK(!compatible_buffer.defined() || + int_array_equal(compatible_buffer->axis_separators, buffer->axis_separators)) + << "Cannot determine axis separators to use when flattening " + << node->buffer_var->name_hint + << ", multiple buffer objects found with conflicting axis separators"; + compatible_buffer = buffer; + } + } + ICHECK(compatible_buffer.defined()) + << "Cannot determine axis separators to use when flattening " + << node->buffer_var->name_hint << ", no buffers found with matching shape"; + axis_separators = compatible_buffer->axis_separators; + } + } + + // Use GetFlattenedBuffer to determine the flattened shape of the + // output. We only need the shape and axis separators defined, + // everything else can be dummy values. + Buffer dummy_buffer = + decl_buffer(node->extents, DataType::Float(32), "buffer", "", axis_separators); + return dummy_buffer.GetFlattenedBuffer()->shape; + } + // The buffer entry in the flatten map struct DimAlignInfo { int align_factor{0}; @@ -1665,6 +1754,10 @@ class StorageFlattener : public StmtExprMutator { // Set of vars that have occurred in an AllocateNode, but haven't // yet occurred in a BufferLoad/BufferStore. std::unordered_set buffer_var_defines_; + // Map from an allocation variable to the buffer(s) that it backs. + // Used to track the determine the axis_separators that should be + // used for flattening the extents of an AllocateNode. + std::unordered_map> buffer_var_map_; // Buffer map std::unordered_map buf_map_; // The extern buffer map, updated to include flattened buffers. diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 0534f31c34235..d1a37e18ac693 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -76,6 +76,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { }; // The scope of each allocation struct AllocEntry { + // The physical dimension of the allocation. + size_t num_physical_dimensions{0}; // scope level size_t level{0}; // allocation stmt @@ -85,8 +87,16 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { size_t level = scope_.size(); const VarNode* buf = op->buffer_var.get(); - alloc_info_[buf].alloc = op; - alloc_info_[buf].level = level; + + AllocEntry entry; + entry.alloc = op; + entry.level = level; + // Since StorageRewrite occurs after StorageFlatten/FlattenBuffer, + // all allocations specify the extent of physical dimensions, and + // is 1 for flat memory spaces. + entry.num_physical_dimensions = op->extents.size(); + alloc_info_[buf] = entry; + StmtExprVisitor::VisitStmt_(op); } @@ -104,6 +114,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); scope_[it->second.level].touched.push_back(buf); + + ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions) + << "Buffer " << op->buffer->name << " is allocated with " + << it->second.num_physical_dimensions + << " physical dimensions, but is accessed as having " + << op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl; } StmtEntry e = scope_.back(); scope_.pop_back(); @@ -125,6 +141,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; scope_[it->second.level].touched.push_back(buf); + + ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions) + << "Buffer " << op->buffer->name << " is allocated with " + << it->second.num_physical_dimensions + << " physical dimensions, but is accessed as having " + << op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl; } } @@ -530,6 +552,10 @@ class StoragePlanRewriter : public StmtExprMutator { uint64_t const_nbits{0}; // The storage scope. StorageScope scope; + // The physical dimensionality of the allocations. Since + // StorageRewrite is applied after StorageFlatten/FlattenBuffer, + // this is size of `AllocateNode::extents`. If moved + size_t ndim; // Allocs that shares this entry. std::vector allocs; // The children of this entry, not including itself. @@ -629,8 +655,8 @@ class StoragePlanRewriter : public StmtExprMutator { // simply use the original allocation. PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, make_const(DataType::Int(32), 1), e->allocs[0]->extents); - e->new_alloc = - Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0)); + e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents, + e->allocs[0]->condition, Evaluate(0)); if (IsSpecialTaggedMemory(e->scope)) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -641,8 +667,13 @@ class StoragePlanRewriter : public StmtExprMutator { // Build a merged allocation PrimExpr combo_size; for (const AllocateNode* op : e->allocs) { - PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), op->extents); + ICHECK_EQ(op->extents.size(), 1) + << "Buffer var " << op->buffer_var->name_hint + << " was identified as a re-usable allocation, but has " << op->extents.size() + << " physical dimensions. " + << "Currently, only flat 1-d memory spaces should be identified as re-usable " + "allocations."; + PrimExpr sz = op->extents[0]; auto nbits = op->dtype.bits() * op->dtype.lanes(); if (const auto* imm = sz.as()) { if (imm->value > std::numeric_limits::max() / nbits) { @@ -790,7 +821,8 @@ class StoragePlanRewriter : public StmtExprMutator { for (const VarNode* var : it->second.gen) { ICHECK(alloc_info.count(var)); - const AllocateNode* alloc = alloc_info.at(var).alloc; + const AllocEntry& entry = alloc_info.at(var); + const AllocateNode* alloc = entry.alloc; auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection @@ -818,7 +850,8 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = FindAlloc(alloc, thread_scope_, storage_scope); + dst_entry = + FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions); } dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; @@ -871,24 +904,34 @@ class StoragePlanRewriter : public StmtExprMutator { } StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope, - const StorageScope& scope) { + const StorageScope& scope, size_t num_physical_dimensions) { ICHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. const uint64_t match_range = 16; uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); uint64_t const_nbits = static_cast(op->ConstantAllocationSize() * op_elem_bits); + + // If the size of the array isn't known at compile-time, it must + // have its own allocation with size determined at runtime. + bool is_known_size = (const_nbits != 0); + + // Currently, only flat memory spaces can be re-used. Packing + // into N-d space (e.g. 2-d texture memory on GPUs) will require + // more in-depth algorithms. + bool is_flat_memory_space = (num_physical_dimensions == 1); + // disable reuse of small arrays, they will be lowered to registers in LLVM // This rules only apply if we are using non special memory - if (scope.tag.length() == 0) { - if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) { - return NewAlloc(op, attach_scope, scope, const_nbits); - } - if (const_nbits > 0 && const_nbits <= 32) { - return NewAlloc(op, attach_scope, scope, const_nbits); - } + bool is_small_array = + (scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp || op->dtype.is_handle() || + (is_known_size && const_nbits <= 32)); + + if (is_small_array || !is_flat_memory_space) { + return NewAlloc(op, attach_scope, scope, const_nbits); } - if (const_nbits != 0) { + + if (is_known_size) { // constant allocation. auto begin = const_free_map_.lower_bound(const_nbits / match_range); auto mid = const_free_map_.lower_bound(const_nbits);