Skip to content

Commit

Permalink
fix (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored and vinx13 committed Jun 2, 2022
1 parent 53fef8e commit 030a0cd
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 50 deletions.
7 changes: 0 additions & 7 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,6 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) {
}
}

Stmt IRMutatorWithAnalyzer::VisitStmt_(const BlockNode* op) {
for (const IterVar& iter_var : op->iter_vars) {
analyzer_->Bind(iter_var->var, iter_var->dom);
}
return StmtExprMutator::VisitStmt_(op);
}

PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
static auto op_if_then_else = Op::Get("tir.if_then_else");
Expand Down
1 change: 0 additions & 1 deletion src/arith/ir_mutator_with_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override;
tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override;
tir::Stmt VisitStmt_(const tir::AssertStmtNode* op) override;
tir::Stmt VisitStmt_(const tir::BlockNode* op) override;
PrimExpr VisitExpr_(const tir::LetNode* op) override;
PrimExpr VisitExpr_(const tir::SelectNode* op) override;
PrimExpr VisitExpr_(const tir::CallNode* op) override;
Expand Down
74 changes: 35 additions & 39 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,6 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
}

Optional<LoopRV> MultiLevelTilingNode::TransformWithTensorIntrin(State& state, const String& intrin_name) const {
// Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
// sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
BlockRV block_rv = state.block_rv;
Optional<tir::LayoutInfo> opt_layout_info =
GetTensorizeLayoutInfo(state.sch->state(), state.sch->GetSRef(block_rv),
Expand All @@ -310,62 +308,60 @@ Optional<LoopRV> MultiLevelTilingNode::TransformWithTensorIntrin(State& state, c
for (size_t i = 0; i < block->writes.size(); ++i) {
buffers[block->writes[i]->buffer] = std::move(std::make_pair(i, false));
}

// Reindex buffers and insert reindex stage
state.tensor_core_reindex_store = state.sch->ReIndex(block_rv, 0, true);
state.tensor_core_reindex_A = state.sch->ReIndex(block_rv, 0, false);
state.tensor_core_reindex_B = state.sch->ReIndex(block_rv, 1, false);
state.sch->TransformBlockLayout(state.tensor_core_reindex_store.value(), info->mapping);
state.sch->TransformBlockLayout(state.tensor_core_reindex_A.value(), info->mapping);
state.sch->TransformBlockLayout(state.tensor_core_reindex_B.value(), info->mapping);
state.sch->TransformBlockLayout(state.block_rv, info->mapping);

size_t offset = info->mapping->final_indices.size() - info->rhs_iters.size();
// Transform the layout of reindex buffers accordingly
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> unmapped_vars;
std::unordered_map<tir::Var, tir::Var, ObjectPtrHash, ObjectPtrEqual> representer_map;
std::unordered_map<tir::Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> tgt_iter_map;

size_t offset = info->mapping->final_indices.size() - info->rhs_iters.size();
ICHECK_EQ(info->lhs_iters.size(), info->mapping->initial_indices.size());
for (size_t i = 0; i < info->lhs_iters.size(); ++i) {
representer_map[info->lhs_iters[i]->var] = info->mapping->initial_indices[i];
}
for (size_t i = 0; i < offset; ++i) {
const tir::VarNode* var_ptr = info->mapping->final_indices[i].as<tir::VarNode>();
ICHECK(var_ptr != nullptr);
unmapped_vars.insert(Downcast<tir::Var>(info->mapping->final_indices[i]));
}
for (size_t i = offset; i < info->mapping->final_indices.size(); ++i) {
tgt_iter_map[info->rhs_iters[i - offset]->var] = info->mapping->final_indices[i];
}

for (const auto& it : buffers) {
// organize the mappings for buffer layout transformation
const tir::Buffer& rhs_buffer = info->lhs_buffer_map[it.first];
std::vector<tir::Var> new_representers;
std::vector<PrimExpr> new_tgt_iters;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> covered;
auto collect = [&](const ObjectRef& obj) -> bool {
if (const tir::VarNode* var = obj.as<tir::VarNode>()) {
covered.insert(GetRef<tir::Var>(var));
}
return true;
};
// new target iters
for (const PrimExpr& index : info->lhs_indices_map[it.first]) {
tir::PreOrderVisit(index, collect);
}
for (size_t i = 0; i < offset; ++i) {
if (covered.count(info->lhs_iters[i]->var)) {
covered.insert(info->mapping->initial_indices[i]);
new_tgt_iters.push_back(info->mapping->final_indices[i]);
std::vector<tir::Var> sub_representers;
std::vector<PrimExpr> sub_target_iters;
// Refresh block sref and handler
block_sref = state.sch->GetSRef(state.block_rv);
block = TVM_SREF_TO_BLOCK(block, block_sref);
const tir::BufferRegion& region = it.second.second ? block->reads[it.second.first] : block->writes[it.second.first];
for (const Range& range : region->region) {
ICHECK(tir::is_one(range->extent));
const tir::VarNode* var_ptr = range->min.as<tir::VarNode>();
ICHECK(var_ptr != nullptr);
sub_representers.push_back(representer_map[GetRef<tir::Var>(var_ptr)]);

if (unmapped_vars.find(GetRef<tir::Var>(var_ptr)) != unmapped_vars.end()) {
sub_target_iters.push_back(GetRef<tir::Var>(var_ptr));
}
}
for (size_t i = 0; i < info->rhs_indices_map[rhs_buffer].size(); ++i) {
const tir::VarNode* var = info->rhs_indices_map[rhs_buffer][i].as<tir::VarNode>();
ICHECK(var != nullptr);
new_tgt_iters.push_back(tgt_iter_map[GetRef<tir::Var>(var)]);
tir::PreOrderVisit(new_tgt_iters.back(), collect);
}
// new representers
for (const auto& representer : info->mapping->initial_indices) {
if (covered.count(representer)) {
new_representers.push_back(representer);
}
sub_target_iters.push_back(tgt_iter_map[GetRef<tir::Var>(var)]);
}
LOG(INFO) << "TransformaLayout " << it.second.first << it.first << " " << rhs_buffer;
state.sch->TransformLayout(state.block_rv, it.second.first,
it.second.second ? tir::BufferIndexType::kRead : tir::BufferIndexType::kWrite,
tir::IndexMap(new_representers, new_tgt_iters));
LOG(INFO) << "OK";
tir::IndexMap(sub_representers, sub_target_iters));
}
// Transform the layout of current block and reindex blocks
state.sch->TransformBlockLayout(state.tensor_core_reindex_store.value(), info->mapping);
state.sch->TransformBlockLayout(state.tensor_core_reindex_A.value(), info->mapping);
state.sch->TransformBlockLayout(state.tensor_core_reindex_B.value(), info->mapping);
state.sch->TransformBlockLayout(state.block_rv, info->mapping);

Array<LoopRV> loops = state.sch->GetLoops(state.block_rv);
return loops[loops.size() - info->rhs_iters.size()];
Expand Down
3 changes: 1 addition & 2 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -785,12 +785,11 @@ class LayoutInfoNode : public Object {
public:
IndexMap mapping;
Map<Buffer, Buffer> lhs_buffer_map;
Map<Buffer, Array<PrimExpr>> lhs_indices_map, rhs_indices_map;
Map<Buffer, Array<PrimExpr>> rhs_indices_map;
Array<IterVar> lhs_iters, rhs_iters;

void VisitAttrs(AttrVisitor* v) {
v->Visit("mapping", &mapping);
v->Visit("lhs_indices_map", &lhs_indices_map);
v->Visit("rhs_indices_map", &rhs_indices_map);
v->Visit("lhs_iters", &lhs_iters);
v->Visit("rhs_iters", &rhs_iters);
Expand Down
1 change: 0 additions & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,6 @@ Optional<LayoutInfo> GetTensorizeLayoutInfo(const tir::ScheduleState& self,
// Only using 1 layout now
ret->mapping = std::move(proposer.mappings_[0]);
ret->lhs_buffer_map = std::move(proposer.lhs_buffer_map_);
ret->lhs_indices_map = std::move(extractor.lhs_buffer_indices_map_);
ret->rhs_indices_map = std::move(extractor.rhs_buffer_indices_map_);
ret->lhs_iters = std::move(extractor.lhs_iters_);
ret->rhs_iters = std::move(extractor.rhs_iters_);
Expand Down

0 comments on commit 030a0cd

Please sign in to comment.