From 9f0b3a18b56b5fb077807c5e26b8e6c15adc7b8a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 18 Mar 2021 05:46:02 -0700 Subject: [PATCH] [Refactor] Buffer flatten (#340) --- include/tvm/tir/expr.h | 1 + src/meta_schedule/utils.h | 19 +- src/tir/schedule/analysis.cc | 11 + src/tir/schedule/analysis.h | 6 + src/tir/schedule/transform.cc | 45 ++ src/tir/schedule/transform.h | 40 ++ src/tir/schedule/utils.h | 42 ++ src/tir/transforms/buffer_flatten.cc | 945 +++++++++++++++------------ 8 files changed, 677 insertions(+), 432 deletions(-) create mode 100644 src/tir/schedule/transform.cc create mode 100644 src/tir/schedule/transform.h diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 473c08578d..f08ca028fa 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -639,6 +639,7 @@ class BufferLoad : public PrimExpr { public: TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; /*! diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index ce960edcfa..12f3fcad88 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -338,23 +338,8 @@ inline std::ostream& StdCout(int verbose, int setting = 1) { /**************** String Manipulation ****************/ -/*! - * \brief Find all positions that the specific char occurs in the string - * \param str The string to be examined - * \param c The specific char - * \return A list of integers indicating the occurrence position - */ -inline std::vector FindCharPos(const String& str, char c) { - std::vector result; - const char* data = str.data(); - int n = str.length(); - for (int i = 0; i < n; ++i) { - if (data[i] == c) { - result.push_back(i); - } - } - return result; -} +using tir::FindCharPos; +using tir::StartsWith; /**************** Target Hardware Concurrency ****************/ diff --git a/src/tir/schedule/analysis.cc b/src/tir/schedule/analysis.cc index 2579b1b18d..64425df156 100644 --- a/src/tir/schedule/analysis.cc +++ b/src/tir/schedule/analysis.cc @@ -54,6 +54,17 @@ bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set Vars(const ObjectRef& stmt_or_expr) { + std::unordered_set result; + auto f_visit = [&result](const ObjectRef& obj) -> void { + if (const auto* var = obj.as()) { + result.insert(var); + } + }; + PostOrderVisit(stmt_or_expr, f_visit); + return result; +} + bool ValidateBlockBinding(const BlockRealize& realize, const Map& loop_var_ranges) { arith::Analyzer analyzer; Array results = arith::DetectIterMap( diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index c799befb11..c0d00ab788 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -45,6 +45,12 @@ bool ContainsVar(const ObjectRef& stmt_or_expr, const Var& var); * \return A boolean indicating if any var in the list is found in stmt/expr */ bool ContainsVar(const ObjectRef& stmt_or_expr, const std::unordered_set& var); +/*! + * \brief Collect the variables that appear in the specific Stmt or Expr + * \param stmt_or_expr The Stmt or Expr + * \return All variables that appear + */ +std::unordered_set Vars(const ObjectRef& stmt_or_expr); /******** Verification ********/ /*! diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc new file mode 100644 index 0000000000..191d079f60 --- /dev/null +++ b/src/tir/schedule/transform.cc @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +Stmt RealizeInitBlock(const Stmt& init, const Array& iter_vars) { + std::vector conditions; + for (const IterVar& var : iter_vars) { + if (var->iter_type == IterVarType::kCommReduce) { + conditions.push_back(equal(var->var, var->dom->min)); + } + } + int n = conditions.size(); + // Handle the case where there is no condition + if (n == 0) { + return init; + } + // Concate the conditions with logical and (&&) + PrimExpr cond = conditions[0]; + for (int i = 1; i < n; ++i) { + cond = logical_and(cond, conditions[i]); + } + return IfThenElse(cond, init); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h new file mode 100644 index 0000000000..22e7c1822b --- /dev/null +++ b/src/tir/schedule/transform.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_ +#define TVM_TIR_SCHEDULE_TRANSFORM_H_ + +#include + +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Transform the init block into actual computation + * \param init The init block + * \param iter_vars The block variables + * \return The actual computation + */ +Stmt RealizeInitBlock(const Stmt& init, const Array& iter_vars); + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_ diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index d823f60308..bf2f322328 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -20,6 +20,7 @@ #define TVM_TIR_SCHEDULE_SCHEDULE_COMMON_H_ #include +#include #include #include #include @@ -35,6 +36,7 @@ #include #include "./analysis.h" +#include "./transform.h" namespace tvm { namespace tir { @@ -455,6 +457,46 @@ static DefaultReducer default_reducers[4] = { } // namespace default_reducer +/**************** String ****************/ + +/*! + * \brief Find all positions that the specific char occurs in the string + * \param str The string to be examined + * \param c The specific char + * \return A list of integers indicating the occurrence position + */ +inline std::vector FindCharPos(const String& str, char c) { + std::vector result; + const char* data = str.data(); + int n = str.length(); + for (int i = 0; i < n; ++i) { + if (data[i] == c) { + result.push_back(i); + } + } + return result; +} + +inline bool StartsWith(const String& str, const String& prefix) { + int n = prefix.size(); + if (static_cast(str.size()) < n) { + return false; + } + const char* data = str.data(); + return std::equal(data, data + n, prefix.data()); +} + +inline bool StartsWith(const String& str, const char* prefix) { + int n = strlen(prefix); + if (static_cast(str.size()) < n) { + return false; + } + const char* data = str.data(); + return std::equal(data, data + n, prefix); +} + +/**************** Loop extents ****************/ + inline int64_t GetLoopIntExtent(const ForNode* loop) { const auto* int_extent = loop->extent.as(); return int_extent ? int_extent->value : -1; diff --git a/src/tir/transforms/buffer_flatten.cc b/src/tir/transforms/buffer_flatten.cc index 37a5bfc76b..5dec7d6598 100644 --- a/src/tir/transforms/buffer_flatten.cc +++ b/src/tir/transforms/buffer_flatten.cc @@ -20,43 +20,101 @@ /*! * \file buffer_flatten.cc */ - -#include -#include -#include -#include #include -#include -#include -#include -#include #include +#include "../schedule/utils.h" + namespace tvm { namespace tir { -/*! - * \brief Transform block with init into actual computation - */ -class ReductionTransformer : public StmtExprMutator { +template +using SMap = std::unordered_map; +template +using SSet = std::unordered_set; + +using NDIntSet = std::vector; + +arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { + return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); +} + +void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) { + ICHECK_EQ(lhs->size(), rhs.size()); + int ndim = rhs.size(); + for (int i = 0; i < ndim; ++i) { + arith::IntSet& int_set = lhs->at(i); + int_set = arith::Union({int_set, rhs.at(i)}); + } +} + +Array NDIntSet2Region(const NDIntSet& nd_int_set) { + Integer one(1); + Array result; + result.reserve(nd_int_set.size()); + for (const arith::IntSet& int_set : nd_int_set) { + PrimExpr min = int_set.min(); + PrimExpr max = int_set.max(); + result.push_back(Range(/*begin=*/min, /*end=*/max + one)); + } + return result; +} + +NDIntSet NDIntSetFromShape(const Array& shape) { + NDIntSet result; + for (const PrimExpr& extent : shape) { + result.push_back(IntSetFromMinExtent(Integer(0), extent)); + } + return result; +} + +NDIntSet NDIntSetEmpty(int ndim) { + return std::vector(ndim, arith::IntSet::Nothing()); +} + +bool IsThreadBound(const For& loop) { + if (loop->kind != ForKind::kThreadBinding) { + return false; + } + ICHECK(loop->thread_binding.defined()); + std::string thread_tag = loop->thread_binding.value()->thread_tag; + if (StartsWith(thread_tag, "threadIdx")) { + return true; + } + if (StartsWith(thread_tag, "vthread")) { + return true; + } + return false; +} + +bool IsReduceTempBuffer(const Buffer& buffer) { + return StartsWith(buffer->name, "normal_reduce_temp") || // + StartsWith(buffer->name, "reduce_temp"); +} + +PrimExpr BufferArea(const Buffer& buffer) { + PrimExpr area = Integer(1); + for (const PrimExpr& dim : buffer->shape) { + area = area * dim; + } + return area; +} + +class ReductionTransformer : public StmtMutator { public: - ReductionTransformer() = default; - - Stmt VisitStmt_(const BlockNode* op) override { - Block res = Downcast(StmtMutator::VisitStmt_(op)); - if (op->init) { - PrimExpr condition = Bool(true); - for (const auto& var : res->iter_vars) { - if (var->iter_type == IterVarType::kCommReduce) { - condition = And(condition, EQ(var, var->dom->min)); - } - } - Stmt init = op->init.value(); - if (!is_one(condition)) init = IfThenElse(condition, init); - res.CopyOnWrite()->body = SeqStmt::Flatten(init, op->body); - res.CopyOnWrite()->init = NullOpt; + static Stmt Transform(const PrimFunc& f) { return ReductionTransformer().VisitStmt(f->body); } + + private: + Stmt VisitStmt_(const BlockNode* block) override { + if (!block->init.defined()) { + return StmtMutator::VisitStmt_(block); } - return std::move(res); + Stmt init = RealizeInitBlock(block->init.value(), block->iter_vars); + Stmt body = VisitStmt(block->body); + ObjectPtr new_block = make_object(*block); + new_block->init = NullOpt; + new_block->body = SeqStmt::Flatten(init, body); + return Stmt(std::move(new_block)); } }; @@ -66,493 +124,550 @@ class ReductionTransformer : public StmtExprMutator { */ class LCADetector : public StmtExprVisitor { public: - explicit LCADetector(const Map& func_args) { - for (const auto& x : func_args) { - arg_buffers_.insert(x.second); - buffers_lca_[x.second] = NullValue(); + static Map> Detect(const PrimFunc& func) { + LCADetector detector; + // Buffers, who appear as arguments, do not have allocation sites + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + detector.buffer_lca_.emplace(buffer.get(), nullptr); + } + detector(func->body); + // Prepare the return + Map> buffer_lca; + for (const auto& kv : detector.buffer_lca_) { + buffer_lca.Set(GetRef(kv.first), GetRef>(kv.second)); } + return buffer_lca; } - // Update parent and depth information for each AST node - + private: void VisitStmt_(const ForNode* op) final { - Stmt n = GetRef(op); - ast_scopes_info_[n] = ScopeInfo{scope_, depth_}; - ++depth_; - std::swap(scope_, n); + int n = ancestor_loops_.size(); + for_info_.emplace(op, ForInfo{ancestor_loops_.back(), n}); + ancestor_loops_.push_back(op); StmtExprVisitor::VisitStmt_(op); - std::swap(scope_, n); - --depth_; - } - - // Update LCA when visiting BufferLoad and BufferStore - template - void VisitBuffer(T op) { - Buffer buffer = op->buffer; - ObjectRef n = GetRef(op); - ast_scopes_info_[n] = ScopeInfo{scope_, depth_}; - // No need to update LCA if the buffer is in the func args (function input/output buffer) - if (arg_buffers_.count(buffer)) return; - if (buffers_lca_.count(buffer)) { - buffers_lca_[buffer] = LowestCommonAncestor(GetRef(op), buffers_lca_[buffer]); - } else { - buffers_lca_[buffer] = GetRef(op); - } + ancestor_loops_.pop_back(); } void VisitExpr_(const BufferLoadNode* op) final { - VisitBuffer(op); + CalcBufferLCA(op->buffer.get()); StmtExprVisitor::VisitExpr_(op); } + void VisitStmt_(const BufferStoreNode* op) final { - VisitBuffer(op); + CalcBufferLCA(op->buffer.get()); StmtExprVisitor::VisitStmt_(op); } - /*! \brief The map from Buffer to its LCA Stmt/Expr */ - std::unordered_map buffers_lca_; - /*! \brief The Buffer in function args */ - std::unordered_set arg_buffers_; + void CalcBufferLCA(const BufferNode* buffer) { + const ForNode*& lca = buffer_lca_[buffer]; + lca = LowestCommonAncestor(lca, ancestor_loops_.back()); + } + + const ForNode* LowestCommonAncestor(const ForNode* lhs, const ForNode* rhs) const { + while (lhs != nullptr && rhs != nullptr && lhs != rhs) { + auto it_l = for_info_.find(lhs); + auto it_r = for_info_.find(rhs); + ICHECK(it_l != for_info_.end()); + ICHECK(it_r != for_info_.end()); + const ForInfo& l = it_l->second; + const ForInfo& r = it_r->second; + if (l.depth == r.depth) { + lhs = l.parent_loop; + rhs = r.parent_loop; + } else if (l.depth < r.depth) { + rhs = r.parent_loop; + } else { + lhs = l.parent_loop; + } + } + if (lhs == nullptr) { + return rhs; + } + if (rhs == nullptr) { + return lhs; + } + return lhs; + } - private: /*! \brief The AST node information for querying LCA */ - struct ScopeInfo { + struct ForInfo { // The parent loop node - Stmt parent_scope; + const ForNode* parent_loop; // The scope depth in the AST - size_t depth; + int depth; }; /*! \brief The current scope initializing with Null */ - Stmt scope_{NullValue()}; - /*! \brief The current DFS depth */ - size_t depth_{0}; + std::vector ancestor_loops_ = {nullptr}; /*! \brief The parent and depth info of each Loop/BufferLoad/BufferStore Node */ - std::unordered_map ast_scopes_info_; - - ObjectRef LowestCommonAncestor(ObjectRef lhs, ObjectRef rhs) { - if (!lhs.defined() || !rhs.defined()) return NullValue(); - ICHECK(ast_scopes_info_.count(lhs)); - ICHECK(ast_scopes_info_.count(rhs)); - while (ast_scopes_info_[lhs].depth > ast_scopes_info_[rhs].depth) { - lhs = ast_scopes_info_[lhs].parent_scope; - } - while (ast_scopes_info_[lhs].depth < ast_scopes_info_[rhs].depth) { - rhs = ast_scopes_info_[rhs].parent_scope; - } - while (!lhs.same_as(rhs)) { - lhs = ast_scopes_info_[lhs].parent_scope; - rhs = ast_scopes_info_[rhs].parent_scope; - } - return lhs; - } + std::unordered_map for_info_ = {}; + /*! \brief The map from Buffer to its LCA Stmt/Expr */ + std::unordered_map buffer_lca_ = {}; }; -/*! - * \brief Gather the used region of each buffers. - */ -class RegionGatherer : public StmtExprVisitor { +class BufferAccessRewriter : public StmtExprMutator { public: - RegionGatherer( - const std::unordered_map& buffers_lca, - const Map& func_args) - : buffers_lca_(buffers_lca) { - for (const auto& arg : func_args) { - std::vector region; - for (const auto& size : arg.second->shape) { - region.push_back(arith::IntSet::FromRange(Range::FromMinExtent(0, size))); - } - buffers_region_[arg.second] = region; - } - } + using FRewriteBufferAccess = std::function* indices)>; - void VisitStmt_(const ForNode* op) final { - auto loop = GetRef(op); - loop_stack_.push_back(loop); - if (op->annotations.empty() && is_one(op->extent)) { - unit_loops_[op->loop_var.get()] = op->min; - } - StmtExprVisitor::VisitStmt_(op); - loop_stack_.pop_back(); + static Stmt Rewrite(Stmt stmt, const FRewriteBufferAccess& f_rewrite) { + BufferAccessRewriter rewriter(f_rewrite); + return rewriter.VisitStmt(stmt); } - void VisitStmt_(const BlockRealizeNode* op) final { - const auto* block_op = op->block.as(); - for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { - const auto& iter = block_op->iter_vars[i]; - const auto& v = op->binding_values[i]; - block_var_[iter->var.get()] = Substitute(Substitute(v, block_var_), unit_loops_); - } - StmtExprVisitor::VisitStmt_(op); - } + private: + explicit BufferAccessRewriter(const FRewriteBufferAccess& f_rewrite) : f_rewrite_(f_rewrite) {} - void VisitStmt_(const BlockNode* op) final { - for (const auto& buffer_region : op->reads) { - VisitBufferRegion(buffer_region); - } - for (const auto& buffer_region : op->writes) { - VisitBufferRegion(buffer_region); - } - for (const auto& alloc_buf : op->alloc_buffers) { - std::vector empty_region(alloc_buf->shape.size(), arith::IntSet::Nothing()); - // Initialize the buffer region with empty region. - buffers_region_[alloc_buf] = empty_region; - } - StmtExprVisitor::VisitStmt_(op); + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + BufferStoreNode* op = store.CopyOnWrite(); + f_rewrite_(&op->buffer, &op->indices); + return store; } - /*! \brief The used region of each Buffer */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> - buffers_region_; - /*! \brief The map from block vars to the expr value */ - std::unordered_map block_var_; - /*! \brief The map from unit loop vars to the expr value */ - std::unordered_map unit_loops_; - - private: - const std::unordered_map& buffers_lca_; - - /*! \brief The loops from the current node up to the root */ - std::vector loop_stack_; - - void VisitBufferRegion(const BufferRegion& buffer_region) { - auto it = buffers_region_.find(buffer_region->buffer); - ICHECK(it != buffers_region_.end()); - const auto& region = GatherRegion(buffer_region); - auto& buffer_new_region = it->second; - ICHECK_EQ(buffer_new_region.size(), region.size()); - for (size_t i = 0; i < region.size(); ++i) { - buffer_new_region[i] = arith::Union({buffer_new_region[i], region[i]}); - } + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + BufferLoadNode* op = load.CopyOnWrite(); + f_rewrite_(&op->buffer, &op->indices); + return load; } - /*! - * \brief Gather used buffer region - */ - std::vector GatherRegion(const BufferRegion& buffer_region) { - std::unordered_map dom_map; - auto it = buffers_lca_.find(buffer_region->buffer); - ICHECK(it != buffers_lca_.end()); - const auto& lca = it->second; - // Every loop will be relaxed if the lca is the root - bool need_relax = !lca.defined(); - for (size_t i = 0; i < loop_stack_.size(); ++i) { - const For& loop = loop_stack_[i]; - const VarNode* var = loop->loop_var.get(); - if (need_relax || (buffer_region->buffer->scope == "shared" && IsThreadBinded(loop))) { - dom_map[var] = arith::IntSet::FromRange(Range::FromMinExtent(loop->min, loop->extent)); - } - if (loop.same_as(lca)) need_relax = true; + Stmt VisitStmt_(const BlockNode* _op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(_op)); + BlockNode* op = block.CopyOnWrite(); + ArrayNode* reads = op->reads.CopyOnWrite(); + for (int i = 0, n = reads->size(); i < n; ++i) { + BufferRegion buffer_region = Downcast(reads->at(i)); + BufferRegionNode* p = buffer_region.CopyOnWrite(); + f_rewrite_(&p->buffer, nullptr); } - std::vector region; - for (const auto& range : buffer_region->region) { - Range r = - Range::FromMinExtent(Substitute(Substitute(range->min, block_var_), unit_loops_), - Substitute(Substitute(range->extent, block_var_), unit_loops_)); - region.push_back(arith::EvalSet(r, dom_map)); + ArrayNode* writes = op->writes.CopyOnWrite(); + for (int i = 0, n = writes->size(); i < n; ++i) { + BufferRegion buffer_region = Downcast(writes->at(i)); + BufferRegionNode* p = buffer_region.CopyOnWrite(); + f_rewrite_(&p->buffer, nullptr); } - return region; + return block; } - static bool IsThreadBinded(const For& loop) { - if (loop->kind != ForKind::kThreadBinding || !loop->thread_binding.defined()) return false; - std::string thread_tag = loop->thread_binding.value()->thread_tag; - return (thread_tag.substr(0, 9) == "threadIdx" || thread_tag.substr(0, 7) == "vthread"); - } + const FRewriteBufferAccess& f_rewrite_; }; /*! - * \brief Transform multi-dimension BufferLoad/BufferStore into one-dimension Load/Store + * \brief Alloc the used region of each buffers. */ -class BufferFlattener : public StmtExprMutator { +class BufferAllocator : public StmtExprMutator { public: - BufferFlattener( - const std::unordered_map& block_var, - const std::unordered_map& unit_loops, - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& - buffers_region, - const std::unordered_map& buffers_lca, - const std::unordered_set& arg_buffers) - : buffers_region_(buffers_region), - block_var_(block_var), - unit_loops_(unit_loops), - buffers_lca_(buffers_lca), - arg_buffers_(arg_buffers) {} - - Stmt VisitStmt(const Stmt& stmt) override { - Stmt body = StmtMutator::VisitStmt(stmt); + static Stmt Alloc(const PrimFunc& f) { + Map> buffer_lca = LCADetector::Detect(f); + SMap buffer_info; + SMap, Array> loop_allocs; + buffer_info.reserve(buffer_lca.size()); + loop_allocs.reserve(buffer_lca.size()); + for (const auto& kv : buffer_lca) { + const Buffer& buffer = kv.first; + const Optional& alloc_site = kv.second; + int ndim = buffer->shape.size(); + buffer_info.emplace(buffer, BufferInfo(ndim, alloc_site)); + loop_allocs[alloc_site].push_back(buffer); + } + for (const auto& kv : f->buffer_map) { + const Buffer& buffer = kv.second; + ICHECK(buffer_info.count(buffer)); + BufferInfo& info = buffer_info.at(buffer); + info.accessed_region = NDIntSetFromShape(buffer->shape); + info.alloc_site = NullOpt; + info.region = NDIntSet2Region(info.accessed_region); + info.new_buffer = buffer; + info.is_arg = true; + } + BufferAllocator alloc(std::move(buffer_info), std::move(loop_allocs)); + Stmt stmt = alloc.VisitStmt(f->body); + stmt = BufferAccessRewriter::Rewrite( + /*stmt=*/std::move(stmt), + /*f_rewrite=*/std::bind(&BufferAllocator::RewriteBufferAccess, + &alloc, // + std::placeholders::_1, // + std::placeholders::_2)); + return stmt; + } + + private: + struct BufferInfo { + NDIntSet accessed_region; + Optional alloc_site; + Array region; + Buffer new_buffer; + bool is_arg; + + explicit BufferInfo(int ndim, Optional alloc_site) + : accessed_region(NDIntSetEmpty(ndim)), + alloc_site(std::move(alloc_site)), + region{nullptr}, + new_buffer{nullptr}, + is_arg(false) {} + }; + + explicit BufferAllocator(SMap buffer_info, + SMap, Array> loop_allocs) + : block_nest_depth_(0), + buffer_info_(std::move(buffer_info)), + loop_allocs_(std::move(loop_allocs)), + ancestor_loops_{}, + var_substitutes_{}, + reduction_loop_vars_{} {} + + Stmt VisitStmt_(const ForNode* loop) final { + // Step 1. Handle block vars in `min` and `extent` + PrimExpr min = this->VisitExpr(loop->min); + PrimExpr extent = this->VisitExpr(loop->extent); + // Step 2. Handle unit loops + if (is_one(extent)) { + var_substitutes_[loop->loop_var] = min; + } + // Step 3. Visit recursively + ancestor_loops_.push_back(GetRef(loop)); + Stmt body = this->VisitStmt(loop->body); + ancestor_loops_.pop_back(); + // Step 4. Add allocation + Array alloc_buffers = AllocBufferUnderLoop(GetRef(loop)); + if (!alloc_buffers.empty()) { + body = BlockRealize(/*binding_values=*/{}, + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, // + /*reads=*/{}, // + /*writes=*/{}, // + /*alloc_buffers=*/std::move(alloc_buffers), // + /*annotations=*/{}, // + /*match_buffers=*/{}, // + /*exec_scope=*/"", // + /*name_hint=*/"alloc", // + /*body=*/std::move(body), // + /*init=*/NullOpt)); + } + // Step 5. Make the new loop + if (loop->kind == ForKind::kThreadBinding && reduction_loop_vars_.count(loop->loop_var)) { + // do nothing, because the loop is going to be removed + } else { + body = For(/*loop_var=*/loop->loop_var, + /*min=*/min, + /*extent=*/extent, + /*kind=*/loop->kind, + /*body=*/std::move(body), + /*thread_binding=*/loop->thread_binding, + /*annotations=*/loop->annotations); + } return body; } - Stmt VisitStmt_(const SeqStmtNode* op) final { - Array seq; - for (const Stmt& stmt : op->seq) { - std::unordered_set double_buffer; - std::swap(double_buffer, double_buffer_); - Stmt body = VisitStmt(stmt); - std::swap(double_buffer, double_buffer_); - - for (const Buffer& buffer : double_buffer) { - ObjectRef lca = buffers_lca_.at(buffer); - if (lca.defined() && lca.same_as(parent_scope_)) { - body = AttrStmt(buffer->data, tir::attr::double_buffer_scope, 1, body); - } else { - double_buffer_.insert(buffer); + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + const auto* block = realize->block.get(); + ICHECK(!block->init.defined()); + // Step 1. Update "block vars => loop vars" for substitution, add reduction loop vars + ICHECK_EQ(block->iter_vars.size(), realize->binding_values.size()); + for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { + IterVar block_var = block->iter_vars[i]; + PrimExpr v = this->VisitExpr(realize->binding_values[i]); + var_substitutes_.emplace(block_var->var, v); + if (block_var->iter_type == kCommReduce) { + for (const VarNode* var : Vars(v)) { + this->reduction_loop_vars_.insert(GetRef(var)); } } - - seq.push_back(body); } + // Step 2. Visit recursively + ++block_nest_depth_; + Stmt body = this->VisitStmt(block->body); + --block_nest_depth_; + // Step 3. Update the read/write buffer regions + Array reads = VisitBufferRegions(block->reads); + Array writes = VisitBufferRegions(block->writes); + // Step 4. Handle predicate + PrimExpr predicate = this->VisitExpr(realize->predicate); + // Step 5. Root allocation + Array alloc_buffers = + (block_nest_depth_ == 0) ? AllocBufferUnderLoop(NullOpt) : Array{}; + // Step 6. Create new blocks + return BlockRealize(/*binding_values=*/{}, + /*predicate=*/std::move(predicate), + /*block=*/ + Block(/*iter_vars=*/{}, // + /*reads=*/std::move(reads), // + /*writes=*/std::move(writes), // + /*alloc_buffers=*/std::move(alloc_buffers), // + /*annotations=*/block->annotations, // + /*match_buffers=*/block->match_buffers, // + /*exec_scope=*/block->exec_scope, // + /*name_hint=*/block->name_hint, // + /*body=*/std::move(body), // + /*init=*/NullOpt)); + } - return SeqStmt(seq); + PrimExpr VisitExpr_(const VarNode* var) final { + auto it = var_substitutes_.find(GetRef(var)); + if (it != var_substitutes_.end()) { + return it->second; + } + return GetRef(var); } - Stmt VisitStmt_(const BlockRealizeNode* op) final { - // Handle allocations - const auto* block_op = op->block.as(); - Stmt old_stmt = GetRef(block_op); - ICHECK(block_op != nullptr); - for (size_t i = block_op->alloc_buffers.size(); i > 0; --i) { - const auto& buffer = block_op->alloc_buffers[i - 1]; - const std::string name = std::string(buffer->name); - if (name.substr(0, 18) == "normal_reduce_temp" || name.substr(0, 11) == "reduce_temp") { - continue; - } - if (buffers_lca_.at(buffer).defined()) { - pending_allocate_[buffer] = block_op->alloc_buffers[i - 1]; + Array VisitBufferRegions(const Array& buffer_regions) { + // Calculate `new_buffer_regions` by recursively visiting min/extent of each range + Array new_buffer_regions; + new_buffer_regions.reserve(buffer_regions.size()); + for (const BufferRegion& buffer_region : buffer_regions) { + const Buffer& buffer = buffer_region->buffer; + const Array& region = buffer_region->region; + Array new_region; + new_region.reserve(region.size()); + for (const Range& range : region) { + new_region.push_back(Range::FromMinExtent(/*min=*/this->VisitExpr(range->min), + /*extent=*/this->VisitExpr(range->extent))); } + new_buffer_regions.push_back(BufferRegion(buffer, new_region)); } - for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { - const IterVar& block_var = block_op->iter_vars[i]; - const PrimExpr& binding_value = op->binding_values[i]; - ICHECK(block_var.as()); - ICHECK(binding_value.as()); - - if (block_var->iter_type == kCommReduce) { - PreOrderVisit(binding_value, [this](const ObjectRef& node) { - if (const auto* var = node.as()) { - this->reduction_relative_.insert(GetRef(var)); - return false; + // Calculate `info.accessed_region` + for (const BufferRegion& buffer_region : new_buffer_regions) { + const Buffer& buffer = buffer_region->buffer; + ICHECK(buffer_info_.count(buffer)); + BufferInfo& info = buffer_info_.at(buffer); + if (info.is_arg) { + continue; + } + std::unordered_map dom_map; + { + const Object* alloc_site = info.alloc_site.get(); + // Every loop will be relaxed if the lca is the root + bool need_relax = (alloc_site == nullptr); + for (const For& loop : this->ancestor_loops_) { + const VarNode* loop_var = loop->loop_var.get(); + if (need_relax || (buffer->scope == "shared" && IsThreadBound(loop))) { + // TODO + dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); + } + if (loop.get() == alloc_site) { + need_relax = true; } - return true; - }); + } } - } - // visit body - Stmt parent_scope = op->block; - std::swap(parent_scope, parent_scope_); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - std::swap(parent_scope, parent_scope_); - op = stmt.as(); - ICHECK(op != nullptr); - block_op = op->block.as(); - ICHECK(block_op != nullptr); - Stmt body = block_op->body; - // Handle block predicate - if (!is_one(op->predicate)) { - body = IfThenElse(op->predicate, body); - } - - for (const auto& anno : block_op->annotations) { - if (anno.first == tir::attr::double_buffer_scope && is_one(Downcast(anno.second))) { - ICHECK_EQ(block_op->writes.size(), 1); - double_buffer_.insert(block_op->writes[0]->buffer); + NDIntSet int_set; + int_set.reserve(buffer_region->region.size()); + for (const Range& range : buffer_region->region) { + int_set.push_back(arith::EvalSet(range, dom_map)); } + NDIntSetUnionWith(&info.accessed_region, int_set); } + return new_buffer_regions; + } - for (size_t i = block_op->alloc_buffers.size(); i > 0; --i) { - const auto& alloc_buf = block_op->alloc_buffers[i - 1]; - const std::string name = std::string(alloc_buf->name); - if (name.substr(0, 18) == "normal_reduce_temp" || name.substr(0, 11) == "reduce_temp") { + Array AllocBufferUnderLoop(const Optional& loop) { + auto it = loop_allocs_.find(loop); + if (it == loop_allocs_.end()) { + return {}; + } + const Array& buffers = it->second; + Array result; + result.reserve(buffers.size()); + for (const Buffer& buffer : buffers) { + ICHECK(buffer_info_.count(buffer)); + BufferInfo& info = buffer_info_.at(buffer); + if (info.is_arg) { + ICHECK(info.region.defined()); + ICHECK(info.new_buffer.defined()); continue; + } else { + ICHECK(!info.region.defined()); + ICHECK(!info.new_buffer.defined()); } - if (!buffers_lca_.at(alloc_buf).defined() || buffers_lca_.at(alloc_buf).same_as(old_stmt)) { - PrimExpr extents = 1; - for (const auto& extent : buffers_region_.at(alloc_buf)) { - extents *= extent.max() - extent.min() + 1; - } - body = Allocate(alloc_buf->data, alloc_buf->dtype, {extents}, const_true(), body); - - // Change empty scope into global - std::string scope = alloc_buf->scope.empty() ? "global" : alloc_buf->scope; - body = AttrStmt(alloc_buf->data, attr::storage_scope, StringImm(scope), body); + // Calculate `info.region` + info.region = NDIntSet2Region(info.accessed_region); + // Calculate `info.new_buffer` + Array shape; + shape.reserve(info.region.size()); + for (const Range& range : info.region) { + shape.push_back(range->extent); } + ObjectPtr new_buffer = make_object(*buffer.get()); + new_buffer->shape = std::move(shape); + info.new_buffer = Buffer(std::move(new_buffer)); + result.push_back(info.new_buffer); } - - return body; + return result; } - PrimExpr VisitExpr_(const VarNode* op) final { - // Replace the block var with its value - auto it = block_var_.find(op); - if (it != block_var_.end()) { - return Substitute(it->second, unit_loops_); - } else { - return Substitute(GetRef(op), unit_loops_); + void RewriteBufferAccess(Buffer* buffer, Array* indices) const { + ICHECK(buffer_info_.count(*buffer)); + const BufferInfo& info = buffer_info_.at(*buffer); + ICHECK(info.new_buffer.defined()); + if (indices == nullptr) { + *buffer = info.new_buffer; + return; + } + ICHECK(info.region.defined()); + // TODO: the ndim could be changed by tensorize, more investigation is needed + ICHECK_GE(indices->size(), info.region.size()); + int ndim = info.region.size(); + Array new_indices; + new_indices.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + new_indices.push_back((*indices)[i] - info.region[i]->min); } + *buffer = info.new_buffer; + *indices = std::move(new_indices); } - Stmt VisitStmt_(const ForNode* op) final { - Stmt old_stmt = GetRef(op); - std::swap(old_stmt, parent_scope_); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - std::swap(old_stmt, parent_scope_); - - op = stmt.as(); - ICHECK(op != nullptr); - - ForKind kind = op->kind; - if (op->kind == ForKind::kThreadBinding) kind = ForKind::kSerial; - - Stmt body = op->body; - for (auto it = pending_allocate_.begin(); it != pending_allocate_.end();) { - if (old_stmt.same_as(buffers_lca_.at(it->first))) { - PrimExpr extents = 1; - const auto& alloc_buf = it->second; - for (const auto& extent : buffers_region_.at(alloc_buf)) { - extents *= extent.max() - extent.min() + 1; - } - body = Allocate(alloc_buf->data, alloc_buf->dtype, {extents}, const_true(), body); - // Change empty scope into global - std::string scope = alloc_buf->scope.empty() ? "global" : alloc_buf->scope; - body = AttrStmt(alloc_buf->data, attr::storage_scope, StringImm(scope), body); - pending_allocate_.erase(it++); - } else { - it++; - } + /*! \brief Number of blocks nested in the ancestor during visiting */ + int block_nest_depth_; + /*! \brief Collective information about each buffer */ + SMap buffer_info_; + /*! \brief Buffers allocated at each for loop */ + SMap, Array> loop_allocs_; + /*! \brief The loops from the current node up to the root */ + std::vector ancestor_loops_; + /*! \brief The map from block vars to the expr value */ + SMap var_substitutes_; + /*! \brief Loop variables that are bound to reduction block vars */ + SSet reduction_loop_vars_; +}; + +/*! + * \brief Transform multi-dimension BufferLoad/BufferStore into one-dimension Load/Store + */ +class BufferFlattener : public StmtExprMutator { + public: + static Stmt Flatten(const PrimFunc& f) { return BufferFlattener().VisitStmt(f->body); } + + private: + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + ICHECK(realize->binding_values.empty()); + // Step 1. Visit the body + Block new_block = Downcast(this->VisitStmt(realize->block)); + PrimExpr predicate = this->VisitExpr(realize->predicate); + const BlockNode* block = new_block.get(); + // Step 2. Transform the `predicate` to if-then-else + Stmt body = block->body; + if (!is_one(predicate)) { + body = IfThenElse(predicate, body); } + // Step 3. Pick out blocks that writes with double buffering + if (IsDoubleBufferScope(block->annotations)) { + ICHECK_EQ(block->writes.size(), 1); + const Buffer& write = block->writes[0]->buffer; + double_buffered_.insert(write); + } + // Step 4. Handle allocations + for (const Buffer& buffer : block->alloc_buffers) { + body = MakeAllocStmt(buffer, body, double_buffered_.count(buffer)); + } + return body; + } - Stmt for_stmt; + Stmt VisitStmt_(const ForNode* op) final { + // Step 1. Visit recursively + PrimExpr min = this->VisitExpr(op->min); + PrimExpr extent = this->VisitExpr(op->extent); + Stmt body = this->VisitStmt(op->body); + // Step 2. Add the for loop accordingly if (op->kind == ForKind::kThreadBinding) { + // Case 1. Thread binding ICHECK(op->thread_binding.defined()); String thread_tag = op->thread_binding.value()->thread_tag; - if (!reduction_relative_.count(op->loop_var)) { - for_stmt = AttrStmt(IterVar(Range(op->min, op->extent), op->loop_var, - IterVarType::kThreadIndex, thread_tag), - thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent, - op->extent, body); - } else { - for_stmt = body; - } - } else if (is_one(op->extent) && op->annotations.empty()) { + body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); + } else if (is_one(extent) && op->annotations.empty()) { + // Case 2. Handle unit loop return body; } else { - for_stmt = For(op->loop_var, op->min, op->extent, op->kind, body); + // Case 3. An ordinary loop + body = For(op->loop_var, min, extent, op->kind, body); } - + // Step 3. Handle annotations for (const auto& annotation : op->annotations) { - if (attr::IsPragmaKey(annotation.first)) { - for_stmt = AttrStmt(op->loop_var, annotation.first, Downcast(annotation.second), - for_stmt); + const String& ann_key = annotation.first; + const ObjectRef& ann_value = annotation.second; + if (attr::IsPragmaKey(ann_key)) { + body = AttrStmt(op->loop_var, ann_key, Downcast(ann_value), body); } } - - return for_stmt; + return body; } - Stmt VisitStmt_(const AttrStmtNode* op) final { return StmtMutator::VisitStmt_(op); } - Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op != nullptr); - auto begins = ComputeRelativeIndices(op->buffer, op->indices); - Buffer new_buffer = ReshapeBuffer(op->buffer, this->buffers_region_.at(op->buffer)); - return new_buffer.vstore(begins, op->value); + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + op = store.get(); + return op->buffer.vstore(op->indices, op->value); } PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - auto begins = ComputeRelativeIndices(op->buffer, op->indices); - Buffer new_buffer = ReshapeBuffer(op->buffer, this->buffers_region_.at(op->buffer)); - return new_buffer.vload(begins, op->dtype); + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + op = load.get(); + return op->buffer.vload(op->indices, op->dtype); } PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::get_elem_offset())) { + // Handle `get_elem_offset` ICHECK_EQ(op->args.size(), 1); - const auto* buffer_load = op->args[0].as(); - ICHECK(buffer_load != nullptr); - Load load = Downcast(VisitExpr(op->args[0])); + PrimExpr arg = op->args[0]; + ICHECK(arg->IsInstance()); + arg = this->VisitExpr(arg); + const auto* load = TVM_TYPE_AS(load, arg, LoadNode); return load->index; - } else { - return StmtExprMutator::VisitExpr_(op); } + return StmtExprMutator::VisitExpr_(op); } - private: - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& - buffers_region_; - const std::unordered_map& block_var_; - const std::unordered_map& unit_loops_; - const std::unordered_map& buffers_lca_; - const std::unordered_set& arg_buffers_; - - std::unordered_map pending_allocate_; - std::unordered_set reduction_relative_; - std::unordered_set double_buffer_; - Stmt parent_scope_; - - /*! - * \brief Create a buffer with alternative shape - */ - Buffer ReshapeBuffer(const Buffer& buffer, const std::vector& region) { - if (arg_buffers_.count(buffer)) return buffer; - auto n = runtime::make_object(*(buffer.operator->())); - Array shape; - for (const auto& i : region) { - shape.push_back(i.max() - i.min() + 1); + static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body, bool is_double_buffer) { + if (IsReduceTempBuffer(buffer)) { + return body; } - n->shape = std::move(shape); - return Buffer(n); - } - - /*! - * \brief Transform indices from the absolute indices to relative indices - * \note T can be BufferLoad or BufferStore - */ - std::vector ComputeRelativeIndices(const Buffer& buffer, - const Array& indices) { - auto it = buffers_region_.find(buffer); - ICHECK(it != buffers_region_.end()); - const auto& region = it->second; - std::vector new_indices; - for (size_t i = 0; i < region.size(); ++i) { - if (arg_buffers_.count(buffer)) { - new_indices.push_back(indices[i]); - } else { - new_indices.push_back(indices[i] - region[i].min()); + String storage_scope = buffer->scope; + if (storage_scope.empty()) { + storage_scope = "global"; + } + PrimExpr area = BufferArea(buffer); + body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), body); + body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), body); + if (is_double_buffer) { + body = AttrStmt(buffer->data, attr::double_buffer_scope, Integer(1), body); + } + return body; + } + + static Stmt MakeLaunchThread(const PrimExpr& min, const PrimExpr& extent, const Var& var, + const String& thread_tag, Stmt body) { + IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), + /*var=*/var, + /*iter_type=*/IterVarType::kThreadIndex, + /*thread_tag=*/thread_tag); + String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; + body = AttrStmt(iter_var, attr_key, extent, body); + return body; + } + + static bool IsDoubleBufferScope(const Map& annotations) { + if (Optional ann_value = annotations.Get(attr::double_buffer_scope)) { + const auto* value = TVM_TYPE_AS(value, ann_value, PrimExprNode); + if (is_one(GetRef(value))) { + return true; } } - return new_indices; + return false; } + + SSet double_buffered_; }; PrimFunc BufferFlatten(PrimFunc f) { - auto fptr = f.CopyOnWrite(); - - // Check memory and execution hierarchy + PrimFuncNode* fptr = f.CopyOnWrite(); + // Step 0. Check memory and execution hierarchy VerifyExecScope(f); - - // Transform the reduction calls to BufferStore - ReductionTransformer reduction_transformer; - fptr->body = reduction_transformer(fptr->body); - - // Find the LCA of each Buffer access - LCADetector lca_detector(fptr->buffer_map); - lca_detector(fptr->body); - - // Recalculate the buffer region - RegionGatherer region_gatherer(lca_detector.buffers_lca_, fptr->buffer_map); - region_gatherer(fptr->body); - - // Transform BufferLoad/BufferStore into Load/Store - BufferFlattener flattener(region_gatherer.block_var_, region_gatherer.unit_loops_, - region_gatherer.buffers_region_, lca_detector.buffers_lca_, - lca_detector.arg_buffers_); - fptr->body = flattener(fptr->body); - + // Step 1. Transform the reduction calls to BufferStore + fptr->body = ReductionTransformer::Transform(f); + // Step 2. Recalculate the buffer region + fptr->body = BufferAllocator::Alloc(f); + // Step 3. Transform BufferLoad/BufferStore into Load/Store + fptr->body = BufferFlattener::Flatten(f); return f; }