diff --git a/src/tir/transforms/memhammer_intermediate_stage.cc b/src/tir/transforms/memhammer_intermediate_stage.cc index 3503b371bdc87..1b339620d3a5a 100644 --- a/src/tir/transforms/memhammer_intermediate_stage.cc +++ b/src/tir/transforms/memhammer_intermediate_stage.cc @@ -53,9 +53,9 @@ std::pair LiftThreadBindingLoops(Stmt stmt) { } body = loop->body; } - body = CopyLoopChain(normal_loops, body); + body = CopyLoopChain(normal_loops, std::move(body)); For compute_location; - body = CopyLoopChain(thread_binding_loops, body, + body = CopyLoopChain(thread_binding_loops, std::move(body), static_cast(thread_binding_loops.size()) - 1, &compute_location); return std::make_pair(body, compute_location); @@ -85,7 +85,7 @@ class IndexPatternFinder : public ExprVisitor { * \param rewrite_indices The access indices after rank promotion * \return The new buffer shape after rank promotion. */ - static Array> getRankPromotedShape(Array indices, + static Array> GetRankPromotedShape(Array indices, const Map& var_range, Array* rewrite_indices) { Map var_dom = AsIntSet(var_range); @@ -169,7 +169,7 @@ class RankPromoter : public StmtExprMutator { /*! * \brief Flatten the buffer shape like performing inverse rank promotion. * For example, [[i0, i1], [j0, j1]] to [i0 * i1, j0 * j1] - * \param new_shape The buffer shape in the special form as returned by getRankPromotedShape + * \param new_shape The buffer shape in the special form as returned by GetRankPromotedShape * \return The buffer shape after flatten */ static Array FlattenNewShape(const Array>& new_shape) { @@ -271,7 +271,7 @@ class RankPromoter : public StmtExprMutator { /*! * \brief Rewrite the indices after performing buffer rank promotion + * buffer compacting + buffer flattening. - * \param indices The origina indices + * \param indices The original indices * \return The indices after these transformations */ Array ConvertIndices(const Array& indices) { @@ -290,20 +290,9 @@ class RankPromoter : public StmtExprMutator { Array relaxed_region_; }; -/*! - * \brief Insert a cache stage to the compute location - * \param stmt the stmt - * \param is_write_cache whether to write a read cache or write cache - * \param storage_scope the storage scope of the new cache - * \param compute_location the compute location. - * \param outer_loops the outer loops of this stmt - * \param alloc_buffer the new cache block - * \return a pair. The first is the stmt after transformation. - * The second is the SeqStmt that contains 2 stages (one original and another inserted). - */ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, - For compute_location, const Array& outer_loops, - Buffer* alloc_buffer) { + Optional compute_location, + const Array& outer_loops, Buffer* alloc_buffer) { Stmt body = stmt; std::vector loops; bool need_relax = !compute_location.defined(); @@ -340,13 +329,14 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String } const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + // TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); Buffer orig_buffer = is_write_cache ? buf_store->buffer : buf_load->buffer; Array indices = is_write_cache ? buf_store->indices : buf_load->indices; // Step 1.2 get the new shape and new access indices after rank promotion Array rewrite_indices; Array> new_shape = - IndexPatternFinder::getRankPromotedShape(indices, all_var_range, &rewrite_indices); + IndexPatternFinder::GetRankPromotedShape(indices, all_var_range, &rewrite_indices); // Step 2. relax the access region after rank promotion Region relaxed_region; auto relax_var_intset = AsIntSet(relax_var_range); @@ -404,7 +394,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String // Step 3.3 rewrite the original body to load from cache Stmt rewrite_body; if (compute_location.defined()) { - rewrite_body = compute_location->body; + rewrite_body = compute_location.value()->body; } else { rewrite_body = stmt; } diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/tir/transforms/memhammer_rewrite_rule.h index cc074643619a0..1cb0ea496a030 100644 --- a/src/tir/transforms/memhammer_rewrite_rule.h +++ b/src/tir/transforms/memhammer_rewrite_rule.h @@ -211,5 +211,20 @@ class WmmaToShared : public RewriteRule { } }; +/*! + * \brief Insert a cache stage to the compute location + * \param stmt the stmt + * \param is_write_cache whether to write a read cache or write cache + * \param storage_scope the storage scope of the new cache + * \param compute_location the compute location. + * \param outer_loops the outer loops of this stmt + * \param alloc_buffer the new cache block + * \return a pair. The first is the stmt after transformation. + * The second is the SeqStmt that contains 2 stages (one original and another inserted). + */ +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + Optional compute_location, + const Array& outer_loops, Buffer* alloc_buffer); + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc index e97e89488c30c..c7714625c3319 100644 --- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -237,10 +237,6 @@ class WmmaToGlobalRewriter : public StmtExprMutator { const ConstraintSet& constraints_; }; -std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, - For compute_location, const Array& outer_loops, - Buffer* alloc_buffer); - Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { Stmt body;