From 21828b2c36c5e0ae8b6c7de33e7ec7386a845d97 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 9 Mar 2021 06:19:18 +0000 Subject: [PATCH 01/20] light refactor --- src/meta_schedule/utils.h | 19 +- src/tir/schedule/analysis.cc | 11 + src/tir/schedule/analysis.h | 6 + src/tir/schedule/transform.cc | 45 +++ src/tir/schedule/transform.h | 40 +++ src/tir/schedule/utils.h | 41 +++ src/tir/transforms/buffer_flatten.cc | 414 +++++++++++++-------------- 7 files changed, 348 insertions(+), 228 deletions(-) create mode 100644 src/tir/schedule/transform.cc create mode 100644 src/tir/schedule/transform.h diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 9dc9058646..07db40e1b7 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -337,23 +337,8 @@ inline std::ostream& StdCout(int verbose, int setting = 1) { /**************** String Manipulation ****************/ -/*! - * \brief Find all positions that the specific char occurs in the string - * \param str The string to be examined - * \param c The specific char - * \return A list of integers indicating the occurrence position - */ -inline std::vector FindCharPos(const String& str, char c) { - std::vector result; - const char* data = str.data(); - int n = str.length(); - for (int i = 0; i < n; ++i) { - if (data[i] == c) { - result.push_back(i); - } - } - return result; -} +using tir::FindCharPos; +using tir::StartsWith; /**************** Target Hardware Concurrency ****************/ diff --git a/src/tir/schedule/analysis.cc b/src/tir/schedule/analysis.cc index 77e45a5122..96305ca0fc 100644 --- a/src/tir/schedule/analysis.cc +++ b/src/tir/schedule/analysis.cc @@ -54,6 +54,17 @@ bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set Vars(const ObjectRef& stmt_or_expr) { + std::unordered_set result; + auto f_visit = [&result](const ObjectRef& obj) -> void { + if (const auto* var = obj.as()) { + result.insert(var); + } + }; + PostOrderVisit(stmt_or_expr, f_visit); + return result; +} + bool ValidateBlockBinding(const BlockRealize& realize, const Map& loop_var_ranges) { arith::Analyzer analyzer; Array results = arith::DetectIterMap( diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index c799befb11..c0d00ab788 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -45,6 +45,12 @@ bool ContainsVar(const ObjectRef& stmt_or_expr, const Var& var); * \return A boolean indicating if any var in the list is found in stmt/expr */ bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set& var); +/*! + * \brief Collect the variables that appear in the specific Stmt or Expr + * \param stmt_or_expr The Stmt or Expr + * \return All variables that appear + */ +std::unordered_set Vars(const ObjectRef& stmt_or_expr); /******** Verification ********/ /*! diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc new file mode 100644 index 0000000000..191d079f60 --- /dev/null +++ b/src/tir/schedule/transform.cc @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +Stmt RealizeInitBlock(const Stmt& init, const Array& iter_vars) { + std::vector conditions; + for (const IterVar& var : iter_vars) { + if (var->iter_type == IterVarType::kCommReduce) { + conditions.push_back(equal(var->var, var->dom->min)); + } + } + int n = conditions.size(); + // Handle the case where there is no condition + if (n == 0) { + return init; + } + // Concate the conditions with logical and (&&) + PrimExpr cond = conditions[0]; + for (int i = 1; i < n; ++i) { + cond = logical_and(cond, conditions[i]); + } + return IfThenElse(cond, init); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h new file mode 100644 index 0000000000..22e7c1822b --- /dev/null +++ b/src/tir/schedule/transform.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_ +#define TVM_TIR_SCHEDULE_TRANSFORM_H_ + +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Transform the init block into actual computation + * \param init The init block + * \param iter_vars The block variables + * \return The actual computation + */ +Stmt RealizeInitBlock(const Stmt& init, const Array& iter_vars); + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_ diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 57fac1707b..5ddf3941e9 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -35,6 +35,7 @@ #include #include "./analysis.h" +#include "./transform.h" namespace tvm { namespace tir { @@ -445,6 +446,46 @@ static DefaultReducer default_reducers[4] = { } // namespace default_reducer +/**************** String ****************/ + +/*! + * \brief Find all positions that the specific char occurs in the string + * \param str The string to be examined + * \param c The specific char + * \return A list of integers indicating the occurrence position + */ +inline std::vector FindCharPos(const String& str, char c) { + std::vector result; + const char* data = str.data(); + int n = str.length(); + for (int i = 0; i < n; ++i) { + if (data[i] == c) { + result.push_back(i); + } + } + return result; +} + +inline bool StartsWith(const String& str, const String& prefix) { + int n = prefix.size(); + if (static_cast(str.size()) < n) { + return false; + } + const char* data = str.data(); + return std::equal(data, data + n, prefix.data()); +} + +inline bool StartsWith(const String& str, const char* prefix) { + int n = strlen(prefix); + if (static_cast(str.size()) < n) { + return false; + } + const char* data = str.data(); + return std::equal(data, data + n, prefix); +} + +/**************** Loop extents ****************/ + inline int64_t GetLoopIntExtent(const ForNode* loop) { const auto* int_extent = loop->extent.as(); return int_extent ? int_extent->value : -1; diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 37a5bfc76b..715901c7f5 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -32,31 +32,23 @@ #include #include +#include "../schedule/utils.h" + namespace tvm { namespace tir { -/*! - * \brief Transform block with init into actual computation - */ -class ReductionTransformer : public StmtExprMutator { +class ReductionTransformer : public StmtMutator { public: - ReductionTransformer() = default; - - Stmt VisitStmt_(const BlockNode* op) override { - Block res = Downcast(StmtMutator::VisitStmt_(op)); - if (op->init) { - PrimExpr condition = Bool(true); - for (const auto& var : res->iter_vars) { - if (var->iter_type == IterVarType::kCommReduce) { - condition = And(condition, EQ(var, var->dom->min)); - } - } - Stmt init = op->init.value(); - if (!is_one(condition)) init = IfThenElse(condition, init); - res.CopyOnWrite()->body = SeqStmt::Flatten(init, op->body); - res.CopyOnWrite()->init = NullOpt; + Stmt VisitStmt_(const BlockNode* block) override { + if (!block->init.defined()) { + return StmtMutator::VisitStmt_(block); } - return std::move(res); + Stmt init = RealizeInitBlock(block->init.value(), block->iter_vars); + Stmt body = VisitStmt(block->body); + ObjectPtr new_block = make_object(*block); + new_block->init = NullOpt; + new_block->body = SeqStmt::Flatten(init, body); + return Stmt(std::move(new_block)); } }; @@ -66,83 +58,81 @@ class ReductionTransformer : public StmtExprMutator { */ class LCADetector : public StmtExprVisitor { public: - explicit LCADetector(const Map& func_args) { - for (const auto& x : func_args) { - arg_buffers_.insert(x.second); - buffers_lca_[x.second] = NullValue(); + explicit LCADetector(const std::unordered_set& arg_buffers) + : arg_buffers_(arg_buffers) { + for (const BufferNode* buffer : arg_buffers) { + buffers_lca_.emplace(GetRef(buffer), ObjectRef(nullptr)); } } - // Update parent and depth information for each AST node - void VisitStmt_(const ForNode* op) final { Stmt n = GetRef(op); - ast_scopes_info_[n] = ScopeInfo{scope_, depth_}; - ++depth_; - std::swap(scope_, n); + ast_scopes_info_[n] = MakeScope(); + ancestor_loops_.push_back(op); StmtExprVisitor::VisitStmt_(op); - std::swap(scope_, n); - --depth_; + ancestor_loops_.pop_back(); } - // Update LCA when visiting BufferLoad and BufferStore - template - void VisitBuffer(T op) { - Buffer buffer = op->buffer; - ObjectRef n = GetRef(op); - ast_scopes_info_[n] = ScopeInfo{scope_, depth_}; - // No need to update LCA if the buffer is in the func args (function input/output buffer) - if (arg_buffers_.count(buffer)) return; + void VisitBuffer(const Buffer& buffer, const ObjectRef& n) { + ast_scopes_info_[n] = MakeScope(); + if (arg_buffers_.count(buffer.get())) { + return; + } if (buffers_lca_.count(buffer)) { - buffers_lca_[buffer] = LowestCommonAncestor(GetRef(op), buffers_lca_[buffer]); + buffers_lca_[buffer] = LowestCommonAncestor(n, buffers_lca_[buffer]); } else { - buffers_lca_[buffer] = GetRef(op); + buffers_lca_[buffer] = n; } } void VisitExpr_(const BufferLoadNode* op) final { - VisitBuffer(op); + VisitBuffer(op->buffer, GetRef(op)); StmtExprVisitor::VisitExpr_(op); } + void VisitStmt_(const BufferStoreNode* op) final { - VisitBuffer(op); + VisitBuffer(op->buffer, GetRef(op)); StmtExprVisitor::VisitStmt_(op); } /*! \brief The map from Buffer to its LCA Stmt/Expr */ std::unordered_map buffers_lca_; - /*! \brief The Buffer in function args */ - std::unordered_set arg_buffers_; private: /*! \brief The AST node information for querying LCA */ struct ScopeInfo { // The parent loop node - Stmt parent_scope; + const StmtNode* parent_scope; // The scope depth in the AST - size_t depth; + int depth; }; + ScopeInfo MakeScope() { + int n = ancestor_loops_.size(); + return (n == 0) ? ScopeInfo{nullptr, 0} : ScopeInfo{ancestor_loops_.back(), n}; + } + + /*! \brief The Buffer in function args */ + const std::unordered_set& arg_buffers_; /*! \brief The current scope initializing with Null */ - Stmt scope_{NullValue()}; - /*! \brief The current DFS depth */ - size_t depth_{0}; + std::vector ancestor_loops_; /*! \brief The parent and depth info of each Loop/BufferLoad/BufferStore Node */ std::unordered_map ast_scopes_info_; ObjectRef LowestCommonAncestor(ObjectRef lhs, ObjectRef rhs) { - if (!lhs.defined() || !rhs.defined()) return NullValue(); + ICHECK(lhs.defined()); + ICHECK(rhs.defined()); ICHECK(ast_scopes_info_.count(lhs)); ICHECK(ast_scopes_info_.count(rhs)); while (ast_scopes_info_[lhs].depth > ast_scopes_info_[rhs].depth) { - lhs = ast_scopes_info_[lhs].parent_scope; + lhs = GetRef(ast_scopes_info_[lhs].parent_scope); } while (ast_scopes_info_[lhs].depth < ast_scopes_info_[rhs].depth) { - rhs = ast_scopes_info_[rhs].parent_scope; + rhs = GetRef(ast_scopes_info_[rhs].parent_scope); } while (!lhs.same_as(rhs)) { - lhs = ast_scopes_info_[lhs].parent_scope; - rhs = ast_scopes_info_[rhs].parent_scope; + lhs = GetRef(ast_scopes_info_[lhs].parent_scope); + rhs = GetRef(ast_scopes_info_[rhs].parent_scope); } return lhs; } @@ -151,54 +141,52 @@ class LCADetector : public StmtExprVisitor { /*! * \brief Gather the used region of each buffers. */ -class RegionGatherer : public StmtExprVisitor { +class RegionGatherer : public StmtVisitor { public: RegionGatherer( const std::unordered_map& buffers_lca, const Map& func_args) : buffers_lca_(buffers_lca) { for (const auto& arg : func_args) { + const Buffer& buffer = arg.second; std::vector region; - for (const auto& size : arg.second->shape) { - region.push_back(arith::IntSet::FromRange(Range::FromMinExtent(0, size))); + for (const PrimExpr& size : buffer->shape) { + region.push_back(IntSetFromMinExtent(0, size)); } - buffers_region_[arg.second] = region; + buffers_region_[buffer] = region; } } void VisitStmt_(const ForNode* op) final { - auto loop = GetRef(op); + For loop = GetRef(op); loop_stack_.push_back(loop); - if (op->annotations.empty() && is_one(op->extent)) { + if (!op->thread_binding.defined() && op->annotations.empty() && is_one(op->extent)) { unit_loops_[op->loop_var.get()] = op->min; } - StmtExprVisitor::VisitStmt_(op); + StmtVisitor::VisitStmt_(op); loop_stack_.pop_back(); } - void VisitStmt_(const BlockRealizeNode* op) final { - const auto* block_op = op->block.as(); - for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { - const auto& iter = block_op->iter_vars[i]; - const auto& v = op->binding_values[i]; - block_var_[iter->var.get()] = Substitute(Substitute(v, block_var_), unit_loops_); + void VisitStmt_(const BlockRealizeNode* realize) final { + const auto* block = realize->block.as(); + CHECK(!block->init.defined()); + for (size_t i = 0; i < block->iter_vars.size(); ++i) { + const IterVar& iter = block->iter_vars[i]; + const PrimExpr& v = realize->binding_values[i]; + block_var_[iter->var.get()] = ReplaceBlockVar(v); } - StmtExprVisitor::VisitStmt_(op); - } - - void VisitStmt_(const BlockNode* op) final { - for (const auto& buffer_region : op->reads) { - VisitBufferRegion(buffer_region); + for (const BufferRegion& buffer_region : block->reads) { + UnionWith(&buffers_region_.at(buffer_region->buffer), GatherRegion(buffer_region)); } - for (const auto& buffer_region : op->writes) { - VisitBufferRegion(buffer_region); + for (const BufferRegion& buffer_region : block->writes) { + UnionWith(&buffers_region_.at(buffer_region->buffer), GatherRegion(buffer_region)); } - for (const auto& alloc_buf : op->alloc_buffers) { - std::vector empty_region(alloc_buf->shape.size(), arith::IntSet::Nothing()); + for (const Buffer& alloc_buf : block->alloc_buffers) { // Initialize the buffer region with empty region. - buffers_region_[alloc_buf] = empty_region; + buffers_region_[alloc_buf] = + std::vector(alloc_buf->shape.size(), arith::IntSet::Nothing()); } - StmtExprVisitor::VisitStmt_(op); + VisitStmt(block->body); } /*! \brief The used region of each Buffer */ @@ -215,14 +203,12 @@ class RegionGatherer : public StmtExprVisitor { /*! \brief The loops from the current node up to the root */ std::vector loop_stack_; - void VisitBufferRegion(const BufferRegion& buffer_region) { - auto it = buffers_region_.find(buffer_region->buffer); - ICHECK(it != buffers_region_.end()); - const auto& region = GatherRegion(buffer_region); - auto& buffer_new_region = it->second; - ICHECK_EQ(buffer_new_region.size(), region.size()); + static void UnionWith(std::vector* buffer_new_region, + const std::vector& region) { + ICHECK_EQ(buffer_new_region->size(), region.size()); for (size_t i = 0; i < region.size(); ++i) { - buffer_new_region[i] = arith::Union({buffer_new_region[i], region[i]}); + arith::IntSet& int_set = buffer_new_region->at(i); + int_set = arith::Union({int_set, region[i]}); } } @@ -231,29 +217,35 @@ class RegionGatherer : public StmtExprVisitor { */ std::vector GatherRegion(const BufferRegion& buffer_region) { std::unordered_map dom_map; - auto it = buffers_lca_.find(buffer_region->buffer); - ICHECK(it != buffers_lca_.end()); - const auto& lca = it->second; + const ObjectRef& lca = buffers_lca_.at(buffer_region->buffer); // Every loop will be relaxed if the lca is the root bool need_relax = !lca.defined(); - for (size_t i = 0; i < loop_stack_.size(); ++i) { - const For& loop = loop_stack_[i]; + for (const For& loop : loop_stack_) { const VarNode* var = loop->loop_var.get(); if (need_relax || (buffer_region->buffer->scope == "shared" && IsThreadBinded(loop))) { - dom_map[var] = arith::IntSet::FromRange(Range::FromMinExtent(loop->min, loop->extent)); + dom_map[var] = IntSetFromMinExtent(loop->min, loop->extent); + } + if (loop.same_as(lca)) { + need_relax = true; } - if (loop.same_as(lca)) need_relax = true; } std::vector region; - for (const auto& range : buffer_region->region) { - Range r = - Range::FromMinExtent(Substitute(Substitute(range->min, block_var_), unit_loops_), - Substitute(Substitute(range->extent, block_var_), unit_loops_)); - region.push_back(arith::EvalSet(r, dom_map)); + for (const Range& range : buffer_region->region) { + PrimExpr min = ReplaceBlockVar(range->min); + PrimExpr extent = ReplaceBlockVar(range->extent); + region.push_back(arith::EvalSet(Range::FromMinExtent(min, extent), dom_map)); } return region; } + PrimExpr ReplaceBlockVar(const PrimExpr& expr) const { + return Substitute(Substitute(expr, block_var_), unit_loops_); + } + + static arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { + return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); + } + static bool IsThreadBinded(const For& loop) { if (loop->kind != ForKind::kThreadBinding || !loop->thread_binding.defined()) return false; std::string thread_tag = loop->thread_binding.value()->thread_tag; @@ -272,18 +264,13 @@ class BufferFlattener : public StmtExprMutator { const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& buffers_region, const std::unordered_map& buffers_lca, - const std::unordered_set& arg_buffers) + const std::unordered_set& arg_buffers) : buffers_region_(buffers_region), block_var_(block_var), unit_loops_(unit_loops), buffers_lca_(buffers_lca), arg_buffers_(arg_buffers) {} - Stmt VisitStmt(const Stmt& stmt) override { - Stmt body = StmtMutator::VisitStmt(stmt); - return body; - } - Stmt VisitStmt_(const SeqStmtNode* op) final { Array seq; for (const Stmt& stmt : op->seq) { @@ -295,7 +282,7 @@ class BufferFlattener : public StmtExprMutator { for (const Buffer& buffer : double_buffer) { ObjectRef lca = buffers_lca_.at(buffer); if (lca.defined() && lca.same_as(parent_scope_)) { - body = AttrStmt(buffer->data, tir::attr::double_buffer_scope, 1, body); + body = AttrStmt(buffer->data, attr::double_buffer_scope, 1, body); } else { double_buffer_.insert(buffer); } @@ -307,74 +294,77 @@ class BufferFlattener : public StmtExprMutator { return SeqStmt(seq); } - Stmt VisitStmt_(const BlockRealizeNode* op) final { + Stmt VisitStmt_(const BlockRealizeNode* realize) final { // Handle allocations - const auto* block_op = op->block.as(); - Stmt old_stmt = GetRef(block_op); - ICHECK(block_op != nullptr); - for (size_t i = block_op->alloc_buffers.size(); i > 0; --i) { - const auto& buffer = block_op->alloc_buffers[i - 1]; - const std::string name = std::string(buffer->name); - if (name.substr(0, 18) == "normal_reduce_temp" || name.substr(0, 11) == "reduce_temp") { + const auto* block = realize->block.get(); + Block old_block = realize->block; + int n_alloc_buffer = block->alloc_buffers.size(); + int n_iter_var = block->iter_vars.size(); + // Step 1. Figure out `pending_allocate_` + for (int i = n_alloc_buffer - 1; i >= 0; --i) { + // Why the order + const Buffer& buffer = block->alloc_buffers[i]; + if (StartsWith(buffer->name, "normal_reduce_temp") || + StartsWith(buffer->name, "reduce_temp")) { continue; } if (buffers_lca_.at(buffer).defined()) { - pending_allocate_[buffer] = block_op->alloc_buffers[i - 1]; + pending_allocate_[buffer] = buffer; } } - for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { - const IterVar& block_var = block_op->iter_vars[i]; - const PrimExpr& binding_value = op->binding_values[i]; - ICHECK(block_var.as()); - ICHECK(binding_value.as()); - - if (block_var->iter_type == kCommReduce) { - PreOrderVisit(binding_value, [this](const ObjectRef& node) { - if (const auto* var = node.as()) { - this->reduction_relative_.insert(GetRef(var)); - return false; - } - return true; - }); + for (int i = 0; i < n_iter_var; ++i) { + const IterVar& block_var = block->iter_vars[i]; + const PrimExpr& binding_value = realize->binding_values[i]; + if (block_var->iter_type != kCommReduce) { + continue; + } + std::unordered_set vars = Vars(binding_value); + for (const VarNode* var : vars) { + this->reduction_relative_.insert(GetRef(var)); } } - // visit body - Stmt parent_scope = op->block; + // Step 2. Visit the body + Stmt parent_scope = realize->block; std::swap(parent_scope, parent_scope_); - Stmt stmt = StmtExprMutator::VisitStmt_(op); + BlockRealize new_stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); std::swap(parent_scope, parent_scope_); - op = stmt.as(); - ICHECK(op != nullptr); - block_op = op->block.as(); - ICHECK(block_op != nullptr); - Stmt body = block_op->body; - // Handle block predicate - if (!is_one(op->predicate)) { - body = IfThenElse(op->predicate, body); + // Reset `realize` and `block` + realize = new_stmt.get(); + block = realize->block.get(); + // Step 3. Transform the `predicate` to if-then-else + Stmt body = block->body; + if (!is_one(realize->predicate)) { + body = IfThenElse(realize->predicate, body); } - - for (const auto& anno : block_op->annotations) { - if (anno.first == tir::attr::double_buffer_scope && is_one(Downcast(anno.second))) { - ICHECK_EQ(block_op->writes.size(), 1); - double_buffer_.insert(block_op->writes[0]->buffer); + // Step 4. Pick out blocks that writes with double buffering + for (const auto& ann : block->annotations) { + const String& ann_key = ann.first; + const ObjectRef& ann_value = ann.second; + if (ann_key == attr::double_buffer_scope) { + if (is_one(Downcast(ann_value))) { + ICHECK_EQ(block->writes.size(), 1); + double_buffer_.insert(block->writes[0]->buffer); + } } } - - for (size_t i = block_op->alloc_buffers.size(); i > 0; --i) { - const auto& alloc_buf = block_op->alloc_buffers[i - 1]; - const std::string name = std::string(alloc_buf->name); - if (name.substr(0, 18) == "normal_reduce_temp" || name.substr(0, 11) == "reduce_temp") { + // Step 5. Add allocation and storage scope + for (int i = n_alloc_buffer - 1; i >= 0; --i) { + const Buffer& alloc_buf = block->alloc_buffers[i]; + if (StartsWith(alloc_buf->name, "normal_reduce_temp") || + StartsWith(alloc_buf->name, "reduce_temp")) { continue; } - if (!buffers_lca_.at(alloc_buf).defined() || buffers_lca_.at(alloc_buf).same_as(old_stmt)) { + if (!buffers_lca_.at(alloc_buf).defined() || buffers_lca_.at(alloc_buf).same_as(old_block)) { PrimExpr extents = 1; - for (const auto& extent : buffers_region_.at(alloc_buf)) { + for (const arith::IntSet& extent : buffers_region_.at(alloc_buf)) { extents *= extent.max() - extent.min() + 1; } body = Allocate(alloc_buf->data, alloc_buf->dtype, {extents}, const_true(), body); - // Change empty scope into global - std::string scope = alloc_buf->scope.empty() ? "global" : alloc_buf->scope; + String scope = alloc_buf->scope; + if (scope.empty()) { + scope = "global"; + } body = AttrStmt(alloc_buf->data, attr::storage_scope, StringImm(scope), body); } } @@ -382,93 +372,89 @@ class BufferFlattener : public StmtExprMutator { return body; } - PrimExpr VisitExpr_(const VarNode* op) final { - // Replace the block var with its value - auto it = block_var_.find(op); - if (it != block_var_.end()) { - return Substitute(it->second, unit_loops_); - } else { - return Substitute(GetRef(op), unit_loops_); - } - } - Stmt VisitStmt_(const ForNode* op) final { Stmt old_stmt = GetRef(op); std::swap(old_stmt, parent_scope_); - Stmt stmt = StmtExprMutator::VisitStmt_(op); + For stmt = Downcast(StmtExprMutator::VisitStmt_(op)); std::swap(old_stmt, parent_scope_); + op = stmt.get(); - op = stmt.as(); - ICHECK(op != nullptr); - - ForKind kind = op->kind; - if (op->kind == ForKind::kThreadBinding) kind = ForKind::kSerial; - + std::vector removed_buffers; + // Add buffer allocation Stmt body = op->body; for (auto it = pending_allocate_.begin(); it != pending_allocate_.end();) { - if (old_stmt.same_as(buffers_lca_.at(it->first))) { + const Buffer& alloc_buf = it->first; + if (old_stmt.same_as(buffers_lca_.at(alloc_buf))) { PrimExpr extents = 1; - const auto& alloc_buf = it->second; - for (const auto& extent : buffers_region_.at(alloc_buf)) { + for (const arith::IntSet& extent : buffers_region_.at(alloc_buf)) { extents *= extent.max() - extent.min() + 1; } body = Allocate(alloc_buf->data, alloc_buf->dtype, {extents}, const_true(), body); // Change empty scope into global - std::string scope = alloc_buf->scope.empty() ? "global" : alloc_buf->scope; + String scope = alloc_buf->scope.empty() ? "global" : alloc_buf->scope; body = AttrStmt(alloc_buf->data, attr::storage_scope, StringImm(scope), body); - pending_allocate_.erase(it++); + removed_buffers.push_back(alloc_buf); + ++it; } else { - it++; + ++it; } } + for (const Buffer& buffer : removed_buffers) { + pending_allocate_.erase(buffer); + } - Stmt for_stmt; if (op->kind == ForKind::kThreadBinding) { ICHECK(op->thread_binding.defined()); String thread_tag = op->thread_binding.value()->thread_tag; if (!reduction_relative_.count(op->loop_var)) { - for_stmt = AttrStmt(IterVar(Range(op->min, op->extent), op->loop_var, - IterVarType::kThreadIndex, thread_tag), - thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent, - op->extent, body); - } else { - for_stmt = body; + IterVar iter_var(/*dom=*/Range(op->min, op->extent), + /*var=*/op->loop_var, + /*iter_type=*/IterVarType::kThreadIndex, + /*thread_tag=*/thread_tag); + String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; + body = AttrStmt(iter_var, attr_key, op->extent, body); } } else if (is_one(op->extent) && op->annotations.empty()) { return body; } else { - for_stmt = For(op->loop_var, op->min, op->extent, op->kind, body); + body = For(op->loop_var, op->min, op->extent, op->kind, body); } - for (const auto& annotation : op->annotations) { - if (attr::IsPragmaKey(annotation.first)) { - for_stmt = AttrStmt(op->loop_var, annotation.first, Downcast(annotation.second), - for_stmt); + const String& ann_key = annotation.first; + const ObjectRef& ann_value = annotation.second; + if (attr::IsPragmaKey(ann_key)) { + body = AttrStmt(op->loop_var, ann_key, Downcast(ann_value), body); } } - - return for_stmt; + return body; } - Stmt VisitStmt_(const AttrStmtNode* op) final { return StmtMutator::VisitStmt_(op); } - Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op != nullptr); - auto begins = ComputeRelativeIndices(op->buffer, op->indices); + BufferStore stmt = Downcast(StmtExprMutator::VisitStmt_(op)); + op = stmt.get(); + std::vector begins = ComputeRelativeIndices(op->buffer, op->indices); Buffer new_buffer = ReshapeBuffer(op->buffer, this->buffers_region_.at(op->buffer)); return new_buffer.vstore(begins, op->value); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - auto begins = ComputeRelativeIndices(op->buffer, op->indices); + BufferLoad expr = Downcast(StmtExprMutator::VisitExpr_(op)); + op = expr.get(); + std::vector begins = ComputeRelativeIndices(op->buffer, op->indices); Buffer new_buffer = ReshapeBuffer(op->buffer, this->buffers_region_.at(op->buffer)); return new_buffer.vload(begins, op->dtype); } + PrimExpr VisitExpr_(const VarNode* op) final { + // Replace the block var with its value + auto it = block_var_.find(op); + if (it != block_var_.end()) { + return Substitute(it->second, unit_loops_); + } else { + return Substitute(GetRef(op), unit_loops_); + } + } + PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::get_elem_offset())) { ICHECK_EQ(op->args.size(), 1); @@ -487,7 +473,7 @@ class BufferFlattener : public StmtExprMutator { const std::unordered_map& block_var_; const std::unordered_map& unit_loops_; const std::unordered_map& buffers_lca_; - const std::unordered_set& arg_buffers_; + const std::unordered_set& arg_buffers_; std::unordered_map pending_allocate_; std::unordered_set reduction_relative_; @@ -498,14 +484,16 @@ class BufferFlattener : public StmtExprMutator { * \brief Create a buffer with alternative shape */ Buffer ReshapeBuffer(const Buffer& buffer, const std::vector& region) { - if (arg_buffers_.count(buffer)) return buffer; - auto n = runtime::make_object(*(buffer.operator->())); + if (arg_buffers_.count(buffer.get())) { + return buffer; + } Array shape; - for (const auto& i : region) { + for (const arith::IntSet& i : region) { shape.push_back(i.max() - i.min() + 1); } + ObjectPtr n = make_object(*buffer.get()); n->shape = std::move(shape); - return Buffer(n); + return Buffer(std::move(n)); } /*! @@ -514,12 +502,10 @@ class BufferFlattener : public StmtExprMutator { */ std::vector ComputeRelativeIndices(const Buffer& buffer, const Array& indices) { - auto it = buffers_region_.find(buffer); - ICHECK(it != buffers_region_.end()); - const auto& region = it->second; + const std::vector& region = buffers_region_.at(buffer); std::vector new_indices; for (size_t i = 0; i < region.size(); ++i) { - if (arg_buffers_.count(buffer)) { + if (arg_buffers_.count(buffer.get())) { new_indices.push_back(indices[i]); } else { new_indices.push_back(indices[i] - region[i].min()); @@ -539,8 +525,14 @@ PrimFunc BufferFlatten(PrimFunc f) { ReductionTransformer reduction_transformer; fptr->body = reduction_transformer(fptr->body); + std::unordered_set arg_buffers; + for (const auto& kv : fptr->buffer_map) { + const Buffer& buffer = kv.second; + arg_buffers.insert(buffer.get()); + } + // Find the LCA of each Buffer access - LCADetector lca_detector(fptr->buffer_map); + LCADetector lca_detector(arg_buffers); lca_detector(fptr->body); // Recalculate the buffer region @@ -550,7 +542,7 @@ PrimFunc BufferFlatten(PrimFunc f) { // Transform BufferLoad/BufferStore into Load/Store BufferFlattener flattener(region_gatherer.block_var_, region_gatherer.unit_loops_, region_gatherer.buffers_region_, lca_detector.buffers_lca_, - lca_detector.arg_buffers_); + arg_buffers); fptr->body = flattener(fptr->body); return f; From e3c60ba1bbfc57b8c3cb062b0f13e2819a600d31 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 11 Mar 2021 07:12:08 +0000 Subject: [PATCH 02/20] ... --- src/tir/transforms/buffer_flatten.cc | 257 +++++++++++++++------------ 1 file changed, 139 insertions(+), 118 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 715901c7f5..99e7e4dd35 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -37,6 +37,44 @@ namespace tvm { namespace tir { +using NDIntSet = std::vector; + +void UnionWith(NDIntSet* lhs, const NDIntSet& rhs) { + ICHECK_EQ(lhs->size(), rhs.size()); + int ndim = rhs.size(); + for (int i = 0; i < ndim; ++i) { + arith::IntSet& int_set = lhs->at(i); + int_set = arith::Union({int_set, rhs.at(i)}); + } +} + +arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { + return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); +} + +NDIntSet NDIntSetFromShape(const Array& shape) { + NDIntSet result; + for (const PrimExpr& extent : shape) { + result.push_back(IntSetFromMinExtent(Integer(0), extent)); + } + return result; +} + +bool IsThreadBinded(const ForNode* loop) { + if (loop->kind != ForKind::kThreadBinding) { + return false; + } + ICHECK(loop->thread_binding.defined()); + std::string thread_tag = loop->thread_binding.value()->thread_tag; + if (StartsWith(thread_tag, "threadIdx")) { + return true; + } + if (StartsWith(thread_tag, "vthread")) { + return true; + } + return false; +} + class ReductionTransformer : public StmtMutator { public: Stmt VisitStmt_(const BlockNode* block) override { @@ -58,84 +96,86 @@ class ReductionTransformer : public StmtMutator { */ class LCADetector : public StmtExprVisitor { public: - explicit LCADetector(const std::unordered_set& arg_buffers) - : arg_buffers_(arg_buffers) { - for (const BufferNode* buffer : arg_buffers) { - buffers_lca_.emplace(GetRef(buffer), ObjectRef(nullptr)); + static Map> Detect(const PrimFunc& func) { + LCADetector detector; + // Buffers, who appear as arguments, do not have allocation sites + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + detector.buffers_lca_.emplace(buffer.get(), nullptr); } + detector(func->body); + // Prepare the return + Map> buffer_lca; + for (const auto& kv : detector.buffers_lca_) { + buffer_lca.Set(GetRef(kv.first), GetRef>(kv.second)); + } + return buffer_lca; } + private: void VisitStmt_(const ForNode* op) final { - Stmt n = GetRef(op); - ast_scopes_info_[n] = MakeScope(); + int n = ancestor_loops_.size(); + for_info_.emplace(op, ForInfo{ancestor_loops_.back(), n}); ancestor_loops_.push_back(op); StmtExprVisitor::VisitStmt_(op); ancestor_loops_.pop_back(); } - void VisitBuffer(const Buffer& buffer, const ObjectRef& n) { - ast_scopes_info_[n] = MakeScope(); - if (arg_buffers_.count(buffer.get())) { - return; - } - if (buffers_lca_.count(buffer)) { - buffers_lca_[buffer] = LowestCommonAncestor(n, buffers_lca_[buffer]); - } else { - buffers_lca_[buffer] = n; - } - } - void VisitExpr_(const BufferLoadNode* op) final { - VisitBuffer(op->buffer, GetRef(op)); + CalcBufferLCA(op->buffer.get()); StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { - VisitBuffer(op->buffer, GetRef(op)); + CalcBufferLCA(op->buffer.get()); StmtExprVisitor::VisitStmt_(op); } - /*! \brief The map from Buffer to its LCA Stmt/Expr */ - std::unordered_map buffers_lca_; + void CalcBufferLCA(const BufferNode* buffer) { + const ForNode*& lca = buffers_lca_[buffer]; + lca = LowestCommonAncestor(lca, ancestor_loops_.back()); + } + + const ForNode* LowestCommonAncestor(const ForNode* lhs, const ForNode* rhs) const { + while (lhs != nullptr && rhs != nullptr && lhs != rhs) { + auto it_l = for_info_.find(lhs); + auto it_r = for_info_.find(rhs); + ICHECK(it_l != for_info_.end()); + ICHECK(it_r != for_info_.end()); + const ForInfo& l = it_l->second; + const ForInfo& r = it_r->second; + if (l.depth == r.depth) { + lhs = l.parent_loop; + rhs = r.parent_loop; + } else if (l.depth < r.depth) { + rhs = r.parent_loop; + } else { + lhs = l.parent_loop; + } + } + if (lhs == nullptr) { + return rhs; + } + if (rhs == nullptr) { + return lhs; + } + return lhs; + } - private: /*! \brief The AST node information for querying LCA */ - struct ScopeInfo { + struct ForInfo { // The parent loop node - const StmtNode* parent_scope; + const ForNode* parent_loop; // The scope depth in the AST int depth; }; - ScopeInfo MakeScope() { - int n = ancestor_loops_.size(); - return (n == 0) ? ScopeInfo{nullptr, 0} : ScopeInfo{ancestor_loops_.back(), n}; - } - - /*! \brief The Buffer in function args */ - const std::unordered_set& arg_buffers_; /*! \brief The current scope initializing with Null */ - std::vector ancestor_loops_; + std::vector ancestor_loops_ = {nullptr}; /*! \brief The parent and depth info of each Loop/BufferLoad/BufferStore Node */ - std::unordered_map ast_scopes_info_; - - ObjectRef LowestCommonAncestor(ObjectRef lhs, ObjectRef rhs) { - ICHECK(lhs.defined()); - ICHECK(rhs.defined()); - ICHECK(ast_scopes_info_.count(lhs)); - ICHECK(ast_scopes_info_.count(rhs)); - while (ast_scopes_info_[lhs].depth > ast_scopes_info_[rhs].depth) { - lhs = GetRef(ast_scopes_info_[lhs].parent_scope); - } - while (ast_scopes_info_[lhs].depth < ast_scopes_info_[rhs].depth) { - rhs = GetRef(ast_scopes_info_[rhs].parent_scope); - } - while (!lhs.same_as(rhs)) { - lhs = GetRef(ast_scopes_info_[lhs].parent_scope); - rhs = GetRef(ast_scopes_info_[rhs].parent_scope); - } - return lhs; - } + std::unordered_map for_info_ = {}; + /*! \brief The map from Buffer to its LCA Stmt/Expr */ + std::unordered_map buffers_lca_ = {}; }; /*! @@ -143,93 +183,81 @@ class LCADetector : public StmtExprVisitor { */ class RegionGatherer : public StmtVisitor { public: - RegionGatherer( - const std::unordered_map& buffers_lca, - const Map& func_args) + RegionGatherer(const Map>& buffers_lca, const Map& func_args) : buffers_lca_(buffers_lca) { for (const auto& arg : func_args) { const Buffer& buffer = arg.second; - std::vector region; - for (const PrimExpr& size : buffer->shape) { - region.push_back(IntSetFromMinExtent(0, size)); - } - buffers_region_[buffer] = region; + buffers_region_[buffer] = NDIntSetFromShape(buffer->shape); } } void VisitStmt_(const ForNode* op) final { - For loop = GetRef(op); - loop_stack_.push_back(loop); + ancestor_loops_.push_back(op); if (!op->thread_binding.defined() && op->annotations.empty() && is_one(op->extent)) { unit_loops_[op->loop_var.get()] = op->min; } StmtVisitor::VisitStmt_(op); - loop_stack_.pop_back(); + ancestor_loops_.pop_back(); } void VisitStmt_(const BlockRealizeNode* realize) final { const auto* block = realize->block.as(); CHECK(!block->init.defined()); - for (size_t i = 0; i < block->iter_vars.size(); ++i) { + // Update the mapping from block vars to loop vars so that we can substitute them + CHECK_EQ(block->iter_vars.size(), realize->binding_values.size()); + int n_block_vars = block->iter_vars.size(); + for (int i = 0; i < n_block_vars; ++i) { const IterVar& iter = block->iter_vars[i]; const PrimExpr& v = realize->binding_values[i]; block_var_[iter->var.get()] = ReplaceBlockVar(v); } - for (const BufferRegion& buffer_region : block->reads) { - UnionWith(&buffers_region_.at(buffer_region->buffer), GatherRegion(buffer_region)); + for (const BufferRegion& read_region : block->reads) { + NDIntSet& alloc_region = buffers_region_.at(read_region->buffer); + UnionWith(&alloc_region, GatherRegion(read_region)); } - for (const BufferRegion& buffer_region : block->writes) { - UnionWith(&buffers_region_.at(buffer_region->buffer), GatherRegion(buffer_region)); + for (const BufferRegion& write_region : block->writes) { + NDIntSet& alloc_region = buffers_region_.at(write_region->buffer); + UnionWith(&alloc_region, GatherRegion(write_region)); } for (const Buffer& alloc_buf : block->alloc_buffers) { // Initialize the buffer region with empty region. - buffers_region_[alloc_buf] = - std::vector(alloc_buf->shape.size(), arith::IntSet::Nothing()); + // TODO + buffers_region_[alloc_buf] = NDIntSet(alloc_buf->shape.size(), arith::IntSet::Nothing()); } VisitStmt(block->body); } /*! \brief The used region of each Buffer */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> - buffers_region_; + std::unordered_map buffers_region_; /*! \brief The map from block vars to the expr value */ std::unordered_map block_var_; /*! \brief The map from unit loop vars to the expr value */ std::unordered_map unit_loops_; private: - const std::unordered_map& buffers_lca_; - - /*! \brief The loops from the current node up to the root */ - std::vector loop_stack_; - - static void UnionWith(std::vector* buffer_new_region, - const std::vector& region) { - ICHECK_EQ(buffer_new_region->size(), region.size()); - for (size_t i = 0; i < region.size(); ++i) { - arith::IntSet& int_set = buffer_new_region->at(i); - int_set = arith::Union({int_set, region[i]}); - } + PrimExpr ReplaceBlockVar(const PrimExpr& expr) const { + return Substitute(Substitute(expr, block_var_), unit_loops_); } /*! * \brief Gather used buffer region */ - std::vector GatherRegion(const BufferRegion& buffer_region) { + NDIntSet GatherRegion(const BufferRegion& buffer_region) const { std::unordered_map dom_map; - const ObjectRef& lca = buffers_lca_.at(buffer_region->buffer); + const Optional& lca = buffers_lca_.at(buffer_region->buffer); // Every loop will be relaxed if the lca is the root bool need_relax = !lca.defined(); - for (const For& loop : loop_stack_) { - const VarNode* var = loop->loop_var.get(); + for (const ForNode* loop : ancestor_loops_) { + const VarNode* loop_var = loop->loop_var.get(); + // TODO if (need_relax || (buffer_region->buffer->scope == "shared" && IsThreadBinded(loop))) { - dom_map[var] = IntSetFromMinExtent(loop->min, loop->extent); + dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); } - if (loop.same_as(lca)) { + if (loop == lca.get()) { need_relax = true; } } - std::vector region; + NDIntSet region; for (const Range& range : buffer_region->region) { PrimExpr min = ReplaceBlockVar(range->min); PrimExpr extent = ReplaceBlockVar(range->extent); @@ -238,19 +266,10 @@ class RegionGatherer : public StmtVisitor { return region; } - PrimExpr ReplaceBlockVar(const PrimExpr& expr) const { - return Substitute(Substitute(expr, block_var_), unit_loops_); - } - - static arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { - return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); - } - - static bool IsThreadBinded(const For& loop) { - if (loop->kind != ForKind::kThreadBinding || !loop->thread_binding.defined()) return false; - std::string thread_tag = loop->thread_binding.value()->thread_tag; - return (thread_tag.substr(0, 9) == "threadIdx" || thread_tag.substr(0, 7) == "vthread"); - } + /*! \brief The map from Buffer to its LCA Stmt/Expr */ + const Map>& buffers_lca_; + /*! \brief The loops from the current node up to the root */ + std::vector ancestor_loops_; }; /*! @@ -261,9 +280,8 @@ class BufferFlattener : public StmtExprMutator { BufferFlattener( const std::unordered_map& block_var, const std::unordered_map& unit_loops, - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& - buffers_region, - const std::unordered_map& buffers_lca, + const std::unordered_map& buffers_region, + const Map>& buffers_lca, const std::unordered_set& arg_buffers) : buffers_region_(buffers_region), block_var_(block_var), @@ -468,11 +486,10 @@ class BufferFlattener : public StmtExprMutator { } private: - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& - buffers_region_; + const std::unordered_map& buffers_region_; const std::unordered_map& block_var_; const std::unordered_map& unit_loops_; - const std::unordered_map& buffers_lca_; + const Map>& buffers_lca_; const std::unordered_set& arg_buffers_; std::unordered_map pending_allocate_; @@ -483,7 +500,7 @@ class BufferFlattener : public StmtExprMutator { /*! * \brief Create a buffer with alternative shape */ - Buffer ReshapeBuffer(const Buffer& buffer, const std::vector& region) { + Buffer ReshapeBuffer(const Buffer& buffer, const NDIntSet& region) { if (arg_buffers_.count(buffer.get())) { return buffer; } @@ -502,7 +519,7 @@ class BufferFlattener : public StmtExprMutator { */ std::vector ComputeRelativeIndices(const Buffer& buffer, const Array& indices) { - const std::vector& region = buffers_region_.at(buffer); + const NDIntSet& region = buffers_region_.at(buffer); std::vector new_indices; for (size_t i = 0; i < region.size(); ++i) { if (arg_buffers_.count(buffer.get())) { @@ -516,7 +533,7 @@ class BufferFlattener : public StmtExprMutator { }; PrimFunc BufferFlatten(PrimFunc f) { - auto fptr = f.CopyOnWrite(); + tvm::tir::PrimFuncNode* fptr = f.CopyOnWrite(); // Check memory and execution hierarchy VerifyExecScope(f); @@ -532,17 +549,21 @@ PrimFunc BufferFlatten(PrimFunc f) { } // Find the LCA of each Buffer access - LCADetector lca_detector(arg_buffers); - lca_detector(fptr->body); + // LCADetector lca_detector(arg_buffers); + // lca_detector(fptr->body); + + Map> buffer_lca = LCADetector::Detect(f); + // for (const auto& kv : lca_detector.buffers_lca_) { + // buffer_lca.Set(GetRef(kv.first), GetRef(kv.second)); + // } // Recalculate the buffer region - RegionGatherer region_gatherer(lca_detector.buffers_lca_, fptr->buffer_map); + RegionGatherer region_gatherer(buffer_lca, fptr->buffer_map); region_gatherer(fptr->body); // Transform BufferLoad/BufferStore into Load/Store BufferFlattener flattener(region_gatherer.block_var_, region_gatherer.unit_loops_, - region_gatherer.buffers_region_, lca_detector.buffers_lca_, - arg_buffers); + region_gatherer.buffers_region_, buffer_lca, arg_buffers); fptr->body = flattener(fptr->body); return f; From f0702101ed3775652e1a9ee707f0a7ef5bf2c546 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 11 Mar 2021 09:13:08 +0000 Subject: [PATCH 03/20] ... --- src/tir/transforms/buffer_flatten.cc | 177 +++++++++++++++------------ 1 file changed, 97 insertions(+), 80 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 99e7e4dd35..62706d7a08 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -39,6 +39,10 @@ namespace tir { using NDIntSet = std::vector; +arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { + return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); +} + void UnionWith(NDIntSet* lhs, const NDIntSet& rhs) { ICHECK_EQ(lhs->size(), rhs.size()); int ndim = rhs.size(); @@ -48,8 +52,12 @@ void UnionWith(NDIntSet* lhs, const NDIntSet& rhs) { } } -arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { - return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); +PrimExpr NDIntSetArea(const NDIntSet& nd_int_set) { + PrimExpr area = 1; + for (const arith::IntSet& int_set : nd_int_set) { + area = area * (int_set.max() - int_set.min() + 1); + } + return area; } NDIntSet NDIntSetFromShape(const Array& shape) { @@ -75,6 +83,25 @@ bool IsThreadBinded(const ForNode* loop) { return false; } +bool IsReduceTempBuffer(const Buffer& buffer) { + return StartsWith(buffer->name, "normal_reduce_temp") || // + StartsWith(buffer->name, "reduce_temp"); +} + +String NormalizeStorageScope(const String& s) { + if (s.empty()) { + return "global"; + } + return s; +} + +Stmt MakeAllocStmt(const Buffer& buffer, const PrimExpr& area, Stmt body) { + body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), body); + body = AttrStmt(buffer->data, attr::storage_scope, + StringImm(NormalizeStorageScope(buffer->scope)), body); + return body; +} + class ReductionTransformer : public StmtMutator { public: Stmt VisitStmt_(const BlockNode* block) override { @@ -182,6 +209,8 @@ class LCADetector : public StmtExprVisitor { * \brief Gather the used region of each buffers. */ class RegionGatherer : public StmtVisitor { + using VarDomain = std::unordered_map; + public: RegionGatherer(const Map>& buffers_lca, const Map& func_args) : buffers_lca_(buffers_lca) { @@ -212,12 +241,18 @@ class RegionGatherer : public StmtVisitor { block_var_[iter->var.get()] = ReplaceBlockVar(v); } for (const BufferRegion& read_region : block->reads) { - NDIntSet& alloc_region = buffers_region_.at(read_region->buffer); - UnionWith(&alloc_region, GatherRegion(read_region)); + const Buffer& buffer = read_region->buffer; + VarDomain dom_map = LoopVarDomain(buffer); + NDIntSet region = AsRegion(read_region->region, dom_map); + NDIntSet& alloc_region = buffers_region_.at(buffer); + UnionWith(&alloc_region, region); } for (const BufferRegion& write_region : block->writes) { - NDIntSet& alloc_region = buffers_region_.at(write_region->buffer); - UnionWith(&alloc_region, GatherRegion(write_region)); + const Buffer& buffer = write_region->buffer; + VarDomain dom_map = LoopVarDomain(buffer); + NDIntSet region = AsRegion(write_region->region, dom_map); + NDIntSet& alloc_region = buffers_region_.at(buffer); + UnionWith(&alloc_region, region); } for (const Buffer& alloc_buf : block->alloc_buffers) { // Initialize the buffer region with empty region. @@ -239,26 +274,28 @@ class RegionGatherer : public StmtVisitor { return Substitute(Substitute(expr, block_var_), unit_loops_); } - /*! - * \brief Gather used buffer region - */ - NDIntSet GatherRegion(const BufferRegion& buffer_region) const { - std::unordered_map dom_map; - const Optional& lca = buffers_lca_.at(buffer_region->buffer); + VarDomain LoopVarDomain(const Buffer& buffer) const { + VarDomain dom_map; + const Optional& lca = this->buffers_lca_.at(buffer); // Every loop will be relaxed if the lca is the root bool need_relax = !lca.defined(); - for (const ForNode* loop : ancestor_loops_) { + for (const ForNode* loop : this->ancestor_loops_) { const VarNode* loop_var = loop->loop_var.get(); // TODO - if (need_relax || (buffer_region->buffer->scope == "shared" && IsThreadBinded(loop))) { + if (need_relax || (buffer->scope == "shared" && IsThreadBinded(loop))) { dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); } if (loop == lca.get()) { need_relax = true; } } + return dom_map; + } + + NDIntSet AsRegion(const Array& buffer_region, const VarDomain& dom_map) const { NDIntSet region; - for (const Range& range : buffer_region->region) { + region.reserve(buffer_region.size()); + for (const Range& range : buffer_region) { PrimExpr min = ReplaceBlockVar(range->min); PrimExpr extent = ReplaceBlockVar(range->extent); region.push_back(arith::EvalSet(Range::FromMinExtent(min, extent), dom_map)); @@ -277,7 +314,7 @@ class RegionGatherer : public StmtVisitor { */ class BufferFlattener : public StmtExprMutator { public: - BufferFlattener( + explicit BufferFlattener( const std::unordered_map& block_var, const std::unordered_map& unit_loops, const std::unordered_map& buffers_region, @@ -315,33 +352,29 @@ class BufferFlattener : public StmtExprMutator { Stmt VisitStmt_(const BlockRealizeNode* realize) final { // Handle allocations const auto* block = realize->block.get(); - Block old_block = realize->block; - int n_alloc_buffer = block->alloc_buffers.size(); - int n_iter_var = block->iter_vars.size(); - // Step 1. Figure out `pending_allocate_` - for (int i = n_alloc_buffer - 1; i >= 0; --i) { - // Why the order - const Buffer& buffer = block->alloc_buffers[i]; - if (StartsWith(buffer->name, "normal_reduce_temp") || - StartsWith(buffer->name, "reduce_temp")) { + // Step 1. Add non-root block allocations into `pending_allocate_` + for (const Buffer& buffer : block->alloc_buffers) { + if (IsReduceTempBuffer(buffer)) { continue; } if (buffers_lca_.at(buffer).defined()) { - pending_allocate_[buffer] = buffer; + pending_allocate_.insert(buffer.get()); } } - for (int i = 0; i < n_iter_var; ++i) { + // Step 2. Add reduction loop vars + CHECK_EQ(block->iter_vars.size(), realize->binding_values.size()); + int n_block_vars = block->iter_vars.size(); + for (int i = 0; i < n_block_vars; ++i) { const IterVar& block_var = block->iter_vars[i]; const PrimExpr& binding_value = realize->binding_values[i]; - if (block_var->iter_type != kCommReduce) { - continue; - } - std::unordered_set vars = Vars(binding_value); - for (const VarNode* var : vars) { - this->reduction_relative_.insert(GetRef(var)); + if (block_var->iter_type == kCommReduce) { + std::unordered_set vars = Vars(binding_value); + for (const VarNode* var : vars) { + this->reduction_relative_.insert(GetRef(var)); + } } } - // Step 2. Visit the body + // Step 3. Visit the body Stmt parent_scope = realize->block; std::swap(parent_scope, parent_scope_); BlockRealize new_stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); @@ -349,12 +382,12 @@ class BufferFlattener : public StmtExprMutator { // Reset `realize` and `block` realize = new_stmt.get(); block = realize->block.get(); - // Step 3. Transform the `predicate` to if-then-else + // Step 4. Transform the `predicate` to if-then-else Stmt body = block->body; if (!is_one(realize->predicate)) { body = IfThenElse(realize->predicate, body); } - // Step 4. Pick out blocks that writes with double buffering + // Step 5. Pick out blocks that writes with double buffering for (const auto& ann : block->annotations) { const String& ann_key = ann.first; const ObjectRef& ann_value = ann.second; @@ -365,63 +398,44 @@ class BufferFlattener : public StmtExprMutator { } } } - // Step 5. Add allocation and storage scope - for (int i = n_alloc_buffer - 1; i >= 0; --i) { - const Buffer& alloc_buf = block->alloc_buffers[i]; - if (StartsWith(alloc_buf->name, "normal_reduce_temp") || - StartsWith(alloc_buf->name, "reduce_temp")) { + // Step 6. Add root block allocations + for (const Buffer& buffer : block->alloc_buffers) { + if (IsReduceTempBuffer(buffer)) { continue; } - if (!buffers_lca_.at(alloc_buf).defined() || buffers_lca_.at(alloc_buf).same_as(old_block)) { - PrimExpr extents = 1; - for (const arith::IntSet& extent : buffers_region_.at(alloc_buf)) { - extents *= extent.max() - extent.min() + 1; - } - body = Allocate(alloc_buf->data, alloc_buf->dtype, {extents}, const_true(), body); - // Change empty scope into global - String scope = alloc_buf->scope; - if (scope.empty()) { - scope = "global"; - } - body = AttrStmt(alloc_buf->data, attr::storage_scope, StringImm(scope), body); + if (!buffers_lca_.at(buffer).defined()) { + const NDIntSet& region = buffers_region_.at(buffer); + body = MakeAllocStmt(buffer, NDIntSetArea(region), body); } } - return body; } Stmt VisitStmt_(const ForNode* op) final { - Stmt old_stmt = GetRef(op); - std::swap(old_stmt, parent_scope_); + // Step 1. Find the buffer that can be allocated under the current loop + std::vector alloc_buffers; + for (const BufferNode* buffer : pending_allocate_) { + const Optional alloc_site = buffers_lca_.at(GetRef(buffer)); + if (op == alloc_site.get()) { + alloc_buffers.push_back(buffer); + } + } + // Step 2. Visit recursively + Stmt parent_scope = GetRef(op); + std::swap(parent_scope, parent_scope_); For stmt = Downcast(StmtExprMutator::VisitStmt_(op)); - std::swap(old_stmt, parent_scope_); + std::swap(parent_scope, parent_scope_); op = stmt.get(); - - std::vector removed_buffers; - // Add buffer allocation + // Step 3. Add buffer allocation Stmt body = op->body; - for (auto it = pending_allocate_.begin(); it != pending_allocate_.end();) { - const Buffer& alloc_buf = it->first; - if (old_stmt.same_as(buffers_lca_.at(alloc_buf))) { - PrimExpr extents = 1; - for (const arith::IntSet& extent : buffers_region_.at(alloc_buf)) { - extents *= extent.max() - extent.min() + 1; - } - body = Allocate(alloc_buf->data, alloc_buf->dtype, {extents}, const_true(), body); - // Change empty scope into global - String scope = alloc_buf->scope.empty() ? "global" : alloc_buf->scope; - body = AttrStmt(alloc_buf->data, attr::storage_scope, StringImm(scope), body); - removed_buffers.push_back(alloc_buf); - ++it; - } else { - ++it; - } - } - for (const Buffer& buffer : removed_buffers) { + for (const BufferNode* buffer : alloc_buffers) { + const NDIntSet& region = buffers_region_.at(GetRef(buffer)); + body = MakeAllocStmt(GetRef(buffer), NDIntSetArea(region), body); pending_allocate_.erase(buffer); } - + // Step 4. Add the for loop accordingly if (op->kind == ForKind::kThreadBinding) { + // Case 1. Thread binding ICHECK(op->thread_binding.defined()); String thread_tag = op->thread_binding.value()->thread_tag; if (!reduction_relative_.count(op->loop_var)) { @@ -433,10 +447,13 @@ class BufferFlattener : public StmtExprMutator { body = AttrStmt(iter_var, attr_key, op->extent, body); } } else if (is_one(op->extent) && op->annotations.empty()) { + // Case 2. Handle unit loop return body; } else { + // Case 3. An ordinary loop body = For(op->loop_var, op->min, op->extent, op->kind, body); } + // Step 5. Handle annotations for (const auto& annotation : op->annotations) { const String& ann_key = annotation.first; const ObjectRef& ann_value = annotation.second; @@ -492,7 +509,7 @@ class BufferFlattener : public StmtExprMutator { const Map>& buffers_lca_; const std::unordered_set& arg_buffers_; - std::unordered_map pending_allocate_; + std::unordered_set pending_allocate_; std::unordered_set reduction_relative_; std::unordered_set double_buffer_; Stmt parent_scope_; From c02e404dfeed1060156543a6190fc71a66dced3f Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 13 Mar 2021 02:35:19 +0000 Subject: [PATCH 04/20] ... --- src/tir/transforms/buffer_flatten.cc | 100 ++++++++++++++------------- 1 file changed, 52 insertions(+), 48 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 62706d7a08..dce41bc5fd 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -254,10 +254,9 @@ class RegionGatherer : public StmtVisitor { NDIntSet& alloc_region = buffers_region_.at(buffer); UnionWith(&alloc_region, region); } - for (const Buffer& alloc_buf : block->alloc_buffers) { + for (const Buffer& buffer : block->alloc_buffers) { // Initialize the buffer region with empty region. - // TODO - buffers_region_[alloc_buf] = NDIntSet(alloc_buf->shape.size(), arith::IntSet::Nothing()); + buffers_region_[buffer] = NDIntSet(buffer->shape.size(), arith::IntSet::Nothing()); } VisitStmt(block->body); } @@ -328,24 +327,23 @@ class BufferFlattener : public StmtExprMutator { Stmt VisitStmt_(const SeqStmtNode* op) final { Array seq; + seq.reserve(op->seq.size()); for (const Stmt& stmt : op->seq) { - std::unordered_set double_buffer; + std::unordered_set double_buffer; std::swap(double_buffer, double_buffer_); Stmt body = VisitStmt(stmt); std::swap(double_buffer, double_buffer_); - - for (const Buffer& buffer : double_buffer) { - ObjectRef lca = buffers_lca_.at(buffer); - if (lca.defined() && lca.same_as(parent_scope_)) { + const StmtNode* parent_scope = parent_scopes_.back(); + for (const BufferNode* buffer : double_buffer) { + const Object* lca = buffers_lca_.at(GetRef(buffer)).get(); + if (lca != nullptr && lca == parent_scope) { body = AttrStmt(buffer->data, attr::double_buffer_scope, 1, body); } else { double_buffer_.insert(buffer); } } - seq.push_back(body); } - return SeqStmt(seq); } @@ -370,18 +368,15 @@ class BufferFlattener : public StmtExprMutator { if (block_var->iter_type == kCommReduce) { std::unordered_set vars = Vars(binding_value); for (const VarNode* var : vars) { - this->reduction_relative_.insert(GetRef(var)); + this->reduction_loop_vars_.insert(var); } } } // Step 3. Visit the body - Stmt parent_scope = realize->block; - std::swap(parent_scope, parent_scope_); - BlockRealize new_stmt = Downcast(StmtExprMutator::VisitStmt_(realize)); - std::swap(parent_scope, parent_scope_); - // Reset `realize` and `block` - realize = new_stmt.get(); - block = realize->block.get(); + parent_scopes_.push_back(realize->block.get()); + Block new_block = Downcast(this->VisitStmt(realize->block)); + block = new_block.get(); + parent_scopes_.pop_back(); // Step 4. Transform the `predicate` to if-then-else Stmt body = block->body; if (!is_one(realize->predicate)) { @@ -394,7 +389,8 @@ class BufferFlattener : public StmtExprMutator { if (ann_key == attr::double_buffer_scope) { if (is_one(Downcast(ann_value))) { ICHECK_EQ(block->writes.size(), 1); - double_buffer_.insert(block->writes[0]->buffer); + const BufferRegion& write = block->writes[0]; + double_buffer_.insert(write->buffer.get()); } } } @@ -421,13 +417,12 @@ class BufferFlattener : public StmtExprMutator { } } // Step 2. Visit recursively - Stmt parent_scope = GetRef(op); - std::swap(parent_scope, parent_scope_); - For stmt = Downcast(StmtExprMutator::VisitStmt_(op)); - std::swap(parent_scope, parent_scope_); - op = stmt.get(); + parent_scopes_.push_back(op); + Stmt body = this->VisitStmt(op->body); + PrimExpr min = this->VisitExpr(op->min); + PrimExpr extent = this->VisitExpr(op->extent); + parent_scopes_.pop_back(); // Step 3. Add buffer allocation - Stmt body = op->body; for (const BufferNode* buffer : alloc_buffers) { const NDIntSet& region = buffers_region_.at(GetRef(buffer)); body = MakeAllocStmt(GetRef(buffer), NDIntSetArea(region), body); @@ -438,20 +433,20 @@ class BufferFlattener : public StmtExprMutator { // Case 1. Thread binding ICHECK(op->thread_binding.defined()); String thread_tag = op->thread_binding.value()->thread_tag; - if (!reduction_relative_.count(op->loop_var)) { - IterVar iter_var(/*dom=*/Range(op->min, op->extent), + if (!reduction_loop_vars_.count(op->loop_var.get())) { + IterVar iter_var(/*dom=*/Range(min, extent), /*var=*/op->loop_var, /*iter_type=*/IterVarType::kThreadIndex, /*thread_tag=*/thread_tag); String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; - body = AttrStmt(iter_var, attr_key, op->extent, body); + body = AttrStmt(iter_var, attr_key, extent, body); } - } else if (is_one(op->extent) && op->annotations.empty()) { + } else if (is_one(extent) && op->annotations.empty()) { // Case 2. Handle unit loop return body; } else { // Case 3. An ordinary loop - body = For(op->loop_var, op->min, op->extent, op->kind, body); + body = For(op->loop_var, min, extent, op->kind, body); } // Step 5. Handle annotations for (const auto& annotation : op->annotations) { @@ -465,18 +460,19 @@ class BufferFlattener : public StmtExprMutator { } Stmt VisitStmt_(const BufferStoreNode* op) final { - BufferStore stmt = Downcast(StmtExprMutator::VisitStmt_(op)); - op = stmt.get(); - std::vector begins = ComputeRelativeIndices(op->buffer, op->indices); - Buffer new_buffer = ReshapeBuffer(op->buffer, this->buffers_region_.at(op->buffer)); - return new_buffer.vstore(begins, op->value); + const Buffer& buffer = op->buffer; + std::vector indices = VisitIndices(op->indices); + PrimExpr value = this->VisitExpr(op->value); + std::vector begins = ComputeRelativeIndices(buffer, indices); + Buffer new_buffer = ReshapeBuffer(buffer, this->buffers_region_.at(buffer)); + return new_buffer.vstore(begins, value); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - BufferLoad expr = Downcast(StmtExprMutator::VisitExpr_(op)); - op = expr.get(); - std::vector begins = ComputeRelativeIndices(op->buffer, op->indices); - Buffer new_buffer = ReshapeBuffer(op->buffer, this->buffers_region_.at(op->buffer)); + const Buffer& buffer = op->buffer; + std::vector indices = VisitIndices(op->indices); + std::vector begins = ComputeRelativeIndices(buffer, indices); + Buffer new_buffer = ReshapeBuffer(buffer, this->buffers_region_.at(buffer)); return new_buffer.vload(begins, op->dtype); } @@ -492,14 +488,14 @@ class BufferFlattener : public StmtExprMutator { PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::get_elem_offset())) { + // Handle `get_elem_offset` ICHECK_EQ(op->args.size(), 1); - const auto* buffer_load = op->args[0].as(); - ICHECK(buffer_load != nullptr); - Load load = Downcast(VisitExpr(op->args[0])); + const PrimExpr& arg = op->args[0]; + ICHECK(arg->IsInstance()); + Load load = Downcast(VisitExpr(arg)); return load->index; - } else { - return StmtExprMutator::VisitExpr_(op); } + return StmtExprMutator::VisitExpr_(op); } private: @@ -510,9 +506,9 @@ class BufferFlattener : public StmtExprMutator { const std::unordered_set& arg_buffers_; std::unordered_set pending_allocate_; - std::unordered_set reduction_relative_; - std::unordered_set double_buffer_; - Stmt parent_scope_; + std::unordered_set reduction_loop_vars_; + std::unordered_set double_buffer_; + std::vector parent_scopes_; /*! * \brief Create a buffer with alternative shape @@ -547,6 +543,15 @@ class BufferFlattener : public StmtExprMutator { } return new_indices; } + + std::vector VisitIndices(const Array& indices) { + std::vector result; + result.reserve(indices.size()); + for (const PrimExpr& index : indices) { + result.push_back(this->VisitExpr(index)); + } + return result; + } }; PrimFunc BufferFlatten(PrimFunc f) { @@ -582,7 +587,6 @@ PrimFunc BufferFlatten(PrimFunc f) { BufferFlattener flattener(region_gatherer.block_var_, region_gatherer.unit_loops_, region_gatherer.buffers_region_, buffer_lca, arg_buffers); fptr->body = flattener(fptr->body); - return f; } From 074a453af7b099d7bfab41aa59e027be65f69f9c Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 13 Mar 2021 03:13:44 +0000 Subject: [PATCH 05/20] ... --- src/tir/transforms/buffer_flatten.cc | 48 ++++++++++------------------ 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index dce41bc5fd..c6d08a18c3 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -212,9 +212,9 @@ class RegionGatherer : public StmtVisitor { using VarDomain = std::unordered_map; public: - RegionGatherer(const Map>& buffers_lca, const Map& func_args) + RegionGatherer(const Map>& buffers_lca, const PrimFunc& f) : buffers_lca_(buffers_lca) { - for (const auto& arg : func_args) { + for (const auto& arg : f->buffer_map) { const Buffer& buffer = arg.second; buffers_region_[buffer] = NDIntSetFromShape(buffer->shape); } @@ -317,13 +317,18 @@ class BufferFlattener : public StmtExprMutator { const std::unordered_map& block_var, const std::unordered_map& unit_loops, const std::unordered_map& buffers_region, - const Map>& buffers_lca, - const std::unordered_set& arg_buffers) + const Map>& buffers_lca, const PrimFunc& func) : buffers_region_(buffers_region), block_var_(block_var), unit_loops_(unit_loops), buffers_lca_(buffers_lca), - arg_buffers_(arg_buffers) {} + arg_buffers_{} { + arg_buffers_.reserve(func->buffer_map.size()); + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + arg_buffers_.insert(buffer.get()); + } + } Stmt VisitStmt_(const SeqStmtNode* op) final { Array seq; @@ -503,8 +508,7 @@ class BufferFlattener : public StmtExprMutator { const std::unordered_map& block_var_; const std::unordered_map& unit_loops_; const Map>& buffers_lca_; - const std::unordered_set& arg_buffers_; - + std::unordered_set arg_buffers_; std::unordered_set pending_allocate_; std::unordered_set reduction_loop_vars_; std::unordered_set double_buffer_; @@ -556,36 +560,18 @@ class BufferFlattener : public StmtExprMutator { PrimFunc BufferFlatten(PrimFunc f) { tvm::tir::PrimFuncNode* fptr = f.CopyOnWrite(); - - // Check memory and execution hierarchy + // Step 0. Check memory and execution hierarchy VerifyExecScope(f); - - // Transform the reduction calls to BufferStore + // Step 1.Transform the reduction calls to BufferStore ReductionTransformer reduction_transformer; fptr->body = reduction_transformer(fptr->body); - - std::unordered_set arg_buffers; - for (const auto& kv : fptr->buffer_map) { - const Buffer& buffer = kv.second; - arg_buffers.insert(buffer.get()); - } - - // Find the LCA of each Buffer access - // LCADetector lca_detector(arg_buffers); - // lca_detector(fptr->body); - + // Step 2. Recalculate the buffer region Map> buffer_lca = LCADetector::Detect(f); - // for (const auto& kv : lca_detector.buffers_lca_) { - // buffer_lca.Set(GetRef(kv.first), GetRef(kv.second)); - // } - - // Recalculate the buffer region - RegionGatherer region_gatherer(buffer_lca, fptr->buffer_map); + RegionGatherer region_gatherer(buffer_lca, f); region_gatherer(fptr->body); - - // Transform BufferLoad/BufferStore into Load/Store + // Step 3. Transform BufferLoad/BufferStore into Load/Store BufferFlattener flattener(region_gatherer.block_var_, region_gatherer.unit_loops_, - region_gatherer.buffers_region_, buffer_lca, arg_buffers); + region_gatherer.buffers_region_, buffer_lca, f); fptr->body = flattener(fptr->body); return f; } From a53ec1e16743ee7e4c72af99b41b3a7baeaf8400 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 13 Mar 2021 03:17:47 +0000 Subject: [PATCH 06/20] ... --- src/tir/transforms/buffer_flatten.cc | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index c6d08a18c3..2791a72f28 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -338,10 +338,10 @@ class BufferFlattener : public StmtExprMutator { std::swap(double_buffer, double_buffer_); Stmt body = VisitStmt(stmt); std::swap(double_buffer, double_buffer_); - const StmtNode* parent_scope = parent_scopes_.back(); + const ForNode* loop = ancestor_loops_.back(); for (const BufferNode* buffer : double_buffer) { const Object* lca = buffers_lca_.at(GetRef(buffer)).get(); - if (lca != nullptr && lca == parent_scope) { + if (lca != nullptr && loop == lca) { body = AttrStmt(buffer->data, attr::double_buffer_scope, 1, body); } else { double_buffer_.insert(buffer); @@ -378,10 +378,8 @@ class BufferFlattener : public StmtExprMutator { } } // Step 3. Visit the body - parent_scopes_.push_back(realize->block.get()); Block new_block = Downcast(this->VisitStmt(realize->block)); block = new_block.get(); - parent_scopes_.pop_back(); // Step 4. Transform the `predicate` to if-then-else Stmt body = block->body; if (!is_one(realize->predicate)) { @@ -422,11 +420,11 @@ class BufferFlattener : public StmtExprMutator { } } // Step 2. Visit recursively - parent_scopes_.push_back(op); + ancestor_loops_.push_back(op); Stmt body = this->VisitStmt(op->body); PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); - parent_scopes_.pop_back(); + ancestor_loops_.pop_back(); // Step 3. Add buffer allocation for (const BufferNode* buffer : alloc_buffers) { const NDIntSet& region = buffers_region_.at(GetRef(buffer)); @@ -512,7 +510,7 @@ class BufferFlattener : public StmtExprMutator { std::unordered_set pending_allocate_; std::unordered_set reduction_loop_vars_; std::unordered_set double_buffer_; - std::vector parent_scopes_; + std::vector ancestor_loops_; /*! * \brief Create a buffer with alternative shape From 7bccf62eefa990c5621e1c0beb0ebc3b8c6901fe Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 13 Mar 2021 03:28:32 +0000 Subject: [PATCH 07/20] ... --- src/tir/transforms/buffer_flatten.cc | 38 ++++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 2791a72f28..4a2065a8a9 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -220,6 +220,14 @@ class RegionGatherer : public StmtVisitor { } } + /*! \brief The used region of each Buffer */ + std::unordered_map buffers_region_; + /*! \brief The map from block vars to the expr value */ + std::unordered_map block_var_; + /*! \brief The map from unit loop vars to the expr value */ + std::unordered_map unit_loops_; + + private: void VisitStmt_(const ForNode* op) final { ancestor_loops_.push_back(op); if (!op->thread_binding.defined() && op->annotations.empty() && is_one(op->extent)) { @@ -261,14 +269,6 @@ class RegionGatherer : public StmtVisitor { VisitStmt(block->body); } - /*! \brief The used region of each Buffer */ - std::unordered_map buffers_region_; - /*! \brief The map from block vars to the expr value */ - std::unordered_map block_var_; - /*! \brief The map from unit loop vars to the expr value */ - std::unordered_map unit_loops_; - - private: PrimExpr ReplaceBlockVar(const PrimExpr& expr) const { return Substitute(Substitute(expr, block_var_), unit_loops_); } @@ -330,6 +330,7 @@ class BufferFlattener : public StmtExprMutator { } } + private: Stmt VisitStmt_(const SeqStmtNode* op) final { Array seq; seq.reserve(op->seq.size()); @@ -501,17 +502,6 @@ class BufferFlattener : public StmtExprMutator { return StmtExprMutator::VisitExpr_(op); } - private: - const std::unordered_map& buffers_region_; - const std::unordered_map& block_var_; - const std::unordered_map& unit_loops_; - const Map>& buffers_lca_; - std::unordered_set arg_buffers_; - std::unordered_set pending_allocate_; - std::unordered_set reduction_loop_vars_; - std::unordered_set double_buffer_; - std::vector ancestor_loops_; - /*! * \brief Create a buffer with alternative shape */ @@ -554,6 +544,16 @@ class BufferFlattener : public StmtExprMutator { } return result; } + + const std::unordered_map& buffers_region_; + const std::unordered_map& block_var_; + const std::unordered_map& unit_loops_; + const Map>& buffers_lca_; + std::unordered_set arg_buffers_; + std::unordered_set pending_allocate_; + std::unordered_set reduction_loop_vars_; + std::unordered_set double_buffer_; + std::vector ancestor_loops_; }; PrimFunc BufferFlatten(PrimFunc f) { From b6566b71c315c9d7ccc5a0a61abaf1842d297bc8 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 13 Mar 2021 03:39:33 +0000 Subject: [PATCH 08/20] ... --- src/tir/transforms/buffer_flatten.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 4a2065a8a9..48775198e5 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -230,7 +230,7 @@ class RegionGatherer : public StmtVisitor { private: void VisitStmt_(const ForNode* op) final { ancestor_loops_.push_back(op); - if (!op->thread_binding.defined() && op->annotations.empty() && is_one(op->extent)) { + if (is_one(op->extent)) { unit_loops_[op->loop_var.get()] = op->min; } StmtVisitor::VisitStmt_(op); From 4fb8f5b4fbab1eef58872131b79dff67ba9951c0 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 13 Mar 2021 07:46:34 +0000 Subject: [PATCH 09/20] ... --- src/tir/transforms/buffer_flatten.cc | 122 ++++++++++++++------------- 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 48775198e5..145bbe2fbb 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -208,7 +208,7 @@ class LCADetector : public StmtExprVisitor { /*! * \brief Gather the used region of each buffers. */ -class RegionGatherer : public StmtVisitor { +class RegionGatherer : public StmtExprMutator { using VarDomain = std::unordered_map; public: @@ -228,45 +228,81 @@ class RegionGatherer : public StmtVisitor { std::unordered_map unit_loops_; private: - void VisitStmt_(const ForNode* op) final { + Stmt VisitStmt_(const ForNode* op) final { ancestor_loops_.push_back(op); - if (is_one(op->extent)) { - unit_loops_[op->loop_var.get()] = op->min; + PrimExpr min = this->VisitExpr(op->min); + PrimExpr extent = this->VisitExpr(op->extent); + if (is_one(extent)) { + unit_loops_[op->loop_var.get()] = min; } - StmtVisitor::VisitStmt_(op); + For result(/*loop_var=*/op->loop_var, // + /*min=*/min, // + /*extent=*/extent, // + /*kind=*/op->kind, // + /*body=*/this->VisitStmt(op->body), // + /*thread_binding=*/op->thread_binding, // + /*annotations=*/op->annotations); ancestor_loops_.pop_back(); + return result; } - void VisitStmt_(const BlockRealizeNode* realize) final { + Stmt VisitStmt_(const BlockRealizeNode* realize) final { const auto* block = realize->block.as(); - CHECK(!block->init.defined()); + ICHECK(!block->init.defined()); // Update the mapping from block vars to loop vars so that we can substitute them - CHECK_EQ(block->iter_vars.size(), realize->binding_values.size()); + ICHECK_EQ(block->iter_vars.size(), realize->binding_values.size()); int n_block_vars = block->iter_vars.size(); for (int i = 0; i < n_block_vars; ++i) { const IterVar& iter = block->iter_vars[i]; const PrimExpr& v = realize->binding_values[i]; - block_var_[iter->var.get()] = ReplaceBlockVar(v); - } - for (const BufferRegion& read_region : block->reads) { - const Buffer& buffer = read_region->buffer; - VarDomain dom_map = LoopVarDomain(buffer); - NDIntSet region = AsRegion(read_region->region, dom_map); - NDIntSet& alloc_region = buffers_region_.at(buffer); - UnionWith(&alloc_region, region); - } - for (const BufferRegion& write_region : block->writes) { - const Buffer& buffer = write_region->buffer; - VarDomain dom_map = LoopVarDomain(buffer); - NDIntSet region = AsRegion(write_region->region, dom_map); - NDIntSet& alloc_region = buffers_region_.at(buffer); - UnionWith(&alloc_region, region); + block_var_[iter->var.get()] = this->VisitExpr(v); + } + for (const BufferRegion& buffer_region : block->reads) { + UpdateBufferRegion(buffer_region); } + for (const BufferRegion& buffer_region : block->writes) { + UpdateBufferRegion(buffer_region); + } + // Initialize the buffer region with empty region. for (const Buffer& buffer : block->alloc_buffers) { - // Initialize the buffer region with empty region. buffers_region_[buffer] = NDIntSet(buffer->shape.size(), arith::IntSet::Nothing()); } - VisitStmt(block->body); + ObjectPtr new_block = make_object(*block); + new_block->iter_vars = {}; + new_block->body = this->VisitStmt(new_block->body); + return BlockRealize(/*values=*/{}, // + /*predicate=*/this->VisitExpr(realize->predicate), // + /*block=*/Block(std::move(new_block))); + } + + PrimExpr VisitExpr_(const VarNode* var) final { + { + auto it = block_var_.find(var); + if (it != block_var_.end()) { + return it->second; + } + } + { + auto it = unit_loops_.find(var); + if (it != unit_loops_.end()) { + return it->second; + } + } + return GetRef(var); + } + + void UpdateBufferRegion(const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + VarDomain dom_map = LoopVarDomain(buffer); + NDIntSet region; + region.reserve(buffer_region->region.size()); + for (const Range& range : buffer_region->region) { + PrimExpr min = this->VisitExpr(range->min); + PrimExpr extent = this->VisitExpr(range->extent); + region.push_back(arith::EvalSet(Range::FromMinExtent(min, extent), dom_map)); + } + NDIntSet& alloc_region = buffers_region_.at(buffer); + UnionWith(&alloc_region, region); } PrimExpr ReplaceBlockVar(const PrimExpr& expr) const { @@ -291,17 +327,6 @@ class RegionGatherer : public StmtVisitor { return dom_map; } - NDIntSet AsRegion(const Array& buffer_region, const VarDomain& dom_map) const { - NDIntSet region; - region.reserve(buffer_region.size()); - for (const Range& range : buffer_region) { - PrimExpr min = ReplaceBlockVar(range->min); - PrimExpr extent = ReplaceBlockVar(range->extent); - region.push_back(arith::EvalSet(Range::FromMinExtent(min, extent), dom_map)); - } - return region; - } - /*! \brief The map from Buffer to its LCA Stmt/Expr */ const Map>& buffers_lca_; /*! \brief The loops from the current node up to the root */ @@ -314,15 +339,9 @@ class RegionGatherer : public StmtVisitor { class BufferFlattener : public StmtExprMutator { public: explicit BufferFlattener( - const std::unordered_map& block_var, - const std::unordered_map& unit_loops, const std::unordered_map& buffers_region, const Map>& buffers_lca, const PrimFunc& func) - : buffers_region_(buffers_region), - block_var_(block_var), - unit_loops_(unit_loops), - buffers_lca_(buffers_lca), - arg_buffers_{} { + : buffers_region_(buffers_region), buffers_lca_(buffers_lca), arg_buffers_{} { arg_buffers_.reserve(func->buffer_map.size()); for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; @@ -480,16 +499,6 @@ class BufferFlattener : public StmtExprMutator { return new_buffer.vload(begins, op->dtype); } - PrimExpr VisitExpr_(const VarNode* op) final { - // Replace the block var with its value - auto it = block_var_.find(op); - if (it != block_var_.end()) { - return Substitute(it->second, unit_loops_); - } else { - return Substitute(GetRef(op), unit_loops_); - } - } - PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::get_elem_offset())) { // Handle `get_elem_offset` @@ -546,8 +555,6 @@ class BufferFlattener : public StmtExprMutator { } const std::unordered_map& buffers_region_; - const std::unordered_map& block_var_; - const std::unordered_map& unit_loops_; const Map>& buffers_lca_; std::unordered_set arg_buffers_; std::unordered_set pending_allocate_; @@ -566,10 +573,9 @@ PrimFunc BufferFlatten(PrimFunc f) { // Step 2. Recalculate the buffer region Map> buffer_lca = LCADetector::Detect(f); RegionGatherer region_gatherer(buffer_lca, f); - region_gatherer(fptr->body); + fptr->body = region_gatherer(fptr->body); // Step 3. Transform BufferLoad/BufferStore into Load/Store - BufferFlattener flattener(region_gatherer.block_var_, region_gatherer.unit_loops_, - region_gatherer.buffers_region_, buffer_lca, f); + BufferFlattener flattener(region_gatherer.buffers_region_, buffer_lca, f); fptr->body = flattener(fptr->body); return f; } From 1071d62ecb9fd890c2794a562a569a7d7b0ce9fb Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 13 Mar 2021 08:09:08 +0000 Subject: [PATCH 10/20] ... --- src/tir/transforms/buffer_flatten.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 145bbe2fbb..7c131bccd1 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -227,6 +227,8 @@ class RegionGatherer : public StmtExprMutator { /*! \brief The map from unit loop vars to the expr value */ std::unordered_map unit_loops_; + std::unordered_map loop_mapping; + private: Stmt VisitStmt_(const ForNode* op) final { ancestor_loops_.push_back(op); @@ -243,6 +245,7 @@ class RegionGatherer : public StmtExprMutator { /*thread_binding=*/op->thread_binding, // /*annotations=*/op->annotations); ancestor_loops_.pop_back(); + loop_mapping.emplace(op, result.get()); return result; } @@ -567,6 +570,7 @@ PrimFunc BufferFlatten(PrimFunc f) { tvm::tir::PrimFuncNode* fptr = f.CopyOnWrite(); // Step 0. Check memory and execution hierarchy VerifyExecScope(f); + LOG(INFO) << "\n" << Repr(f); // Step 1.Transform the reduction calls to BufferStore ReductionTransformer reduction_transformer; fptr->body = reduction_transformer(fptr->body); @@ -574,9 +578,18 @@ PrimFunc BufferFlatten(PrimFunc f) { Map> buffer_lca = LCADetector::Detect(f); RegionGatherer region_gatherer(buffer_lca, f); fptr->body = region_gatherer(fptr->body); + MapNode* buf = buffer_lca.CopyOnWrite(); + for (auto& kv : *buf) { + const ForNode* loop = static_cast(kv.second.get()); + if (loop != nullptr && region_gatherer.loop_mapping.count(loop)) { + kv.second = GetRef(region_gatherer.loop_mapping.at(loop)); + } + } + LOG(INFO) << "\n" << Repr(f); // Step 3. Transform BufferLoad/BufferStore into Load/Store BufferFlattener flattener(region_gatherer.buffers_region_, buffer_lca, f); fptr->body = flattener(fptr->body); + LOG(INFO) << "\n" << Repr(f); return f; } From eec9943cf99b861b9c0d5e04d452a3f09fb12b52 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 13 Mar 2021 08:43:28 +0000 Subject: [PATCH 11/20] move var substitute to region gather --- src/tir/transforms/buffer_flatten.cc | 96 +++++++++++++--------------- 1 file changed, 46 insertions(+), 50 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 7c131bccd1..b563e513e2 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -102,6 +102,17 @@ Stmt MakeAllocStmt(const Buffer& buffer, const PrimExpr& area, Stmt body) { return body; } +Stmt MakeLaunchThread(const PrimExpr& min, const PrimExpr& extent, const Var& var, + const String& thread_tag, Stmt body) { + IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), + /*var=*/var, + /*iter_type=*/IterVarType::kThreadIndex, + /*thread_tag=*/thread_tag); + String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; + body = AttrStmt(iter_var, attr_key, extent, body); + return body; +} + class ReductionTransformer : public StmtMutator { public: Stmt VisitStmt_(const BlockNode* block) override { @@ -223,9 +234,7 @@ class RegionGatherer : public StmtExprMutator { /*! \brief The used region of each Buffer */ std::unordered_map buffers_region_; /*! \brief The map from block vars to the expr value */ - std::unordered_map block_var_; - /*! \brief The map from unit loop vars to the expr value */ - std::unordered_map unit_loops_; + std::unordered_map var_substitutes_; std::unordered_map loop_mapping; @@ -235,7 +244,7 @@ class RegionGatherer : public StmtExprMutator { PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); if (is_one(extent)) { - unit_loops_[op->loop_var.get()] = min; + var_substitutes_[op->loop_var.get()] = min; } For result(/*loop_var=*/op->loop_var, // /*min=*/min, // @@ -258,58 +267,53 @@ class RegionGatherer : public StmtExprMutator { for (int i = 0; i < n_block_vars; ++i) { const IterVar& iter = block->iter_vars[i]; const PrimExpr& v = realize->binding_values[i]; - block_var_[iter->var.get()] = this->VisitExpr(v); - } - for (const BufferRegion& buffer_region : block->reads) { - UpdateBufferRegion(buffer_region); - } - for (const BufferRegion& buffer_region : block->writes) { - UpdateBufferRegion(buffer_region); + var_substitutes_[iter->var.get()] = this->VisitExpr(v); } + Array reads = UpdateBufferRegions(block->reads); + Array writes = UpdateBufferRegions(block->writes); // Initialize the buffer region with empty region. for (const Buffer& buffer : block->alloc_buffers) { buffers_region_[buffer] = NDIntSet(buffer->shape.size(), arith::IntSet::Nothing()); } ObjectPtr new_block = make_object(*block); - new_block->iter_vars = {}; + new_block->reads = std::move(reads); + new_block->writes = std::move(writes); new_block->body = this->VisitStmt(new_block->body); - return BlockRealize(/*values=*/{}, // + return BlockRealize(/*values=*/realize->binding_values, // /*predicate=*/this->VisitExpr(realize->predicate), // /*block=*/Block(std::move(new_block))); } PrimExpr VisitExpr_(const VarNode* var) final { - { - auto it = block_var_.find(var); - if (it != block_var_.end()) { - return it->second; - } - } - { - auto it = unit_loops_.find(var); - if (it != unit_loops_.end()) { - return it->second; - } + auto it = var_substitutes_.find(var); + if (it != var_substitutes_.end()) { + return it->second; } return GetRef(var); } - void UpdateBufferRegion(const BufferRegion& buffer_region) { - const Buffer& buffer = buffer_region->buffer; - VarDomain dom_map = LoopVarDomain(buffer); - NDIntSet region; - region.reserve(buffer_region->region.size()); - for (const Range& range : buffer_region->region) { - PrimExpr min = this->VisitExpr(range->min); - PrimExpr extent = this->VisitExpr(range->extent); - region.push_back(arith::EvalSet(Range::FromMinExtent(min, extent), dom_map)); + Array UpdateBufferRegions(const Array& buffer_regions) { + Array result; + result.reserve(buffer_regions.size()); + for (const BufferRegion& buffer_region : buffer_regions) { + const Buffer& buffer = buffer_region->buffer; + VarDomain dom_map = LoopVarDomain(buffer); + int ndim = buffer_region->region.size(); + Array region; + NDIntSet int_set; + region.reserve(ndim); + int_set.reserve(ndim); + for (const Range& range : buffer_region->region) { + Range new_range = + Range::FromMinExtent(this->VisitExpr(range->min), this->VisitExpr(range->extent)); + region.push_back(new_range); + int_set.push_back(arith::EvalSet(new_range, dom_map)); + } + NDIntSet& alloc_region = buffers_region_.at(buffer); + UnionWith(&alloc_region, int_set); + result.push_back(BufferRegion(buffer_region->buffer, region)); } - NDIntSet& alloc_region = buffers_region_.at(buffer); - UnionWith(&alloc_region, region); - } - - PrimExpr ReplaceBlockVar(const PrimExpr& expr) const { - return Substitute(Substitute(expr, block_var_), unit_loops_); + return result; } VarDomain LoopVarDomain(const Buffer& buffer) const { @@ -457,15 +461,10 @@ class BufferFlattener : public StmtExprMutator { // Step 4. Add the for loop accordingly if (op->kind == ForKind::kThreadBinding) { // Case 1. Thread binding - ICHECK(op->thread_binding.defined()); - String thread_tag = op->thread_binding.value()->thread_tag; if (!reduction_loop_vars_.count(op->loop_var.get())) { - IterVar iter_var(/*dom=*/Range(min, extent), - /*var=*/op->loop_var, - /*iter_type=*/IterVarType::kThreadIndex, - /*thread_tag=*/thread_tag); - String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; - body = AttrStmt(iter_var, attr_key, extent, body); + ICHECK(op->thread_binding.defined()); + String thread_tag = op->thread_binding.value()->thread_tag; + body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); } } else if (is_one(extent) && op->annotations.empty()) { // Case 2. Handle unit loop @@ -570,7 +569,6 @@ PrimFunc BufferFlatten(PrimFunc f) { tvm::tir::PrimFuncNode* fptr = f.CopyOnWrite(); // Step 0. Check memory and execution hierarchy VerifyExecScope(f); - LOG(INFO) << "\n" << Repr(f); // Step 1.Transform the reduction calls to BufferStore ReductionTransformer reduction_transformer; fptr->body = reduction_transformer(fptr->body); @@ -585,11 +583,9 @@ PrimFunc BufferFlatten(PrimFunc f) { kv.second = GetRef(region_gatherer.loop_mapping.at(loop)); } } - LOG(INFO) << "\n" << Repr(f); // Step 3. Transform BufferLoad/BufferStore into Load/Store BufferFlattener flattener(region_gatherer.buffers_region_, buffer_lca, f); fptr->body = flattener(fptr->body); - LOG(INFO) << "\n" << Repr(f); return f; } From 9c80f2a02a9f388468164e98fb4e31c7c7f8c8dc Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 13 Mar 2021 21:52:57 +0000 Subject: [PATCH 12/20] ... --- src/tir/transforms/buffer_flatten.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index b563e513e2..b090e59de8 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -233,8 +233,6 @@ class RegionGatherer : public StmtExprMutator { /*! \brief The used region of each Buffer */ std::unordered_map buffers_region_; - /*! \brief The map from block vars to the expr value */ - std::unordered_map var_substitutes_; std::unordered_map loop_mapping; @@ -338,6 +336,8 @@ class RegionGatherer : public StmtExprMutator { const Map>& buffers_lca_; /*! \brief The loops from the current node up to the root */ std::vector ancestor_loops_; + /*! \brief The map from block vars to the expr value */ + std::unordered_map var_substitutes_; }; /*! From 38bb7253a2f41106333bbcf8562a8877074c93fe Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 14 Mar 2021 07:07:16 +0000 Subject: [PATCH 13/20] ... --- src/tir/transforms/buffer_flatten.cc | 449 +++++++++++++++------------ 1 file changed, 244 insertions(+), 205 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index b090e59de8..557a737cc1 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -60,6 +60,19 @@ PrimExpr NDIntSetArea(const NDIntSet& nd_int_set) { return area; } +Buffer NDIntSet2Buffer(const BufferNode* buffer, const NDIntSet& nd_int_set) { + Integer one(1); + Array shape; + shape.reserve(nd_int_set.size()); + for (const arith::IntSet& int_set : nd_int_set) { + PrimExpr extent = int_set.max() - int_set.min() + one; + shape.push_back(extent); + } + ObjectPtr new_buffer = make_object(*buffer); + new_buffer->shape = std::move(shape); + return Buffer(std::move(new_buffer)); +} + NDIntSet NDIntSetFromShape(const Array& shape) { NDIntSet result; for (const PrimExpr& extent : shape) { @@ -68,7 +81,7 @@ NDIntSet NDIntSetFromShape(const Array& shape) { return result; } -bool IsThreadBinded(const ForNode* loop) { +bool IsThreadBound(const ForNode* loop) { if (loop->kind != ForKind::kThreadBinding) { return false; } @@ -113,6 +126,14 @@ Stmt MakeLaunchThread(const PrimExpr& min, const PrimExpr& extent, const Var& va return body; } +PrimExpr BufferArea(const Buffer& buffer) { + PrimExpr area = Integer(1); + for (const PrimExpr& dim : buffer->shape) { + area = area * dim; + } + return area; +} + class ReductionTransformer : public StmtMutator { public: Stmt VisitStmt_(const BlockNode* block) override { @@ -216,6 +237,63 @@ class LCADetector : public StmtExprVisitor { std::unordered_map buffers_lca_ = {}; }; +class BufferAccessUpdater : public StmtExprMutator { + public: + static Stmt Update( + const std::unordered_map>& buffers_offsets, + const std::unordered_map& buffer_allocated, Stmt body) { + BufferAccessUpdater updater(buffers_offsets, buffer_allocated); + return updater.VisitStmt(body); + } + + private: + explicit BufferAccessUpdater( + const std::unordered_map>& buffers_offsets, + const std::unordered_map& buffer_allocated) + : buffers_offsets_(buffers_offsets), buffer_allocated_(buffer_allocated) {} + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + op = store.get(); + const BufferNode* old_buffer = op->buffer.get(); + const BufferNode* new_buffer = FindNewBuffer(old_buffer); + Array begins = ComputeRelativeIndices(old_buffer, op->indices); + return GetRef(new_buffer).vstore(begins, op->value); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + op = load.get(); + const BufferNode* old_buffer = op->buffer.get(); + const BufferNode* new_buffer = FindNewBuffer(old_buffer); + Array begins = ComputeRelativeIndices(old_buffer, op->indices); + return GetRef(new_buffer).vload(begins, op->dtype); + } + + const BufferNode* FindNewBuffer(const BufferNode* buffer) const { + auto it = buffer_allocated_.find(buffer); + ICHECK(it != buffer_allocated_.end()); + return it->second; + } + + Array ComputeRelativeIndices(const BufferNode* buffer, const Array& indices) { + auto it = buffers_offsets_.find(buffer); + ICHECK(it != buffers_offsets_.end()); + const std::vector& offsets = it->second; + ICHECK_EQ(offsets.size(), indices.size()); + int ndim = offsets.size(); + Array result; + result.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + result.push_back(indices[i] - offsets[i]); + } + return result; + } + + const std::unordered_map>& buffers_offsets_; + const std::unordered_map& buffer_allocated_; +}; + /*! * \brief Gather the used region of each buffers. */ @@ -223,63 +301,158 @@ class RegionGatherer : public StmtExprMutator { using VarDomain = std::unordered_map; public: - RegionGatherer(const Map>& buffers_lca, const PrimFunc& f) - : buffers_lca_(buffers_lca) { + Stmt Gather(const Map>& buffers_lca, const PrimFunc& f) { + for (const auto& kv : buffers_lca) { + const BufferNode* buffer = kv.first.get(); + const ForNode* loop = static_cast(kv.second.get()); + this->buffers_lca_.emplace(buffer, loop); + this->buffer_alloc_[loop].push_back(buffer); + } for (const auto& arg : f->buffer_map) { - const Buffer& buffer = arg.second; - buffers_region_[buffer] = NDIntSetFromShape(buffer->shape); + const BufferNode* buffer = arg.second.get(); + int ndim = buffer->shape.size(); + buffers_region_.emplace(buffer, NDIntSetFromShape(buffer->shape)); + buffer_allocated_.emplace(buffer, buffer); + buffer_offsets_.emplace(buffer, std::vector(ndim, Integer(0))); } + return BufferAccessUpdater::Update(buffer_offsets_, buffer_allocated_, + this->VisitStmt(f->body)); } + public: /*! \brief The used region of each Buffer */ - std::unordered_map buffers_region_; - - std::unordered_map loop_mapping; + std::unordered_map buffers_region_; + std::unordered_map buffer_allocated_; + std::unordered_map> buffer_offsets_; private: - Stmt VisitStmt_(const ForNode* op) final { - ancestor_loops_.push_back(op); - PrimExpr min = this->VisitExpr(op->min); - PrimExpr extent = this->VisitExpr(op->extent); + Array AllocBufers(const ForNode* loop) { + auto it = buffer_alloc_.find(loop); + if (it == buffer_alloc_.end()) { + return {}; + } + const std::vector& buffers = it->second; + Array result; + result.reserve(buffers.size()); + for (const BufferNode* buffer : buffers) { + auto it = buffers_region_.find(buffer); + ICHECK(it != buffers_region_.end()); + const NDIntSet& nd_int_set = it->second; + Buffer allocated = NDIntSet2Buffer(buffer, nd_int_set); + buffer_allocated_.emplace(buffer, allocated.get()); + result.push_back(allocated); + std::vector offsets; + offsets.reserve(nd_int_set.size()); + for (const arith::IntSet& int_set : nd_int_set) { + offsets.push_back(int_set.min()); + } + buffer_offsets_.emplace(buffer, std::move(offsets)); + } + return result; + } + + Stmt VisitStmt_(const ForNode* loop) final { + // Step 1. Handle block vars in `min` and `extent` + PrimExpr min = this->VisitExpr(loop->min); + PrimExpr extent = this->VisitExpr(loop->extent); + // Step 2. Handle unit loops if (is_one(extent)) { - var_substitutes_[op->loop_var.get()] = min; + var_substitutes_[loop->loop_var.get()] = min; } - For result(/*loop_var=*/op->loop_var, // - /*min=*/min, // - /*extent=*/extent, // - /*kind=*/op->kind, // - /*body=*/this->VisitStmt(op->body), // - /*thread_binding=*/op->thread_binding, // - /*annotations=*/op->annotations); + // Step 3. Visit recursively + ancestor_loops_.push_back(loop); + Stmt body = this->VisitStmt(loop->body); ancestor_loops_.pop_back(); - loop_mapping.emplace(op, result.get()); - return result; + // Step 4. Add allocation + Array alloc_buffers = AllocBufers(loop); + if (!alloc_buffers.empty()) { + body = BlockRealize(/*binding_values=*/{}, + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, // + /*reads=*/{}, // + /*writes=*/{}, // + /*alloc_buffers=*/std::move(alloc_buffers), // + /*annotations=*/{}, // + /*match_buffers=*/{}, // + /*exec_scope=*/"", // + /*name_hint=*/"alloc", // + /*body=*/std::move(body), // + /*init=*/NullOpt)); + } + // Step 5. Make the new loop + if (loop->kind == ForKind::kThreadBinding && reduction_loop_vars_.count(loop->loop_var.get())) { + // do nothing, because the loop is going to be removed + } else { + body = For(/*loop_var=*/loop->loop_var, + /*min=*/min, + /*extent=*/extent, + /*kind=*/loop->kind, + /*body=*/std::move(body), + /*thread_binding=*/loop->thread_binding, + /*annotations=*/loop->annotations); + } + return body; } Stmt VisitStmt_(const BlockRealizeNode* realize) final { - const auto* block = realize->block.as(); + const auto* block = realize->block.get(); ICHECK(!block->init.defined()); - // Update the mapping from block vars to loop vars so that we can substitute them + // Step 1. Update "block vars => loop vars" for substitution, add reduction loop vars ICHECK_EQ(block->iter_vars.size(), realize->binding_values.size()); - int n_block_vars = block->iter_vars.size(); - for (int i = 0; i < n_block_vars; ++i) { - const IterVar& iter = block->iter_vars[i]; - const PrimExpr& v = realize->binding_values[i]; - var_substitutes_[iter->var.get()] = this->VisitExpr(v); + for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { + const IterVar& block_var = block->iter_vars[i]; + PrimExpr v = this->VisitExpr(realize->binding_values[i]); + var_substitutes_.emplace(block_var->var.get(), v); + if (block_var->iter_type == kCommReduce) { + for (const VarNode* var : Vars(v)) { + this->reduction_loop_vars_.insert(var); + } + } } - Array reads = UpdateBufferRegions(block->reads); - Array writes = UpdateBufferRegions(block->writes); - // Initialize the buffer region with empty region. + // Step 2. Initialize the buffer region with empty region for (const Buffer& buffer : block->alloc_buffers) { - buffers_region_[buffer] = NDIntSet(buffer->shape.size(), arith::IntSet::Nothing()); + buffers_region_.emplace(buffer.get(), + NDIntSet(buffer->shape.size(), arith::IntSet::Nothing())); } - ObjectPtr new_block = make_object(*block); - new_block->reads = std::move(reads); - new_block->writes = std::move(writes); - new_block->body = this->VisitStmt(new_block->body); - return BlockRealize(/*values=*/realize->binding_values, // - /*predicate=*/this->VisitExpr(realize->predicate), // - /*block=*/Block(std::move(new_block))); + // Step 3. Visit recursively + ++block_nest_depth_; + Stmt body = this->VisitStmt(block->body); + --block_nest_depth_; + // Step 4. Update the read/write buffer regions + Array reads = VisitBufferRegions(block->reads); + Array writes = VisitBufferRegions(block->writes); + // Step 5. Handle predicate + PrimExpr predicate = this->VisitExpr(realize->predicate); + // Step 6. Root allocation + Array alloc_buffers = (block_nest_depth_ == 0) ? AllocBufers(nullptr) : Array{}; + // Step 7. Create new blocks + return BlockRealize(/*binding_values=*/{}, + /*predicate=*/std::move(predicate), + /*block=*/ + Block(/*iter_vars=*/{}, // + /*reads=*/std::move(reads), // + /*writes=*/std::move(writes), // + /*alloc_buffers=*/std::move(alloc_buffers), // + /*annotations=*/block->annotations, // + /*match_buffers=*/block->match_buffers, // + /*exec_scope=*/block->exec_scope, // + /*name_hint=*/block->name_hint, // + /*body=*/std::move(body), // + /*init=*/NullOpt)); + } + + PrimExpr VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::get_elem_offset())) { + // Handle `get_elem_offset` + ICHECK_EQ(op->args.size(), 1); + PrimExpr arg = op->args[0]; + ICHECK(arg->IsInstance()); + arg = this->VisitExpr(arg); + const auto* load = TVM_TYPE_AS(load, arg, LoadNode); + return load->index; + } + return StmtExprMutator::VisitExpr_(op); } PrimExpr VisitExpr_(const VarNode* var) final { @@ -290,7 +463,7 @@ class RegionGatherer : public StmtExprMutator { return GetRef(var); } - Array UpdateBufferRegions(const Array& buffer_regions) { + Array VisitBufferRegions(const Array& buffer_regions) { Array result; result.reserve(buffer_regions.size()); for (const BufferRegion& buffer_region : buffer_regions) { @@ -307,7 +480,9 @@ class RegionGatherer : public StmtExprMutator { region.push_back(new_range); int_set.push_back(arith::EvalSet(new_range, dom_map)); } - NDIntSet& alloc_region = buffers_region_.at(buffer); + auto it = buffers_region_.find(buffer.get()); + ICHECK(it != buffers_region_.end()); + NDIntSet& alloc_region = it->second; UnionWith(&alloc_region, int_set); result.push_back(BufferRegion(buffer_region->buffer, region)); } @@ -315,25 +490,30 @@ class RegionGatherer : public StmtExprMutator { } VarDomain LoopVarDomain(const Buffer& buffer) const { - VarDomain dom_map; - const Optional& lca = this->buffers_lca_.at(buffer); + auto it = this->buffers_lca_.find(buffer.get()); + ICHECK(it != this->buffers_lca_.end()); + const ForNode* lca = it->second; // Every loop will be relaxed if the lca is the root - bool need_relax = !lca.defined(); + VarDomain dom_map; + bool need_relax = (lca == nullptr); for (const ForNode* loop : this->ancestor_loops_) { const VarNode* loop_var = loop->loop_var.get(); // TODO - if (need_relax || (buffer->scope == "shared" && IsThreadBinded(loop))) { + if (need_relax || (buffer->scope == "shared" && IsThreadBound(loop))) { dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); } - if (loop == lca.get()) { + if (loop == lca) { need_relax = true; } } return dom_map; } + int block_nest_depth_ = 0; /*! \brief The map from Buffer to its LCA Stmt/Expr */ - const Map>& buffers_lca_; + std::unordered_map buffers_lca_; + std::unordered_map> buffer_alloc_; + std::unordered_set reduction_loop_vars_; /*! \brief The loops from the current node up to the root */ std::vector ancestor_loops_; /*! \brief The map from block vars to the expr value */ @@ -344,18 +524,6 @@ class RegionGatherer : public StmtExprMutator { * \brief Transform multi-dimension BufferLoad/BufferStore into one-dimension Load/Store */ class BufferFlattener : public StmtExprMutator { - public: - explicit BufferFlattener( - const std::unordered_map& buffers_region, - const Map>& buffers_lca, const PrimFunc& func) - : buffers_region_(buffers_region), buffers_lca_(buffers_lca), arg_buffers_{} { - arg_buffers_.reserve(func->buffer_map.size()); - for (const auto& kv : func->buffer_map) { - const Buffer& buffer = kv.second; - arg_buffers_.insert(buffer.get()); - } - } - private: Stmt VisitStmt_(const SeqStmtNode* op) final { Array seq; @@ -367,12 +535,13 @@ class BufferFlattener : public StmtExprMutator { std::swap(double_buffer, double_buffer_); const ForNode* loop = ancestor_loops_.back(); for (const BufferNode* buffer : double_buffer) { - const Object* lca = buffers_lca_.at(GetRef(buffer)).get(); - if (lca != nullptr && loop == lca) { - body = AttrStmt(buffer->data, attr::double_buffer_scope, 1, body); - } else { - double_buffer_.insert(buffer); - } + // TODO + // const Object* lca = buffers_lca_.at(GetRef(buffer)).get(); + // if (lca != nullptr && loop == lca) { + // body = AttrStmt(buffer->data, attr::double_buffer_scope, 1, body); + // } else { + // double_buffer_.insert(buffer); + // } } seq.push_back(body); } @@ -380,33 +549,9 @@ class BufferFlattener : public StmtExprMutator { } Stmt VisitStmt_(const BlockRealizeNode* realize) final { - // Handle allocations - const auto* block = realize->block.get(); - // Step 1. Add non-root block allocations into `pending_allocate_` - for (const Buffer& buffer : block->alloc_buffers) { - if (IsReduceTempBuffer(buffer)) { - continue; - } - if (buffers_lca_.at(buffer).defined()) { - pending_allocate_.insert(buffer.get()); - } - } - // Step 2. Add reduction loop vars - CHECK_EQ(block->iter_vars.size(), realize->binding_values.size()); - int n_block_vars = block->iter_vars.size(); - for (int i = 0; i < n_block_vars; ++i) { - const IterVar& block_var = block->iter_vars[i]; - const PrimExpr& binding_value = realize->binding_values[i]; - if (block_var->iter_type == kCommReduce) { - std::unordered_set vars = Vars(binding_value); - for (const VarNode* var : vars) { - this->reduction_loop_vars_.insert(var); - } - } - } // Step 3. Visit the body Block new_block = Downcast(this->VisitStmt(realize->block)); - block = new_block.get(); + const BlockNode* block = new_block.get(); // Step 4. Transform the `predicate` to if-then-else Stmt body = block->body; if (!is_one(realize->predicate)) { @@ -424,48 +569,26 @@ class BufferFlattener : public StmtExprMutator { } } } - // Step 6. Add root block allocations + // Step 6. Handle allocations for (const Buffer& buffer : block->alloc_buffers) { - if (IsReduceTempBuffer(buffer)) { - continue; - } - if (!buffers_lca_.at(buffer).defined()) { - const NDIntSet& region = buffers_region_.at(buffer); - body = MakeAllocStmt(buffer, NDIntSetArea(region), body); - } + body = MakeAllocStmt(buffer, BufferArea(buffer), body); } return body; } Stmt VisitStmt_(const ForNode* op) final { - // Step 1. Find the buffer that can be allocated under the current loop - std::vector alloc_buffers; - for (const BufferNode* buffer : pending_allocate_) { - const Optional alloc_site = buffers_lca_.at(GetRef(buffer)); - if (op == alloc_site.get()) { - alloc_buffers.push_back(buffer); - } - } // Step 2. Visit recursively ancestor_loops_.push_back(op); Stmt body = this->VisitStmt(op->body); PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); ancestor_loops_.pop_back(); - // Step 3. Add buffer allocation - for (const BufferNode* buffer : alloc_buffers) { - const NDIntSet& region = buffers_region_.at(GetRef(buffer)); - body = MakeAllocStmt(GetRef(buffer), NDIntSetArea(region), body); - pending_allocate_.erase(buffer); - } // Step 4. Add the for loop accordingly if (op->kind == ForKind::kThreadBinding) { // Case 1. Thread binding - if (!reduction_loop_vars_.count(op->loop_var.get())) { - ICHECK(op->thread_binding.defined()); - String thread_tag = op->thread_binding.value()->thread_tag; - body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); - } + ICHECK(op->thread_binding.defined()); + String thread_tag = op->thread_binding.value()->thread_tag; + body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); } else if (is_one(extent) && op->annotations.empty()) { // Case 2. Handle unit loop return body; @@ -484,83 +607,6 @@ class BufferFlattener : public StmtExprMutator { return body; } - Stmt VisitStmt_(const BufferStoreNode* op) final { - const Buffer& buffer = op->buffer; - std::vector indices = VisitIndices(op->indices); - PrimExpr value = this->VisitExpr(op->value); - std::vector begins = ComputeRelativeIndices(buffer, indices); - Buffer new_buffer = ReshapeBuffer(buffer, this->buffers_region_.at(buffer)); - return new_buffer.vstore(begins, value); - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - const Buffer& buffer = op->buffer; - std::vector indices = VisitIndices(op->indices); - std::vector begins = ComputeRelativeIndices(buffer, indices); - Buffer new_buffer = ReshapeBuffer(buffer, this->buffers_region_.at(buffer)); - return new_buffer.vload(begins, op->dtype); - } - - PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::get_elem_offset())) { - // Handle `get_elem_offset` - ICHECK_EQ(op->args.size(), 1); - const PrimExpr& arg = op->args[0]; - ICHECK(arg->IsInstance()); - Load load = Downcast(VisitExpr(arg)); - return load->index; - } - return StmtExprMutator::VisitExpr_(op); - } - - /*! - * \brief Create a buffer with alternative shape - */ - Buffer ReshapeBuffer(const Buffer& buffer, const NDIntSet& region) { - if (arg_buffers_.count(buffer.get())) { - return buffer; - } - Array shape; - for (const arith::IntSet& i : region) { - shape.push_back(i.max() - i.min() + 1); - } - ObjectPtr n = make_object(*buffer.get()); - n->shape = std::move(shape); - return Buffer(std::move(n)); - } - - /*! - * \brief Transform indices from the absolute indices to relative indices - * \note T can be BufferLoad or BufferStore - */ - std::vector ComputeRelativeIndices(const Buffer& buffer, - const Array& indices) { - const NDIntSet& region = buffers_region_.at(buffer); - std::vector new_indices; - for (size_t i = 0; i < region.size(); ++i) { - if (arg_buffers_.count(buffer.get())) { - new_indices.push_back(indices[i]); - } else { - new_indices.push_back(indices[i] - region[i].min()); - } - } - return new_indices; - } - - std::vector VisitIndices(const Array& indices) { - std::vector result; - result.reserve(indices.size()); - for (const PrimExpr& index : indices) { - result.push_back(this->VisitExpr(index)); - } - return result; - } - - const std::unordered_map& buffers_region_; - const Map>& buffers_lca_; - std::unordered_set arg_buffers_; - std::unordered_set pending_allocate_; - std::unordered_set reduction_loop_vars_; std::unordered_set double_buffer_; std::vector ancestor_loops_; }; @@ -574,17 +620,10 @@ PrimFunc BufferFlatten(PrimFunc f) { fptr->body = reduction_transformer(fptr->body); // Step 2. Recalculate the buffer region Map> buffer_lca = LCADetector::Detect(f); - RegionGatherer region_gatherer(buffer_lca, f); - fptr->body = region_gatherer(fptr->body); - MapNode* buf = buffer_lca.CopyOnWrite(); - for (auto& kv : *buf) { - const ForNode* loop = static_cast(kv.second.get()); - if (loop != nullptr && region_gatherer.loop_mapping.count(loop)) { - kv.second = GetRef(region_gatherer.loop_mapping.at(loop)); - } - } + RegionGatherer region_gatherer; + fptr->body = region_gatherer.Gather(buffer_lca, f); // Step 3. Transform BufferLoad/BufferStore into Load/Store - BufferFlattener flattener(region_gatherer.buffers_region_, buffer_lca, f); + BufferFlattener flattener; fptr->body = flattener(fptr->body); return f; } From 7d907d28111f236f7bc98fc2b5559903717663d8 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 15 Mar 2021 21:48:01 +0000 Subject: [PATCH 14/20] ... --- src/tir/transforms/buffer_flatten.cc | 327 ++++++++++++++------------- 1 file changed, 169 insertions(+), 158 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 557a737cc1..f7d0a52b23 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -43,7 +43,7 @@ arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); } -void UnionWith(NDIntSet* lhs, const NDIntSet& rhs) { +void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) { ICHECK_EQ(lhs->size(), rhs.size()); int ndim = rhs.size(); for (int i = 0; i < ndim; ++i) { @@ -52,25 +52,16 @@ void UnionWith(NDIntSet* lhs, const NDIntSet& rhs) { } } -PrimExpr NDIntSetArea(const NDIntSet& nd_int_set) { - PrimExpr area = 1; - for (const arith::IntSet& int_set : nd_int_set) { - area = area * (int_set.max() - int_set.min() + 1); - } - return area; -} - -Buffer NDIntSet2Buffer(const BufferNode* buffer, const NDIntSet& nd_int_set) { +Array NDIntSet2Region(const NDIntSet& nd_int_set) { Integer one(1); - Array shape; - shape.reserve(nd_int_set.size()); + Array result; + result.reserve(nd_int_set.size()); for (const arith::IntSet& int_set : nd_int_set) { - PrimExpr extent = int_set.max() - int_set.min() + one; - shape.push_back(extent); + PrimExpr min = int_set.min(); + PrimExpr max = int_set.max(); + result.push_back(Range(/*begin=*/min, /*end=*/max + one)); } - ObjectPtr new_buffer = make_object(*buffer); - new_buffer->shape = std::move(shape); - return Buffer(std::move(new_buffer)); + return result; } NDIntSet NDIntSetFromShape(const Array& shape) { @@ -81,7 +72,11 @@ NDIntSet NDIntSetFromShape(const Array& shape) { return result; } -bool IsThreadBound(const ForNode* loop) { +NDIntSet NDIntSetEmpty(int ndim) { + return std::vector(ndim, arith::IntSet::Nothing()); +} + +bool IsThreadBound(const For& loop) { if (loop->kind != ForKind::kThreadBinding) { return false; } @@ -160,12 +155,12 @@ class LCADetector : public StmtExprVisitor { // Buffers, who appear as arguments, do not have allocation sites for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; - detector.buffers_lca_.emplace(buffer.get(), nullptr); + detector.buffer_lca_.emplace(buffer.get(), nullptr); } detector(func->body); // Prepare the return Map> buffer_lca; - for (const auto& kv : detector.buffers_lca_) { + for (const auto& kv : detector.buffer_lca_) { buffer_lca.Set(GetRef(kv.first), GetRef>(kv.second)); } return buffer_lca; @@ -191,7 +186,7 @@ class LCADetector : public StmtExprVisitor { } void CalcBufferLCA(const BufferNode* buffer) { - const ForNode*& lca = buffers_lca_[buffer]; + const ForNode*& lca = buffer_lca_[buffer]; lca = LowestCommonAncestor(lca, ancestor_loops_.back()); } @@ -234,122 +229,103 @@ class LCADetector : public StmtExprVisitor { /*! \brief The parent and depth info of each Loop/BufferLoad/BufferStore Node */ std::unordered_map for_info_ = {}; /*! \brief The map from Buffer to its LCA Stmt/Expr */ - std::unordered_map buffers_lca_ = {}; + std::unordered_map buffer_lca_ = {}; }; -class BufferAccessUpdater : public StmtExprMutator { +class BufferAccessRewriter : public StmtExprMutator { public: - static Stmt Update( - const std::unordered_map>& buffers_offsets, - const std::unordered_map& buffer_allocated, Stmt body) { - BufferAccessUpdater updater(buffers_offsets, buffer_allocated); - return updater.VisitStmt(body); + using FRewriteBufferAccess = std::function>( + const Buffer& buffer, const Array& indices)>; + + static Stmt Rewrite(Stmt stmt, const FRewriteBufferAccess& f_rewrite) { + BufferAccessRewriter rewriter(f_rewrite); + return rewriter.VisitStmt(stmt); } private: - explicit BufferAccessUpdater( - const std::unordered_map>& buffers_offsets, - const std::unordered_map& buffer_allocated) - : buffers_offsets_(buffers_offsets), buffer_allocated_(buffer_allocated) {} + explicit BufferAccessRewriter(const FRewriteBufferAccess& f_rewrite) : f_rewrite_(f_rewrite) {} Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); op = store.get(); - const BufferNode* old_buffer = op->buffer.get(); - const BufferNode* new_buffer = FindNewBuffer(old_buffer); - Array begins = ComputeRelativeIndices(old_buffer, op->indices); - return GetRef(new_buffer).vstore(begins, op->value); + Buffer new_buffer{nullptr}; + Array new_indices{nullptr}; + std::tie(new_buffer, new_indices) = f_rewrite_(op->buffer, op->indices); + return new_buffer.vstore(new_indices, op->value); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); op = load.get(); - const BufferNode* old_buffer = op->buffer.get(); - const BufferNode* new_buffer = FindNewBuffer(old_buffer); - Array begins = ComputeRelativeIndices(old_buffer, op->indices); - return GetRef(new_buffer).vload(begins, op->dtype); - } - - const BufferNode* FindNewBuffer(const BufferNode* buffer) const { - auto it = buffer_allocated_.find(buffer); - ICHECK(it != buffer_allocated_.end()); - return it->second; - } - - Array ComputeRelativeIndices(const BufferNode* buffer, const Array& indices) { - auto it = buffers_offsets_.find(buffer); - ICHECK(it != buffers_offsets_.end()); - const std::vector& offsets = it->second; - ICHECK_EQ(offsets.size(), indices.size()); - int ndim = offsets.size(); - Array result; - result.reserve(ndim); - for (int i = 0; i < ndim; ++i) { - result.push_back(indices[i] - offsets[i]); - } - return result; + Buffer new_buffer{nullptr}; + Array new_indices{nullptr}; + std::tie(new_buffer, new_indices) = f_rewrite_(op->buffer, op->indices); + return new_buffer.vload(new_indices, op->dtype); } - const std::unordered_map>& buffers_offsets_; - const std::unordered_map& buffer_allocated_; + const FRewriteBufferAccess& f_rewrite_; }; /*! * \brief Gather the used region of each buffers. */ class RegionGatherer : public StmtExprMutator { - using VarDomain = std::unordered_map; + template + using SMap = std::unordered_map; + template + using SSet = std::unordered_set; + + struct BufferInfo { + NDIntSet accessed_region; + Optional alloc_site; + Array region; + Buffer new_buffer; + + explicit BufferInfo(int ndim, Optional alloc_site) + : accessed_region(NDIntSetEmpty(ndim)), // + alloc_site(std::move(alloc_site)), + region{nullptr}, + new_buffer{nullptr} {} + }; public: - Stmt Gather(const Map>& buffers_lca, const PrimFunc& f) { - for (const auto& kv : buffers_lca) { - const BufferNode* buffer = kv.first.get(); - const ForNode* loop = static_cast(kv.second.get()); - this->buffers_lca_.emplace(buffer, loop); - this->buffer_alloc_[loop].push_back(buffer); - } - for (const auto& arg : f->buffer_map) { - const BufferNode* buffer = arg.second.get(); + static Stmt Gather(const PrimFunc& f) { + Map> buffer_lca = LCADetector::Detect(f); + SMap buffer_info; + SMap, Array> loop_allocs; + buffer_info.reserve(buffer_lca.size()); + loop_allocs.reserve(buffer_lca.size()); + for (const auto& kv : buffer_lca) { + const Buffer& buffer = kv.first; + const Optional& alloc_site = kv.second; int ndim = buffer->shape.size(); - buffers_region_.emplace(buffer, NDIntSetFromShape(buffer->shape)); - buffer_allocated_.emplace(buffer, buffer); - buffer_offsets_.emplace(buffer, std::vector(ndim, Integer(0))); + buffer_info.emplace(buffer, BufferInfo(ndim, alloc_site)); + loop_allocs[alloc_site].push_back(buffer); } - return BufferAccessUpdater::Update(buffer_offsets_, buffer_allocated_, - this->VisitStmt(f->body)); + for (const auto& kv : f->buffer_map) { + const Buffer& buffer = kv.second; + ICHECK(buffer_info.count(buffer)); + BufferInfo& info = buffer_info.at(buffer); + info.accessed_region = NDIntSetFromShape(buffer->shape); + } + RegionGatherer gatherer(std::move(buffer_info), std::move(loop_allocs)); + return BufferAccessRewriter::Rewrite( + /*stmt=*/gatherer.VisitStmt(f->body), + /*f_rewrite=*/std::bind(&RegionGatherer::RewriteBufferAccess, // + &gatherer, // + std::placeholders::_1, // + std::placeholders::_2)); } - public: - /*! \brief The used region of each Buffer */ - std::unordered_map buffers_region_; - std::unordered_map buffer_allocated_; - std::unordered_map> buffer_offsets_; - private: - Array AllocBufers(const ForNode* loop) { - auto it = buffer_alloc_.find(loop); - if (it == buffer_alloc_.end()) { - return {}; - } - const std::vector& buffers = it->second; - Array result; - result.reserve(buffers.size()); - for (const BufferNode* buffer : buffers) { - auto it = buffers_region_.find(buffer); - ICHECK(it != buffers_region_.end()); - const NDIntSet& nd_int_set = it->second; - Buffer allocated = NDIntSet2Buffer(buffer, nd_int_set); - buffer_allocated_.emplace(buffer, allocated.get()); - result.push_back(allocated); - std::vector offsets; - offsets.reserve(nd_int_set.size()); - for (const arith::IntSet& int_set : nd_int_set) { - offsets.push_back(int_set.min()); - } - buffer_offsets_.emplace(buffer, std::move(offsets)); - } - return result; - } + explicit RegionGatherer(SMap buffer_info, + SMap, Array> loop_allocs) + : block_nest_depth_(0), + buffer_info_(std::move(buffer_info)), + loop_allocs_(std::move(loop_allocs)), + ancestor_loops_{}, + var_substitutes_{}, + reduction_loop_vars_{} {} Stmt VisitStmt_(const ForNode* loop) final { // Step 1. Handle block vars in `min` and `extent` @@ -357,14 +333,14 @@ class RegionGatherer : public StmtExprMutator { PrimExpr extent = this->VisitExpr(loop->extent); // Step 2. Handle unit loops if (is_one(extent)) { - var_substitutes_[loop->loop_var.get()] = min; + var_substitutes_[loop->loop_var] = min; } // Step 3. Visit recursively - ancestor_loops_.push_back(loop); + ancestor_loops_.push_back(GetRef(loop)); Stmt body = this->VisitStmt(loop->body); ancestor_loops_.pop_back(); // Step 4. Add allocation - Array alloc_buffers = AllocBufers(loop); + Array alloc_buffers = AllocBufferUnderLoop(GetRef(loop)); if (!alloc_buffers.empty()) { body = BlockRealize(/*binding_values=*/{}, /*predicate=*/const_true(), @@ -381,7 +357,7 @@ class RegionGatherer : public StmtExprMutator { /*init=*/NullOpt)); } // Step 5. Make the new loop - if (loop->kind == ForKind::kThreadBinding && reduction_loop_vars_.count(loop->loop_var.get())) { + if (loop->kind == ForKind::kThreadBinding && reduction_loop_vars_.count(loop->loop_var)) { // do nothing, because the loop is going to be removed } else { body = For(/*loop_var=*/loop->loop_var, @@ -401,32 +377,28 @@ class RegionGatherer : public StmtExprMutator { // Step 1. Update "block vars => loop vars" for substitution, add reduction loop vars ICHECK_EQ(block->iter_vars.size(), realize->binding_values.size()); for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { - const IterVar& block_var = block->iter_vars[i]; + IterVar block_var = block->iter_vars[i]; PrimExpr v = this->VisitExpr(realize->binding_values[i]); - var_substitutes_.emplace(block_var->var.get(), v); + var_substitutes_.emplace(block_var->var, v); if (block_var->iter_type == kCommReduce) { for (const VarNode* var : Vars(v)) { - this->reduction_loop_vars_.insert(var); + this->reduction_loop_vars_.insert(GetRef(var)); } } } - // Step 2. Initialize the buffer region with empty region - for (const Buffer& buffer : block->alloc_buffers) { - buffers_region_.emplace(buffer.get(), - NDIntSet(buffer->shape.size(), arith::IntSet::Nothing())); - } - // Step 3. Visit recursively + // Step 2. Visit recursively ++block_nest_depth_; Stmt body = this->VisitStmt(block->body); --block_nest_depth_; - // Step 4. Update the read/write buffer regions + // Step 3. Update the read/write buffer regions Array reads = VisitBufferRegions(block->reads); Array writes = VisitBufferRegions(block->writes); - // Step 5. Handle predicate + // Step 4. Handle predicate PrimExpr predicate = this->VisitExpr(realize->predicate); - // Step 6. Root allocation - Array alloc_buffers = (block_nest_depth_ == 0) ? AllocBufers(nullptr) : Array{}; - // Step 7. Create new blocks + // Step 5. Root allocation + Array alloc_buffers = + (block_nest_depth_ == 0) ? AllocBufferUnderLoop(NullOpt) : Array{}; + // Step 6. Create new blocks return BlockRealize(/*binding_values=*/{}, /*predicate=*/std::move(predicate), /*block=*/ @@ -456,7 +428,7 @@ class RegionGatherer : public StmtExprMutator { } PrimExpr VisitExpr_(const VarNode* var) final { - auto it = var_substitutes_.find(var); + auto it = var_substitutes_.find(GetRef(var)); if (it != var_substitutes_.end()) { return it->second; } @@ -468,7 +440,24 @@ class RegionGatherer : public StmtExprMutator { result.reserve(buffer_regions.size()); for (const BufferRegion& buffer_region : buffer_regions) { const Buffer& buffer = buffer_region->buffer; - VarDomain dom_map = LoopVarDomain(buffer); + ICHECK(buffer_info_.count(buffer)); + BufferInfo& info = buffer_info_.at(buffer); + std::unordered_map dom_map; + { + const Object* lca = info.alloc_site.get(); + // Every loop will be relaxed if the lca is the root + bool need_relax = (lca == nullptr); + for (const For& loop : this->ancestor_loops_) { + const VarNode* loop_var = loop->loop_var.get(); + // TODO + if (need_relax || (buffer->scope == "shared" && IsThreadBound(loop))) { + dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); + } + if (loop.get() == lca) { + need_relax = true; + } + } + } int ndim = buffer_region->region.size(); Array region; NDIntSet int_set; @@ -480,44 +469,68 @@ class RegionGatherer : public StmtExprMutator { region.push_back(new_range); int_set.push_back(arith::EvalSet(new_range, dom_map)); } - auto it = buffers_region_.find(buffer.get()); - ICHECK(it != buffers_region_.end()); - NDIntSet& alloc_region = it->second; - UnionWith(&alloc_region, int_set); + NDIntSetUnionWith(&info.accessed_region, int_set); result.push_back(BufferRegion(buffer_region->buffer, region)); } return result; } - VarDomain LoopVarDomain(const Buffer& buffer) const { - auto it = this->buffers_lca_.find(buffer.get()); - ICHECK(it != this->buffers_lca_.end()); - const ForNode* lca = it->second; - // Every loop will be relaxed if the lca is the root - VarDomain dom_map; - bool need_relax = (lca == nullptr); - for (const ForNode* loop : this->ancestor_loops_) { - const VarNode* loop_var = loop->loop_var.get(); - // TODO - if (need_relax || (buffer->scope == "shared" && IsThreadBound(loop))) { - dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); - } - if (loop == lca) { - need_relax = true; + Array AllocBufferUnderLoop(const Optional& loop) { + auto it = loop_allocs_.find(loop); + if (it == loop_allocs_.end()) { + return {}; + } + const Array& buffers = it->second; + Array result; + result.reserve(buffers.size()); + for (const Buffer& buffer : buffers) { + ICHECK(buffer_info_.count(buffer)); + BufferInfo& info = buffer_info_.at(buffer); + ICHECK(!info.region.defined()); + ICHECK(!info.new_buffer.defined()); + // Calculate `info.region` + info.region = NDIntSet2Region(info.accessed_region); + // Calculate `info.new_buffer` + Array shape; + shape.reserve(info.region.size()); + for (const Range& range : info.region) { + shape.push_back(range->extent); } + ObjectPtr new_buffer = make_object(*buffer.get()); + new_buffer->shape = std::move(shape); + info.new_buffer = Buffer(std::move(new_buffer)); + } + return result; + } + + std::pair> RewriteBufferAccess(const Buffer& buffer, + const Array& indices) const { + ICHECK(buffer_info_.count(buffer)); + const BufferInfo& info = buffer_info_.at(buffer); + ICHECK(info.new_buffer.defined()); + ICHECK(info.region.defined()); + ICHECK_EQ(indices.size(), info.region.size()); + int ndim = indices.size(); + Array new_indices; + new_indices.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + new_indices.push_back(indices[i] - info.region[i]->min); } - return dom_map; + return std::make_pair(info.new_buffer, std::move(new_indices)); } - int block_nest_depth_ = 0; - /*! \brief The map from Buffer to its LCA Stmt/Expr */ - std::unordered_map buffers_lca_; - std::unordered_map> buffer_alloc_; - std::unordered_set reduction_loop_vars_; + /*! \brief Number of blocks nested in the ancestor during visiting */ + int block_nest_depth_; + /*! \brief Collective information about each buffer */ + SMap buffer_info_; + /*! \brief Buffers allocated at each for loop */ + SMap, Array> loop_allocs_; /*! \brief The loops from the current node up to the root */ - std::vector ancestor_loops_; + std::vector ancestor_loops_; /*! \brief The map from block vars to the expr value */ - std::unordered_map var_substitutes_; + SMap var_substitutes_; + /*! \brief Loop variables that are bound to reduction block vars */ + SSet reduction_loop_vars_; }; /*! @@ -536,7 +549,7 @@ class BufferFlattener : public StmtExprMutator { const ForNode* loop = ancestor_loops_.back(); for (const BufferNode* buffer : double_buffer) { // TODO - // const Object* lca = buffers_lca_.at(GetRef(buffer)).get(); + // const Object* lca = buffer_lca_.at(GetRef(buffer)).get(); // if (lca != nullptr && loop == lca) { // body = AttrStmt(buffer->data, attr::double_buffer_scope, 1, body); // } else { @@ -619,9 +632,7 @@ PrimFunc BufferFlatten(PrimFunc f) { ReductionTransformer reduction_transformer; fptr->body = reduction_transformer(fptr->body); // Step 2. Recalculate the buffer region - Map> buffer_lca = LCADetector::Detect(f); - RegionGatherer region_gatherer; - fptr->body = region_gatherer.Gather(buffer_lca, f); + RegionGatherer::Gather(f); // Step 3. Transform BufferLoad/BufferStore into Load/Store BufferFlattener flattener; fptr->body = flattener(fptr->body); From ca8d6e2490820725f455ccae653cc236e0903b76 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 15 Mar 2021 21:57:49 +0000 Subject: [PATCH 15/20] ... --- src/tir/transforms/buffer_flatten.cc | 38 ++++++++++++++++------------ 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index f7d0a52b23..b70ea15f8f 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -436,43 +436,49 @@ class RegionGatherer : public StmtExprMutator { } Array VisitBufferRegions(const Array& buffer_regions) { - Array result; - result.reserve(buffer_regions.size()); + // Calculate `new_buffer_regions` by recursively visiting min/extent of each range + Array new_buffer_regions; + new_buffer_regions.reserve(buffer_regions.size()); for (const BufferRegion& buffer_region : buffer_regions) { + const Buffer& buffer = buffer_region->buffer; + const Array& region = buffer_region->region; + Array new_region; + new_region.reserve(region.size()); + for (const Range& range : region) { + new_region.push_back(Range::FromMinExtent(/*min=*/this->VisitExpr(range->min), + /*extent=*/this->VisitExpr(range->extent))); + } + new_buffer_regions.push_back(BufferRegion(buffer, new_region)); + } + // Calculate `info.accessed_region` + for (const BufferRegion& buffer_region : new_buffer_regions) { const Buffer& buffer = buffer_region->buffer; ICHECK(buffer_info_.count(buffer)); BufferInfo& info = buffer_info_.at(buffer); std::unordered_map dom_map; { - const Object* lca = info.alloc_site.get(); + const Object* alloc_site = info.alloc_site.get(); // Every loop will be relaxed if the lca is the root - bool need_relax = (lca == nullptr); + bool need_relax = (alloc_site == nullptr); for (const For& loop : this->ancestor_loops_) { const VarNode* loop_var = loop->loop_var.get(); - // TODO if (need_relax || (buffer->scope == "shared" && IsThreadBound(loop))) { + // TODO dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); } - if (loop.get() == lca) { + if (loop.get() == alloc_site) { need_relax = true; } } } - int ndim = buffer_region->region.size(); - Array region; NDIntSet int_set; - region.reserve(ndim); - int_set.reserve(ndim); + int_set.reserve(buffer_region->region.size()); for (const Range& range : buffer_region->region) { - Range new_range = - Range::FromMinExtent(this->VisitExpr(range->min), this->VisitExpr(range->extent)); - region.push_back(new_range); - int_set.push_back(arith::EvalSet(new_range, dom_map)); + int_set.push_back(arith::EvalSet(range, dom_map)); } NDIntSetUnionWith(&info.accessed_region, int_set); - result.push_back(BufferRegion(buffer_region->buffer, region)); } - return result; + return new_buffer_regions; } Array AllocBufferUnderLoop(const Optional& loop) { From de0f1011ecd4693817b6ca7df9eba74f5331627b Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 15 Mar 2021 23:34:21 +0000 Subject: [PATCH 16/20] add double buffering --- src/tir/transforms/buffer_flatten.cc | 141 ++++++++++++--------------- 1 file changed, 63 insertions(+), 78 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index b70ea15f8f..ece6e295b6 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -38,6 +38,10 @@ namespace tvm { namespace tir { using NDIntSet = std::vector; +template +using SMap = std::unordered_map; +template +using SSet = std::unordered_set; arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); @@ -96,31 +100,6 @@ bool IsReduceTempBuffer(const Buffer& buffer) { StartsWith(buffer->name, "reduce_temp"); } -String NormalizeStorageScope(const String& s) { - if (s.empty()) { - return "global"; - } - return s; -} - -Stmt MakeAllocStmt(const Buffer& buffer, const PrimExpr& area, Stmt body) { - body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), body); - body = AttrStmt(buffer->data, attr::storage_scope, - StringImm(NormalizeStorageScope(buffer->scope)), body); - return body; -} - -Stmt MakeLaunchThread(const PrimExpr& min, const PrimExpr& extent, const Var& var, - const String& thread_tag, Stmt body) { - IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), - /*var=*/var, - /*iter_type=*/IterVarType::kThreadIndex, - /*thread_tag=*/thread_tag); - String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; - body = AttrStmt(iter_var, attr_key, extent, body); - return body; -} - PrimExpr BufferArea(const Buffer& buffer) { PrimExpr area = Integer(1); for (const PrimExpr& dim : buffer->shape) { @@ -270,11 +249,6 @@ class BufferAccessRewriter : public StmtExprMutator { * \brief Gather the used region of each buffers. */ class RegionGatherer : public StmtExprMutator { - template - using SMap = std::unordered_map; - template - using SSet = std::unordered_set; - struct BufferInfo { NDIntSet accessed_region; Optional alloc_site; @@ -544,65 +518,36 @@ class RegionGatherer : public StmtExprMutator { */ class BufferFlattener : public StmtExprMutator { private: - Stmt VisitStmt_(const SeqStmtNode* op) final { - Array seq; - seq.reserve(op->seq.size()); - for (const Stmt& stmt : op->seq) { - std::unordered_set double_buffer; - std::swap(double_buffer, double_buffer_); - Stmt body = VisitStmt(stmt); - std::swap(double_buffer, double_buffer_); - const ForNode* loop = ancestor_loops_.back(); - for (const BufferNode* buffer : double_buffer) { - // TODO - // const Object* lca = buffer_lca_.at(GetRef(buffer)).get(); - // if (lca != nullptr && loop == lca) { - // body = AttrStmt(buffer->data, attr::double_buffer_scope, 1, body); - // } else { - // double_buffer_.insert(buffer); - // } - } - seq.push_back(body); - } - return SeqStmt(seq); - } - Stmt VisitStmt_(const BlockRealizeNode* realize) final { - // Step 3. Visit the body + ICHECK(realize->binding_values.empty()); + // Step 1. Visit the body Block new_block = Downcast(this->VisitStmt(realize->block)); + PrimExpr predicate = this->VisitExpr(realize->predicate); const BlockNode* block = new_block.get(); - // Step 4. Transform the `predicate` to if-then-else + // Step 2. Transform the `predicate` to if-then-else Stmt body = block->body; - if (!is_one(realize->predicate)) { - body = IfThenElse(realize->predicate, body); + if (!is_one(predicate)) { + body = IfThenElse(predicate, body); } - // Step 5. Pick out blocks that writes with double buffering - for (const auto& ann : block->annotations) { - const String& ann_key = ann.first; - const ObjectRef& ann_value = ann.second; - if (ann_key == attr::double_buffer_scope) { - if (is_one(Downcast(ann_value))) { - ICHECK_EQ(block->writes.size(), 1); - const BufferRegion& write = block->writes[0]; - double_buffer_.insert(write->buffer.get()); - } - } + // Step 3. Pick out blocks that writes with double buffering + if (IsDoubleBufferScope(block->annotations)) { + ICHECK_EQ(block->writes.size(), 1); + const Buffer& write = block->writes[0]->buffer; + double_buffered_.insert(write); } - // Step 6. Handle allocations + // Step 4. Handle allocations for (const Buffer& buffer : block->alloc_buffers) { - body = MakeAllocStmt(buffer, BufferArea(buffer), body); + body = MakeAllocStmt(buffer, body, double_buffered_.count(buffer)); } return body; } Stmt VisitStmt_(const ForNode* op) final { - // Step 2. Visit recursively - ancestor_loops_.push_back(op); - Stmt body = this->VisitStmt(op->body); + // Step 1. Visit recursively PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); - ancestor_loops_.pop_back(); - // Step 4. Add the for loop accordingly + Stmt body = this->VisitStmt(op->body); + // Step 2. Add the for loop accordingly if (op->kind == ForKind::kThreadBinding) { // Case 1. Thread binding ICHECK(op->thread_binding.defined()); @@ -615,7 +560,7 @@ class BufferFlattener : public StmtExprMutator { // Case 3. An ordinary loop body = For(op->loop_var, min, extent, op->kind, body); } - // Step 5. Handle annotations + // Step 3. Handle annotations for (const auto& annotation : op->annotations) { const String& ann_key = annotation.first; const ObjectRef& ann_value = annotation.second; @@ -626,8 +571,48 @@ class BufferFlattener : public StmtExprMutator { return body; } - std::unordered_set double_buffer_; - std::vector ancestor_loops_; + static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body, bool is_double_buffer) { + String storage_scope = buffer->scope; + if (storage_scope.empty()) { + storage_scope = "global"; + } + PrimExpr area = BufferArea(buffer); + body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), body); + body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), body); + if (is_double_buffer) { + body = AttrStmt(buffer->data, attr::double_buffer_scope, Integer(1), body); + } + return body; + } + + static Stmt MakeLaunchThread(const PrimExpr& min, const PrimExpr& extent, const Var& var, + const String& thread_tag, Stmt body) { + IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), + /*var=*/var, + /*iter_type=*/IterVarType::kThreadIndex, + /*thread_tag=*/thread_tag); + String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; + body = AttrStmt(iter_var, attr_key, extent, body); + return body; + } + + static bool IsDoubleBufferScope(const Map& annotations) { + for (const auto& ann : annotations) { + const String& ann_key = ann.first; + const ObjectRef& ann_value = ann.second; + if (ann_key != attr::double_buffer_scope) { + continue; + } + const auto* value = TVM_TYPE_AS(value, ann_value, PrimExprNode); + if (!is_one(GetRef(value))) { + continue; + } + return true; + } + return false; + } + + SSet double_buffered_; }; PrimFunc BufferFlatten(PrimFunc f) { From c731fe5a7c38dfecd9ff205064110be057c9b8f1 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 15 Mar 2021 23:47:09 +0000 Subject: [PATCH 17/20] vload, vstore --- src/tir/schedule/utils.h | 1 + src/tir/transforms/buffer_flatten.cc | 68 ++++++++++++++++------------ 2 files changed, 39 insertions(+), 30 deletions(-) diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 5ddf3941e9..bcecab5075 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -20,6 +20,7 @@ #define TVM_TIR_SCHEDULE_SCHEDULE_COMMON_H_ #include +#include #include #include #include diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index ece6e295b6..cfc791c2ce 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -20,16 +20,7 @@ /*! * \file buffer_flatten.cc */ - -#include -#include -#include -#include #include -#include -#include -#include -#include #include #include "../schedule/utils.h" @@ -37,12 +28,13 @@ namespace tvm { namespace tir { -using NDIntSet = std::vector; template using SMap = std::unordered_map; template using SSet = std::unordered_set; +using NDIntSet = std::vector; + arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); } @@ -110,6 +102,9 @@ PrimExpr BufferArea(const Buffer& buffer) { class ReductionTransformer : public StmtMutator { public: + static Stmt Transform(const PrimFunc& f) { return ReductionTransformer().VisitStmt(f->body); } + + private: Stmt VisitStmt_(const BlockNode* block) override { if (!block->init.defined()) { return StmtMutator::VisitStmt_(block); @@ -230,7 +225,7 @@ class BufferAccessRewriter : public StmtExprMutator { Buffer new_buffer{nullptr}; Array new_indices{nullptr}; std::tie(new_buffer, new_indices) = f_rewrite_(op->buffer, op->indices); - return new_buffer.vstore(new_indices, op->value); + return BufferStore(new_buffer, op->value, new_indices); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { @@ -239,7 +234,7 @@ class BufferAccessRewriter : public StmtExprMutator { Buffer new_buffer{nullptr}; Array new_indices{nullptr}; std::tie(new_buffer, new_indices) = f_rewrite_(op->buffer, op->indices); - return new_buffer.vload(new_indices, op->dtype); + return BufferLoad(new_buffer, new_indices); } const FRewriteBufferAccess& f_rewrite_; @@ -249,19 +244,6 @@ class BufferAccessRewriter : public StmtExprMutator { * \brief Gather the used region of each buffers. */ class RegionGatherer : public StmtExprMutator { - struct BufferInfo { - NDIntSet accessed_region; - Optional alloc_site; - Array region; - Buffer new_buffer; - - explicit BufferInfo(int ndim, Optional alloc_site) - : accessed_region(NDIntSetEmpty(ndim)), // - alloc_site(std::move(alloc_site)), - region{nullptr}, - new_buffer{nullptr} {} - }; - public: static Stmt Gather(const PrimFunc& f) { Map> buffer_lca = LCADetector::Detect(f); @@ -292,6 +274,19 @@ class RegionGatherer : public StmtExprMutator { } private: + struct BufferInfo { + NDIntSet accessed_region; + Optional alloc_site; + Array region; + Buffer new_buffer; + + explicit BufferInfo(int ndim, Optional alloc_site) + : accessed_region(NDIntSetEmpty(ndim)), // + alloc_site(std::move(alloc_site)), + region{nullptr}, + new_buffer{nullptr} {} + }; + explicit RegionGatherer(SMap buffer_info, SMap, Array> loop_allocs) : block_nest_depth_(0), @@ -517,6 +512,9 @@ class RegionGatherer : public StmtExprMutator { * \brief Transform multi-dimension BufferLoad/BufferStore into one-dimension Load/Store */ class BufferFlattener : public StmtExprMutator { + public: + static Stmt Flatten(const PrimFunc& f) { return BufferFlattener().VisitStmt(f->body); } + private: Stmt VisitStmt_(const BlockRealizeNode* realize) final { ICHECK(realize->binding_values.empty()); @@ -571,6 +569,18 @@ class BufferFlattener : public StmtExprMutator { return body; } + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + op = store.get(); + return op->buffer.vstore(op->indices, op->value); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + op = load.get(); + return op->buffer.vload(op->indices, op->dtype); + } + static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body, bool is_double_buffer) { String storage_scope = buffer->scope; if (storage_scope.empty()) { @@ -620,13 +630,11 @@ PrimFunc BufferFlatten(PrimFunc f) { // Step 0. Check memory and execution hierarchy VerifyExecScope(f); // Step 1.Transform the reduction calls to BufferStore - ReductionTransformer reduction_transformer; - fptr->body = reduction_transformer(fptr->body); + fptr->body = ReductionTransformer::Transform(f); // Step 2. Recalculate the buffer region - RegionGatherer::Gather(f); + fptr->body = RegionGatherer::Gather(f); // Step 3. Transform BufferLoad/BufferStore into Load/Store - BufferFlattener flattener; - fptr->body = flattener(fptr->body); + fptr->body = BufferFlattener::Flatten(f); return f; } From 2a89c67aaeed52d3150fcc412573e04e0439e7a7 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 16 Mar 2021 06:28:16 +0000 Subject: [PATCH 18/20] done(?) --- include/tvm/tir/expr.h | 1 + include/tvm/tir/stmt.h | 5 +- src/tir/transforms/buffer_flatten.cc | 162 ++++++++++++++++----------- 3 files changed, 102 insertions(+), 66 deletions(-) diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 04990c522a..9a1a4d85f3 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -639,6 +639,7 @@ class BufferLoad : public PrimExpr { public: TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; /*! diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 17e1b402b4..185b30761b 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -942,6 +942,7 @@ class BufferRegion : public ObjectRef { TVM_DLL explicit BufferRegion(Buffer buffer); TVM_DLL explicit BufferRegion(Buffer buffer, Array region); TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode); }; /*! @@ -1058,8 +1059,8 @@ class Block : public Stmt { TVM_DLL explicit Block(Array iter_vars, Array reads, Array writes, Array alloc_buffers, Map annotations, Array match_buffers, - String exec_scope, String name_hint, - Stmt body, Optional init, Span = Span()); + String exec_scope, String name_hint, Stmt body, Optional init, + Span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode); diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index cfc791c2ce..1aab36c4d8 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -208,8 +208,7 @@ class LCADetector : public StmtExprVisitor { class BufferAccessRewriter : public StmtExprMutator { public: - using FRewriteBufferAccess = std::function>( - const Buffer& buffer, const Array& indices)>; + using FRewriteBufferAccess = std::function* indices)>; static Stmt Rewrite(Stmt stmt, const FRewriteBufferAccess& f_rewrite) { BufferAccessRewriter rewriter(f_rewrite); @@ -219,33 +218,47 @@ class BufferAccessRewriter : public StmtExprMutator { private: explicit BufferAccessRewriter(const FRewriteBufferAccess& f_rewrite) : f_rewrite_(f_rewrite) {} - Stmt VisitStmt_(const BufferStoreNode* op) final { - BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); - op = store.get(); - Buffer new_buffer{nullptr}; - Array new_indices{nullptr}; - std::tie(new_buffer, new_indices) = f_rewrite_(op->buffer, op->indices); - return BufferStore(new_buffer, op->value, new_indices); + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + BufferStoreNode* op = store.CopyOnWrite(); + f_rewrite_(&op->buffer, &op->indices); + return store; } - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); - op = load.get(); - Buffer new_buffer{nullptr}; - Array new_indices{nullptr}; - std::tie(new_buffer, new_indices) = f_rewrite_(op->buffer, op->indices); - return BufferLoad(new_buffer, new_indices); + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + BufferLoadNode* op = load.CopyOnWrite(); + f_rewrite_(&op->buffer, &op->indices); + return load; + } + + Stmt VisitStmt_(const BlockNode* _op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(_op)); + BlockNode* op = block.CopyOnWrite(); + ArrayNode* reads = op->reads.CopyOnWrite(); + for (int i = 0, n = reads->size(); i < n; ++i) { + BufferRegion buffer_region = Downcast(reads->at(i)); + BufferRegionNode* p = buffer_region.CopyOnWrite(); + f_rewrite_(&p->buffer, nullptr); + } + ArrayNode* writes = op->writes.CopyOnWrite(); + for (int i = 0, n = writes->size(); i < n; ++i) { + BufferRegion buffer_region = Downcast(writes->at(i)); + BufferRegionNode* p = buffer_region.CopyOnWrite(); + f_rewrite_(&p->buffer, nullptr); + } + return block; } const FRewriteBufferAccess& f_rewrite_; }; /*! - * \brief Gather the used region of each buffers. + * \brief Alloc the used region of each buffers. */ -class RegionGatherer : public StmtExprMutator { +class BufferAllocator : public StmtExprMutator { public: - static Stmt Gather(const PrimFunc& f) { + static Stmt Alloc(const PrimFunc& f) { Map> buffer_lca = LCADetector::Detect(f); SMap buffer_info; SMap, Array> loop_allocs; @@ -263,14 +276,20 @@ class RegionGatherer : public StmtExprMutator { ICHECK(buffer_info.count(buffer)); BufferInfo& info = buffer_info.at(buffer); info.accessed_region = NDIntSetFromShape(buffer->shape); + info.alloc_site = NullOpt; + info.region = NDIntSet2Region(info.accessed_region); + info.new_buffer = buffer; + info.is_arg = true; } - RegionGatherer gatherer(std::move(buffer_info), std::move(loop_allocs)); - return BufferAccessRewriter::Rewrite( - /*stmt=*/gatherer.VisitStmt(f->body), - /*f_rewrite=*/std::bind(&RegionGatherer::RewriteBufferAccess, // - &gatherer, // - std::placeholders::_1, // + BufferAllocator alloc(std::move(buffer_info), std::move(loop_allocs)); + Stmt stmt = alloc.VisitStmt(f->body); + stmt = BufferAccessRewriter::Rewrite( + /*stmt=*/std::move(stmt), + /*f_rewrite=*/std::bind(&BufferAllocator::RewriteBufferAccess, + &alloc, // + std::placeholders::_1, // std::placeholders::_2)); + return stmt; } private: @@ -279,16 +298,18 @@ class RegionGatherer : public StmtExprMutator { Optional alloc_site; Array region; Buffer new_buffer; + bool is_arg; explicit BufferInfo(int ndim, Optional alloc_site) - : accessed_region(NDIntSetEmpty(ndim)), // + : accessed_region(NDIntSetEmpty(ndim)), alloc_site(std::move(alloc_site)), region{nullptr}, - new_buffer{nullptr} {} + new_buffer{nullptr}, + is_arg(false) {} }; - explicit RegionGatherer(SMap buffer_info, - SMap, Array> loop_allocs) + explicit BufferAllocator(SMap buffer_info, + SMap, Array> loop_allocs) : block_nest_depth_(0), buffer_info_(std::move(buffer_info)), loop_allocs_(std::move(loop_allocs)), @@ -383,19 +404,6 @@ class RegionGatherer : public StmtExprMutator { /*init=*/NullOpt)); } - PrimExpr VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::get_elem_offset())) { - // Handle `get_elem_offset` - ICHECK_EQ(op->args.size(), 1); - PrimExpr arg = op->args[0]; - ICHECK(arg->IsInstance()); - arg = this->VisitExpr(arg); - const auto* load = TVM_TYPE_AS(load, arg, LoadNode); - return load->index; - } - return StmtExprMutator::VisitExpr_(op); - } - PrimExpr VisitExpr_(const VarNode* var) final { auto it = var_substitutes_.find(GetRef(var)); if (it != var_substitutes_.end()) { @@ -424,6 +432,9 @@ class RegionGatherer : public StmtExprMutator { const Buffer& buffer = buffer_region->buffer; ICHECK(buffer_info_.count(buffer)); BufferInfo& info = buffer_info_.at(buffer); + if (info.is_arg) { + continue; + } std::unordered_map dom_map; { const Object* alloc_site = info.alloc_site.get(); @@ -461,8 +472,14 @@ class RegionGatherer : public StmtExprMutator { for (const Buffer& buffer : buffers) { ICHECK(buffer_info_.count(buffer)); BufferInfo& info = buffer_info_.at(buffer); - ICHECK(!info.region.defined()); - ICHECK(!info.new_buffer.defined()); + if (info.is_arg) { + ICHECK(info.region.defined()); + ICHECK(info.new_buffer.defined()); + continue; + } else { + ICHECK(!info.region.defined()); + ICHECK(!info.new_buffer.defined()); + } // Calculate `info.region` info.region = NDIntSet2Region(info.accessed_region); // Calculate `info.new_buffer` @@ -474,24 +491,30 @@ class RegionGatherer : public StmtExprMutator { ObjectPtr new_buffer = make_object(*buffer.get()); new_buffer->shape = std::move(shape); info.new_buffer = Buffer(std::move(new_buffer)); + result.push_back(info.new_buffer); } return result; } - std::pair> RewriteBufferAccess(const Buffer& buffer, - const Array& indices) const { - ICHECK(buffer_info_.count(buffer)); - const BufferInfo& info = buffer_info_.at(buffer); + void RewriteBufferAccess(Buffer* buffer, Array* indices) const { + ICHECK(buffer_info_.count(*buffer)); + const BufferInfo& info = buffer_info_.at(*buffer); ICHECK(info.new_buffer.defined()); + if (indices == nullptr) { + *buffer = info.new_buffer; + return; + } ICHECK(info.region.defined()); - ICHECK_EQ(indices.size(), info.region.size()); - int ndim = indices.size(); + // TODO: the ndim could be changed by tensorize, more investigation is needed + ICHECK_GE(indices->size(), info.region.size()); + int ndim = info.region.size(); Array new_indices; new_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { - new_indices.push_back(indices[i] - info.region[i]->min); + new_indices.push_back((*indices)[i] - info.region[i]->min); } - return std::make_pair(info.new_buffer, std::move(new_indices)); + *buffer = info.new_buffer; + *indices = std::move(new_indices); } /*! \brief Number of blocks nested in the ancestor during visiting */ @@ -581,6 +604,19 @@ class BufferFlattener : public StmtExprMutator { return op->buffer.vload(op->indices, op->dtype); } + PrimExpr VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::get_elem_offset())) { + // Handle `get_elem_offset` + ICHECK_EQ(op->args.size(), 1); + PrimExpr arg = op->args[0]; + ICHECK(arg->IsInstance()); + arg = this->VisitExpr(arg); + const auto* load = TVM_TYPE_AS(load, arg, LoadNode); + return load->index; + } + return StmtExprMutator::VisitExpr_(op); + } + static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body, bool is_double_buffer) { String storage_scope = buffer->scope; if (storage_scope.empty()) { @@ -607,17 +643,15 @@ class BufferFlattener : public StmtExprMutator { } static bool IsDoubleBufferScope(const Map& annotations) { - for (const auto& ann : annotations) { - const String& ann_key = ann.first; - const ObjectRef& ann_value = ann.second; + for (const auto& kv : annotations) { + const String& ann_key = kv.first; + const ObjectRef& ann_value = kv.second; if (ann_key != attr::double_buffer_scope) { - continue; - } - const auto* value = TVM_TYPE_AS(value, ann_value, PrimExprNode); - if (!is_one(GetRef(value))) { - continue; + const auto* value = TVM_TYPE_AS(value, ann_value, PrimExprNode); + if (is_one(GetRef(value))) { + return true; + } } - return true; } return false; } @@ -626,13 +660,13 @@ class BufferFlattener : public StmtExprMutator { }; PrimFunc BufferFlatten(PrimFunc f) { - tvm::tir::PrimFuncNode* fptr = f.CopyOnWrite(); + PrimFuncNode* fptr = f.CopyOnWrite(); // Step 0. Check memory and execution hierarchy VerifyExecScope(f); - // Step 1.Transform the reduction calls to BufferStore + // Step 1. Transform the reduction calls to BufferStore fptr->body = ReductionTransformer::Transform(f); // Step 2. Recalculate the buffer region - fptr->body = RegionGatherer::Gather(f); + fptr->body = BufferAllocator::Alloc(f); // Step 3. Transform BufferLoad/BufferStore into Load/Store fptr->body = BufferFlattener::Flatten(f); return f; From 709c35a8faca3d889885bf47baeb8885fa2f7189 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 16 Mar 2021 06:34:11 +0000 Subject: [PATCH 19/20] rule out allocation for reduce temp buffers --- src/tir/transforms/buffer_flatten.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 1aab36c4d8..90a6e22e6d 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -618,6 +618,9 @@ class BufferFlattener : public StmtExprMutator { } static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body, bool is_double_buffer) { + if (IsReduceTempBuffer(buffer)) { + return body; + } String storage_scope = buffer->scope; if (storage_scope.empty()) { storage_scope = "global"; From a8babf24742ecf26da52310b1d1d89d86e25b564 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 16 Mar 2021 22:04:14 +0000 Subject: [PATCH 20/20] fix --- src/tir/transforms/buffer_flatten.cc | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 90a6e22e6d..5dec7d6598 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -646,14 +646,10 @@ class BufferFlattener : public StmtExprMutator { } static bool IsDoubleBufferScope(const Map& annotations) { - for (const auto& kv : annotations) { - const String& ann_key = kv.first; - const ObjectRef& ann_value = kv.second; - if (ann_key != attr::double_buffer_scope) { - const auto* value = TVM_TYPE_AS(value, ann_value, PrimExprNode); - if (is_one(GetRef(value))) { - return true; - } + if (Optional ann_value = annotations.Get(attr::double_buffer_scope)) { + const auto* value = TVM_TYPE_AS(value, ann_value, PrimExprNode); + if (is_one(GetRef(value))) { + return true; } } return false;