From 030a0cd39a9503563487fb939e081a2171c23b61 Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Sat, 21 May 2022 16:55:23 -0400 Subject: [PATCH] fix (#1) --- src/arith/ir_mutator_with_analyzer.cc | 7 -- src/arith/ir_mutator_with_analyzer.h | 1 - .../schedule_rule/multi_level_tiling.cc | 74 +++++++++---------- src/tir/schedule/analysis.h | 3 +- src/tir/schedule/analysis/analysis.cc | 1 - 5 files changed, 36 insertions(+), 50 deletions(-) diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index c1e2388b0343..9cae3b7a6ac8 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -129,13 +129,6 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { } } -Stmt IRMutatorWithAnalyzer::VisitStmt_(const BlockNode* op) { - for (const IterVar& iter_var : op->iter_vars) { - analyzer_->Bind(iter_var->var, iter_var->dom); - } - return StmtExprMutator::VisitStmt_(op); -} - PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else static auto op_if_then_else = Op::Get("tir.if_then_else"); diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index 01ff8a009314..3bd3a98a8445 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -55,7 +55,6 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override; tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override; tir::Stmt VisitStmt_(const tir::AssertStmtNode* op) override; - tir::Stmt VisitStmt_(const tir::BlockNode* op) override; PrimExpr VisitExpr_(const tir::LetNode* op) override; PrimExpr VisitExpr_(const tir::SelectNode* op) override; PrimExpr VisitExpr_(const tir::CallNode* op) override; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 173660074c6f..cac22994fdaa 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -290,8 +290,6 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional MultiLevelTilingNode::TransformWithTensorIntrin(State& state, const String& intrin_name) const { - // Optional opt_tensorize_info = GetTensorizeLoopMapping( - // sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); BlockRV block_rv = state.block_rv; Optional opt_layout_info = GetTensorizeLayoutInfo(state.sch->state(), state.sch->GetSRef(block_rv), @@ -310,62 +308,60 @@ Optional MultiLevelTilingNode::TransformWithTensorIntrin(State& state, c for (size_t i = 0; i < block->writes.size(); ++i) { buffers[block->writes[i]->buffer] = std::move(std::make_pair(i, false)); } - + // Reindex buffers and insert reindex stage state.tensor_core_reindex_store = state.sch->ReIndex(block_rv, 0, true); state.tensor_core_reindex_A = state.sch->ReIndex(block_rv, 0, false); state.tensor_core_reindex_B = state.sch->ReIndex(block_rv, 1, false); - state.sch->TransformBlockLayout(state.tensor_core_reindex_store.value(), info->mapping); - state.sch->TransformBlockLayout(state.tensor_core_reindex_A.value(), info->mapping); - state.sch->TransformBlockLayout(state.tensor_core_reindex_B.value(), info->mapping); - state.sch->TransformBlockLayout(state.block_rv, info->mapping); - - size_t offset = info->mapping->final_indices.size() - info->rhs_iters.size(); + // Transform the layout of reindex buffers accordingly + std::unordered_set unmapped_vars; + std::unordered_map representer_map; std::unordered_map tgt_iter_map; - + size_t offset = info->mapping->final_indices.size() - info->rhs_iters.size(); + ICHECK_EQ(info->lhs_iters.size(), info->mapping->initial_indices.size()); + for (size_t i = 0; i < info->lhs_iters.size(); ++i) { + representer_map[info->lhs_iters[i]->var] = info->mapping->initial_indices[i]; + } + for (size_t i = 0; i < offset; ++i) { + const tir::VarNode* var_ptr = info->mapping->final_indices[i].as(); + ICHECK(var_ptr != nullptr); + unmapped_vars.insert(Downcast(info->mapping->final_indices[i])); + } for (size_t i = offset; i < info->mapping->final_indices.size(); ++i) { tgt_iter_map[info->rhs_iters[i - offset]->var] = info->mapping->final_indices[i]; } - for (const auto& it : buffers) { // organize the mappings for buffer layout transformation const tir::Buffer& rhs_buffer = info->lhs_buffer_map[it.first]; - std::vector new_representers; - std::vector new_tgt_iters; - std::unordered_set covered; - auto collect = [&](const ObjectRef& obj) -> bool { - if (const tir::VarNode* var = obj.as()) { - covered.insert(GetRef(var)); - } - return true; - }; - // new target iters - for (const PrimExpr& index : info->lhs_indices_map[it.first]) { - tir::PreOrderVisit(index, collect); - } - for (size_t i = 0; i < offset; ++i) { - if (covered.count(info->lhs_iters[i]->var)) { - covered.insert(info->mapping->initial_indices[i]); - new_tgt_iters.push_back(info->mapping->final_indices[i]); + std::vector sub_representers; + std::vector sub_target_iters; + // Refresh block sref and handler + block_sref = state.sch->GetSRef(state.block_rv); + block = TVM_SREF_TO_BLOCK(block, block_sref); + const tir::BufferRegion& region = it.second.second ? block->reads[it.second.first] : block->writes[it.second.first]; + for (const Range& range : region->region) { + ICHECK(tir::is_one(range->extent)); + const tir::VarNode* var_ptr = range->min.as(); + ICHECK(var_ptr != nullptr); + sub_representers.push_back(representer_map[GetRef(var_ptr)]); + + if (unmapped_vars.find(GetRef(var_ptr)) != unmapped_vars.end()) { + sub_target_iters.push_back(GetRef(var_ptr)); } } for (size_t i = 0; i < info->rhs_indices_map[rhs_buffer].size(); ++i) { const tir::VarNode* var = info->rhs_indices_map[rhs_buffer][i].as(); ICHECK(var != nullptr); - new_tgt_iters.push_back(tgt_iter_map[GetRef(var)]); - tir::PreOrderVisit(new_tgt_iters.back(), collect); - } - // new representers - for (const auto& representer : info->mapping->initial_indices) { - if (covered.count(representer)) { - new_representers.push_back(representer); - } + sub_target_iters.push_back(tgt_iter_map[GetRef(var)]); } - LOG(INFO) << "TransformaLayout " << it.second.first << it.first << " " << rhs_buffer; state.sch->TransformLayout(state.block_rv, it.second.first, it.second.second ? tir::BufferIndexType::kRead : tir::BufferIndexType::kWrite, - tir::IndexMap(new_representers, new_tgt_iters)); - LOG(INFO) << "OK"; + tir::IndexMap(sub_representers, sub_target_iters)); } + // Transform the layout of current block and reindex blocks + state.sch->TransformBlockLayout(state.tensor_core_reindex_store.value(), info->mapping); + state.sch->TransformBlockLayout(state.tensor_core_reindex_A.value(), info->mapping); + state.sch->TransformBlockLayout(state.tensor_core_reindex_B.value(), info->mapping); + state.sch->TransformBlockLayout(state.block_rv, info->mapping); Array loops = state.sch->GetLoops(state.block_rv); return loops[loops.size() - info->rhs_iters.size()]; diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 2592aad5784a..182251663c52 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -785,12 +785,11 @@ class LayoutInfoNode : public Object { public: IndexMap mapping; Map lhs_buffer_map; - Map> lhs_indices_map, rhs_indices_map; + Map> rhs_indices_map; Array lhs_iters, rhs_iters; void VisitAttrs(AttrVisitor* v) { v->Visit("mapping", &mapping); - v->Visit("lhs_indices_map", &lhs_indices_map); v->Visit("rhs_indices_map", &rhs_indices_map); v->Visit("lhs_iters", &lhs_iters); v->Visit("rhs_iters", &rhs_iters); diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 034667d81926..d3169767ae85 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1312,7 +1312,6 @@ Optional GetTensorizeLayoutInfo(const tir::ScheduleState& self, // Only using 1 layout now ret->mapping = std::move(proposer.mappings_[0]); ret->lhs_buffer_map = std::move(proposer.lhs_buffer_map_); - ret->lhs_indices_map = std::move(extractor.lhs_buffer_indices_map_); ret->rhs_indices_map = std::move(extractor.rhs_buffer_indices_map_); ret->lhs_iters = std::move(extractor.lhs_iters_); ret->rhs_iters = std::move(extractor.rhs_iters_);