Skip to content

Commit

Permalink
Update CreateLocalStage
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jan 20, 2022
1 parent 3bc6181 commit b27132f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
25 changes: 13 additions & 12 deletions src/tir/transforms/memhammer_intermediate_stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Stmt CopyLoopChain(const std::vector<const ForNode*> loops, const Stmt& inner_bo
* \return a pair. The first is the transformed stmt.
* The second is the lowest thread binding loop.
*/
std::pair<Stmt, For> LiftThreadBindingLoops(Stmt stmt) {
std::pair<Stmt, Optional<For>> LiftThreadBindingLoops(Stmt stmt) {
std::vector<const ForNode*> normal_loops;
std::vector<const ForNode*> thread_binding_loops;
Stmt body = stmt;
Expand All @@ -54,10 +54,9 @@ std::pair<Stmt, For> LiftThreadBindingLoops(Stmt stmt) {
body = loop->body;
}
body = CopyLoopChain(normal_loops, std::move(body));
For compute_location;
For compute_location{nullptr};
body = CopyLoopChain(thread_binding_loops, std::move(body),
static_cast<int>(thread_binding_loops.size()) - 1, &compute_location);

return std::make_pair(body, compute_location);
}

Expand Down Expand Up @@ -338,14 +337,17 @@ std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String
Array<Array<PrimExpr>> new_shape =
IndexPatternFinder::GetRankPromotedShape(indices, all_var_range, &rewrite_indices);
// Step 2. relax the access region after rank promotion
Region relaxed_region;
auto relax_var_intset = AsIntSet(relax_var_range);
arith::Analyzer analyzer;
analyzer.Bind(all_var_range);
for (const PrimExpr& index : rewrite_indices) {
auto int_set = arith::EvalSet(index, relax_var_intset);
relaxed_region.push_back(
Range::FromMinExtent(int_set.min(), analyzer.Simplify(int_set.max() - int_set.min() + 1)));
Array<Range> relaxed_region;
relaxed_region.reserve(rewrite_indices.size());
{
Map<Var, arith::IntSet> relax_var_intset = AsIntSet(relax_var_range);
for (const PrimExpr& index : rewrite_indices) {
arith::IntSet int_set = arith::EvalSet(index, relax_var_intset);
relaxed_region.push_back(Range::FromMinExtent(
int_set.min(), analyzer.Simplify(int_set.max() - int_set.min() + 1)));
}
}
// Step 3. generate the data copy bodies
// preparation work
Expand All @@ -368,8 +370,7 @@ std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String
}
// Step 3.1 create a buffer for the cache
Buffer new_buffer = WithScope(orig_buffer, storage_scope);
BufferNode* buffer_ptr = new_buffer.CopyOnWrite();
buffer_ptr->shape = RankPromoter::FlattenNewShape(relaxed_new_shape);
new_buffer.CopyOnWrite()->shape = RankPromoter::FlattenNewShape(relaxed_new_shape);
*alloc_buffer = new_buffer;
Array<PrimExpr> real_orig_buf_indices =
RankPromoter::RewriteBackIndex(orig_buf_indices, new_shape);
Expand Down Expand Up @@ -413,7 +414,7 @@ std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String
Stmt CreateLocalStage::Rewrite(const Stmt& stmt, const ConstraintSet& constraints,
OutputSet* output) const {
Stmt body;
For compute_location;
Optional<For> compute_location;
std::tie(body, compute_location) = LiftThreadBindingLoops(std::move(stmt));
Buffer cache_buffer;
Stmt after_caching = InsertCacheStage(body, false, "local", compute_location,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,4 +386,10 @@ def test_auto_padding():


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
test_coalesce_vectorize()
test_inverse()
test_local_stage()
test_rewrite_shared_to_wmma()
test_rewrite_wmma_to_shared()
test_rewrite_wmma_to_global()
test_auto_padding()

0 comments on commit b27132f

Please sign in to comment.