Skip to content

Commit

Permalink
Lift the definition of InsertCacheStage
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jan 19, 2022
1 parent 686eca2 commit 3bc6181
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 24 deletions.
30 changes: 10 additions & 20 deletions src/tir/transforms/memhammer_intermediate_stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ std::pair<Stmt, For> LiftThreadBindingLoops(Stmt stmt) {
}
body = loop->body;
}
body = CopyLoopChain(normal_loops, body);
body = CopyLoopChain(normal_loops, std::move(body));
For compute_location;
body = CopyLoopChain(thread_binding_loops, body,
body = CopyLoopChain(thread_binding_loops, std::move(body),
static_cast<int>(thread_binding_loops.size()) - 1, &compute_location);

return std::make_pair(body, compute_location);
Expand Down Expand Up @@ -85,7 +85,7 @@ class IndexPatternFinder : public ExprVisitor {
* \param rewrite_indices The access indices after rank promotion
* \return The new buffer shape after rank promotion.
*/
static Array<Array<PrimExpr>> getRankPromotedShape(Array<PrimExpr> indices,
static Array<Array<PrimExpr>> GetRankPromotedShape(Array<PrimExpr> indices,
const Map<Var, Range>& var_range,
Array<PrimExpr>* rewrite_indices) {
Map<Var, arith::IntSet> var_dom = AsIntSet(var_range);
Expand Down Expand Up @@ -169,7 +169,7 @@ class RankPromoter : public StmtExprMutator {
/*!
* \brief Flatten the buffer shape like performing inverse rank promotion.
* For example, [[i0, i1], [j0, j1]] to [i0 * i1, j0 * j1]
* \param new_shape The buffer shape in the special form as returned by getRankPromotedShape
* \param new_shape The buffer shape in the special form as returned by GetRankPromotedShape
* \return The buffer shape after flatten
*/
static Array<PrimExpr> FlattenNewShape(const Array<Array<PrimExpr>>& new_shape) {
Expand Down Expand Up @@ -271,7 +271,7 @@ class RankPromoter : public StmtExprMutator {
/*!
* \brief Rewrite the indices after performing buffer rank promotion +
* buffer compacting + buffer flattening.
* \param indices The origina indices
* \param indices The original indices
* \return The indices after these transformations
*/
Array<PrimExpr> ConvertIndices(const Array<PrimExpr>& indices) {
Expand All @@ -290,20 +290,9 @@ class RankPromoter : public StmtExprMutator {
Array<Range> relaxed_region_;
};

/*!
* \brief Insert a cache stage to the compute location
* \param stmt the stmt
* \param is_write_cache whether to write a read cache or write cache
* \param storage_scope the storage scope of the new cache
* \param compute_location the compute location.
* \param outer_loops the outer loops of this stmt
* \param alloc_buffer the new cache block
* \return a pair. The first is the stmt after transformation.
* The second is the SeqStmt that contains 2 stages (one original and another inserted).
*/
std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope,
For compute_location, const Array<For>& outer_loops,
Buffer* alloc_buffer) {
Optional<For> compute_location,
const Array<For>& outer_loops, Buffer* alloc_buffer) {
Stmt body = stmt;
std::vector<const ForNode*> loops;
bool need_relax = !compute_location.defined();
Expand Down Expand Up @@ -340,13 +329,14 @@ std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String
}

const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode);
// TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate
const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode);
Buffer orig_buffer = is_write_cache ? buf_store->buffer : buf_load->buffer;
Array<PrimExpr> indices = is_write_cache ? buf_store->indices : buf_load->indices;
// Step 1.2 get the new shape and new access indices after rank promotion
Array<PrimExpr> rewrite_indices;
Array<Array<PrimExpr>> new_shape =
IndexPatternFinder::getRankPromotedShape(indices, all_var_range, &rewrite_indices);
IndexPatternFinder::GetRankPromotedShape(indices, all_var_range, &rewrite_indices);
// Step 2. relax the access region after rank promotion
Region relaxed_region;
auto relax_var_intset = AsIntSet(relax_var_range);
Expand Down Expand Up @@ -404,7 +394,7 @@ std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String
// Step 3.3 rewrite the original body to load from cache
Stmt rewrite_body;
if (compute_location.defined()) {
rewrite_body = compute_location->body;
rewrite_body = compute_location.value()->body;
} else {
rewrite_body = stmt;
}
Expand Down
15 changes: 15 additions & 0 deletions src/tir/transforms/memhammer_rewrite_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,20 @@ class WmmaToShared : public RewriteRule {
}
};

/*!
* \brief Insert a cache stage to the compute location
* \param stmt the stmt
* \param is_write_cache whether to write a read cache or write cache
* \param storage_scope the storage scope of the new cache
* \param compute_location the compute location.
* \param outer_loops the outer loops of this stmt
* \param alloc_buffer the new cache block
* \return a pair. The first is the stmt after transformation.
* The second is the SeqStmt that contains 2 stages (one original and another inserted).
*/
std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope,
Optional<For> compute_location,
const Array<For>& outer_loops, Buffer* alloc_buffer);

} // namespace tir
} // namespace tvm
4 changes: 0 additions & 4 deletions src/tir/transforms/memhammer_tensorcore_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ class WmmaToGlobalRewriter : public StmtExprMutator {
const ConstraintSet& constraints_;
};

std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope,
For compute_location, const Array<For>& outer_loops,
Buffer* alloc_buffer);

Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints,
OutputSet* output) const {
Stmt body;
Expand Down

0 comments on commit 3bc6181

Please sign in to comment.