Skip to content

Commit

Permalink
[Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite. (a…
Browse files Browse the repository at this point in the history
…pache#10787)

* [Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite.

As a follow-up from apache#9727,
restricting StorageRewrite to only modify flat memory buffers.  When
rewriting, the existing algorithm in StorageRewrite flattens N-d
allocations into 1-d allocations, preventing them from being exposed
to the codegen.

* Bugfix, flattening of Allocate/AllocateConst extents

Previously, these were ignored entirely.  This worked so long as all
allocations were 1-d, as `StorageRewrite` erroneously flattened merged
arrays into 1-d.
  • Loading branch information
Lunderberg authored and mehrdadh committed Apr 11, 2022
1 parent ff2c0ab commit c6a6233
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 19 deletions.
97 changes: 95 additions & 2 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1405,12 +1405,25 @@ class StorageFlattener : public StmtExprMutator {
// rather than a buffer_var.
Stmt VisitStmt_(const AllocateNode* op) final {
buffer_var_defines_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
auto stmt = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
return Allocate(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), stmt->condition,
stmt->body, stmt->annotations, stmt->span);
}

Stmt VisitStmt_(const AllocateConstNode* op) final {
buffer_var_defines_.insert(op->buffer_var.get());
return StmtExprMutator::VisitStmt_(op);
auto stmt = Downcast<AllocateConst>(StmtExprMutator::VisitStmt_(op));
ObjectRef data_or_idx;
if (stmt->data) {
data_or_idx = stmt->data.value();
} else if (stmt->irmod_storage_idx) {
data_or_idx = stmt->irmod_storage_idx.value();
} else {
LOG(FATAL) << "Neither data array nor data index specified for allocation of const "
<< op->buffer_var->name_hint;
}
return AllocateConst(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), data_or_idx,
stmt->body, stmt->span);
}

Stmt VisitStmt_(const LetStmtNode* op) final {
Expand Down Expand Up @@ -1598,6 +1611,82 @@ class StorageFlattener : public StmtExprMutator {
}

private:
// Helper function for visiting Allocate and AllocateConst. If, in
// the future, these are updated to hold a buffer (Buffer) object
// rather than a buffer_var (Var), this function can be replaced
// with a call to GetBufferEntry.
template <typename Node>
Array<PrimExpr> FlattenExtents(const Node& node) {
arith::Analyzer analyzer;

// If an allocation has extents that match the buffer
auto is_compatible_buffer = [&](const Buffer& buffer) {
if (buffer->shape.size() != node->extents.size()) {
return false;
}
for (size_t i = 0; i < buffer->shape.size(); i++) {
if (!analyzer.CanProveEqual(buffer->shape[i], node->extents[i])) {
return false;
}
}

return true;
};

auto int_array_equal = [](const Array<IntImm>& a, const Array<IntImm>& b) {
if (a.size() != b.size()) {
return false;
}

for (size_t i = 0; i < a.size(); i++) {
if (a[i]->value != b[i]->value) {
return false;
}
}

return true;
};

Array<IntImm> axis_separators;
auto it = buffer_var_map_.find(node->buffer_var.get());
if (it != buffer_var_map_.end()) {
const auto& buffers = it->second;
if (buffers.size() == 0) {
// No buffers use this allocation, treat as flat and optimize
// out later.
} else if (buffers.size() == 1) {
// Only one buffer uses this allocation, so use its axis
// separators.
axis_separators = buffers[0]->axis_separators;
} else {
// Try to find a buffer using this allocation with a matching
// shape.
Buffer compatible_buffer;
for (const auto& buffer : buffers) {
if (is_compatible_buffer(buffer)) {
ICHECK(!compatible_buffer.defined() ||
int_array_equal(compatible_buffer->axis_separators, buffer->axis_separators))
<< "Cannot determine axis separators to use when flattening "
<< node->buffer_var->name_hint
<< ", multiple buffer objects found with conflicting axis separators";
compatible_buffer = buffer;
}
}
ICHECK(compatible_buffer.defined())
<< "Cannot determine axis separators to use when flattening "
<< node->buffer_var->name_hint << ", no buffers found with matching shape";
axis_separators = compatible_buffer->axis_separators;
}
}

// Use GetFlattenedBuffer to determine the flattened shape of the
// output. We only need the shape and axis separators defined,
// everything else can be dummy values.
Buffer dummy_buffer =
decl_buffer(node->extents, DataType::Float(32), "buffer", "", axis_separators);
return dummy_buffer.GetFlattenedBuffer()->shape;
}

// The buffer entry in the flatten map
struct DimAlignInfo {
int align_factor{0};
Expand Down Expand Up @@ -1665,6 +1754,10 @@ class StorageFlattener : public StmtExprMutator {
// Set of vars that have occurred in an AllocateNode, but haven't
// yet occurred in a BufferLoad/BufferStore.
std::unordered_set<const VarNode*> buffer_var_defines_;
// Map from an allocation variable to the buffer(s) that it backs.
// Used to track the determine the axis_separators that should be
// used for flattening the extents of an AllocateNode.
std::unordered_map<const VarNode*, std::vector<Buffer>> buffer_var_map_;
// Buffer map
std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
// The extern buffer map, updated to include flattened buffers.
Expand Down
77 changes: 60 additions & 17 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
};
// The scope of each allocation
struct AllocEntry {
// The physical dimension of the allocation.
size_t num_physical_dimensions{0};
// scope level
size_t level{0};
// allocation stmt
Expand All @@ -85,8 +87,16 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
void VisitStmt_(const AllocateNode* op) final {
size_t level = scope_.size();
const VarNode* buf = op->buffer_var.get();
alloc_info_[buf].alloc = op;
alloc_info_[buf].level = level;

AllocEntry entry;
entry.alloc = op;
entry.level = level;
// Since StorageRewrite occurs after StorageFlatten/FlattenBuffer,
// all allocations specify the extent of physical dimensions, and
// is 1 for flat memory spaces.
entry.num_physical_dimensions = op->extents.size();
alloc_info_[buf] = entry;

StmtExprVisitor::VisitStmt_(op);
}

Expand All @@ -104,6 +114,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
scope_[it->second.level].touched.push_back(buf);

ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
<< it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl;
}
StmtEntry e = scope_.back();
scope_.pop_back();
Expand All @@ -125,6 +141,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
scope_[it->second.level].touched.push_back(buf);

ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions)
<< "Buffer " << op->buffer->name << " is allocated with "
<< it->second.num_physical_dimensions
<< " physical dimensions, but is accessed as having "
<< op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl;
}
}

Expand Down Expand Up @@ -530,6 +552,10 @@ class StoragePlanRewriter : public StmtExprMutator {
uint64_t const_nbits{0};
// The storage scope.
StorageScope scope;
// The physical dimensionality of the allocations. Since
// StorageRewrite is applied after StorageFlatten/FlattenBuffer,
// this is size of `AllocateNode::extents`. If moved
size_t ndim;
// Allocs that shares this entry.
std::vector<const AllocateNode*> allocs;
// The children of this entry, not including itself.
Expand Down Expand Up @@ -629,8 +655,8 @@ class StoragePlanRewriter : public StmtExprMutator {
// simply use the original allocation.
PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), e->allocs[0]->extents);
e->new_alloc =
Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0));
e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0));
if (IsSpecialTaggedMemory(e->scope)) {
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
Expand All @@ -641,8 +667,13 @@ class StoragePlanRewriter : public StmtExprMutator {
// Build a merged allocation
PrimExpr combo_size;
for (const AllocateNode* op : e->allocs) {
PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), op->extents);
ICHECK_EQ(op->extents.size(), 1)
<< "Buffer var " << op->buffer_var->name_hint
<< " was identified as a re-usable allocation, but has " << op->extents.size()
<< " physical dimensions. "
<< "Currently, only flat 1-d memory spaces should be identified as re-usable "
"allocations.";
PrimExpr sz = op->extents[0];
auto nbits = op->dtype.bits() * op->dtype.lanes();
if (const auto* imm = sz.as<IntImmNode>()) {
if (imm->value > std::numeric_limits<int>::max() / nbits) {
Expand Down Expand Up @@ -790,7 +821,8 @@ class StoragePlanRewriter : public StmtExprMutator {

for (const VarNode* var : it->second.gen) {
ICHECK(alloc_info.count(var));
const AllocateNode* alloc = alloc_info.at(var).alloc;
const AllocEntry& entry = alloc_info.at(var);
const AllocateNode* alloc = entry.alloc;
auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var)));
StorageEntry* dst_entry = nullptr;
// inplace detection
Expand Down Expand Up @@ -818,7 +850,8 @@ class StoragePlanRewriter : public StmtExprMutator {
}
}
if (dst_entry == nullptr) {
dst_entry = FindAlloc(alloc, thread_scope_, storage_scope);
dst_entry =
FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions);
}
dst_entry->allocs.emplace_back(alloc);
alloc_map_[var] = dst_entry;
Expand Down Expand Up @@ -871,24 +904,34 @@ class StoragePlanRewriter : public StmtExprMutator {
}

StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
const StorageScope& scope) {
const StorageScope& scope, size_t num_physical_dimensions) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
uint64_t const_nbits = static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);

// If the size of the array isn't known at compile-time, it must
// have its own allocation with size determined at runtime.
bool is_known_size = (const_nbits != 0);

// Currently, only flat memory spaces can be re-used. Packing
// into N-d space (e.g. 2-d texture memory on GPUs) will require
// more in-depth algorithms.
bool is_flat_memory_space = (num_physical_dimensions == 1);

// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (scope.tag.length() == 0) {
if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
if (const_nbits > 0 && const_nbits <= 32) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
bool is_small_array =
(scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp || op->dtype.is_handle() ||
(is_known_size && const_nbits <= 32));

if (is_small_array || !is_flat_memory_space) {
return NewAlloc(op, attach_scope, scope, const_nbits);
}
if (const_nbits != 0) {

if (is_known_size) {
// constant allocation.
auto begin = const_free_map_.lower_bound(const_nbits / match_range);
auto mid = const_free_map_.lower_bound(const_nbits);
Expand Down

0 comments on commit c6a6233

Please sign in to comment.