diff --git a/include/tvm/tir/schedule/block_scope.h b/include/tvm/tir/schedule/block_scope.h index c708528288..51f511384b 100644 --- a/include/tvm/tir/schedule/block_scope.h +++ b/include/tvm/tir/schedule/block_scope.h @@ -113,7 +113,7 @@ class StmtSRef : public runtime::ObjectRef { }; /*! \brief Type of dependency */ -enum class DepType : int32_t { +enum class DepKind : int32_t { kRAW = 0, kWAW = 1, kWAR = 2, @@ -121,40 +121,43 @@ enum class DepType : int32_t { }; /*! \brief An edge representing certain types of dependency, e.g. read-after-write */ -class DepEdgeNode : public runtime::Object { +class DependencyNode : public runtime::Object { public: - /*! \brief The destination block */ + /*! \brief The source of the dependency relation */ + StmtSRef src; + /*! \brief The destination of the dependency relation */ StmtSRef dst; - /*! \brief The dependency type */ - DepType type; + /*! \brief The dependency kind */ + DepKind kind; void VisitAttrs(AttrVisitor* v) { + v->Visit("src", &src); v->Visit("dst", &dst); - v->Visit("type", &type); + v->Visit("kind", &kind); } - static constexpr const char* _type_key = "tir.DepEdge"; - TVM_DECLARE_FINAL_OBJECT_INFO(DepEdgeNode, Object); + static constexpr const char* _type_key = "tir.Dependency"; + TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object); }; /*! - * \brief Managed reference to DepEdgeNode - * \sa DepEdgeNode + * \brief Managed reference to DependencyNode + * \sa DependencyNode */ -class DepEdge : public runtime::ObjectRef { +class Dependency : public runtime::ObjectRef { public: /*! \brief Constructor */ - explicit DepEdge(StmtSRef dst, DepType type); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DepEdge, ObjectRef, DepEdgeNode); + explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode); }; /*! \brief An object recording the producer-consumer dependency between child blocks of a scope */ class BlockScopeNode : public runtime::Object { public: /*! \brief The forward dependency edges of the block */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> forward_edges; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> src2deps; /*! \brief The backward dependency edges of the block */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> backward_edges; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dst2deps; /*! \brief The mapping from the buffer to the blocks who write it */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; @@ -164,19 +167,19 @@ class BlockScopeNode : public runtime::Object { TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, runtime::Object); public: - /******** Dependency ********/ + /******** DependencyNode ********/ /*! * \brief Get all blocks the block depends on * \param block_sref The queried block * \return The predecessors edges */ - TVM_DLL Array GetPredecessors(const StmtSRef& block_sref) const; + TVM_DLL Array GetDepsBySrc(const StmtSRef& block_sref) const; /*! * \brief Get all blocks that depends on the block * \param block_sref The queried block * \return The successor edges */ - TVM_DLL Array GetSuccessors(const StmtSRef& block_sref) const; + TVM_DLL Array GetDepsByDst(const StmtSRef& block_sref) const; /******** Property of a block ********/ /*! diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 5a5ed8055a..60321cdd7a 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -73,12 +73,10 @@ class ScheduleStateNode : public runtime::Object { * * \param src_sref The sref to the statement to be replaced * \param tgt_stmt The statement to be replaced to - * \param block_sref_reuse Maps an new block (replaced to) back to an old block (to be replaced), + * \param block_sref_reuse Maps an old block (to be replaced) to a new block (replaced to), * and enforces reuse of srefs between them (rather than create new srefs) * i.e. after being replaced, the sref that points to the old block will point to the new one * \note `loop_sref_reuse` will be automatically detected via loop vars - * - * TODO(@junrushao1994): change `block_sref_reuse` from "new -> old" to "old -> new" */ TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, const Map& block_sref_reuse); diff --git a/python/tvm/tir/schedule.py b/python/tvm/tir/schedule.py index a6364c7bfe..d4fdaff135 100644 --- a/python/tvm/tir/schedule.py +++ b/python/tvm/tir/schedule.py @@ -47,8 +47,8 @@ def inline_mark() -> StmtSRef: return _ffi_api_schedule.StmtSRefInlineMark() # pylint: disable=no-member -@_register_object("tir.DepEdge") -class DepEdge(Object): +@_register_object("tir.Dependency") +class Dependency(Object): """An edge in the dependency graph""" kRAW = 0 @@ -56,15 +56,16 @@ class DepEdge(Object): kWAR = 2 kOpaque = 3 + src: StmtSRef dst: StmtSRef - type: int + kind: int @_register_object("tir.BlockScope") class BlockScope(Object): """Dependency Graph that stores read/write dependency between Blocks""" - def get_predecessors(self, block: StmtSRef) -> List[DepEdge]: + def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]: """Get the dependency predecessors of the block Parameters @@ -74,12 +75,12 @@ def get_predecessors(self, block: StmtSRef) -> List[DepEdge]: Returns ------- - blocks: List of DepEdge + blocks: List of Dependency The predecessors of the block """ - return _ffi_api_schedule.BlockScopeGetPredecessors(self, block) # pylint: disable=no-member + return _ffi_api_schedule.BlockScopeGetDepsBySrc(self, block) # pylint: disable=no-member - def get_successor(self, block: StmtSRef) -> List[DepEdge]: + def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]: """Get the dependency successor of the block Parameters @@ -89,10 +90,10 @@ def get_successor(self, block: StmtSRef) -> List[DepEdge]: Returns ------- - blocks: List of DepEdge + blocks: List of Dependency The predecessors of the block """ - return _ffi_api_schedule.BlockScopeGetSuccessors(self, block) # pylint: disable=no-member + return _ffi_api_schedule.BlockScopeGetDepsByDst(self, block) # pylint: disable=no-member @_register_object("tir.ScheduleState") diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 1486d9f6a6..ce960edcfa 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -269,7 +269,7 @@ inline void DelAnn(const tir::ScheduleState& sch, const tir::StmtSRef& sref, ObjectPtr n = make_object(*block); n->annotations = std::move(new_ann); tir::Block p(n); - sch->Replace(sref, p, {{p, GetRef(block)}}); + sch->Replace(sref, p, {{GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; @@ -303,7 +303,7 @@ inline void AddAnn(const tir::ScheduleState& sch, const tir::StmtSRef& sref, con ObjectPtr n = make_object(*block); n->annotations = std::move(new_ann); tir::Block p(n); - sch->Replace(sref, p, {{p, GetRef(block)}}); + sch->Replace(sref, p, {{GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; diff --git a/src/tir/schedule/analysis.cc b/src/tir/schedule/analysis.cc index 77e45a5122..2579b1b18d 100644 --- a/src/tir/schedule/analysis.cc +++ b/src/tir/schedule/analysis.cc @@ -95,13 +95,13 @@ void VerifyRegionCover(const ScheduleState& self, const StmtSRef& consumer_block // Maps a buffer var to its producers std::unordered_map> buffer_producers; // Collect all producers to a buffer by enumerating all RAW predecessors of the consumer - for (const DepEdge& edge : - self->scopes.at(parent_block_sref)->GetPredecessors(consumer_block_sref)) { - if (edge->type != DepType::kRAW) { + for (const Dependency& edge : + self->scopes.at(parent_block_sref)->GetDepsByDst(consumer_block_sref)) { + if (edge->kind != DepKind::kRAW) { continue; } // i.e. the RAW predecessor is producer - const StmtSRef& producer_block_sref = edge->dst; + const StmtSRef& producer_block_sref = edge->src; for (const BufferRegion& output_region : producer_block_sref->GetStmt()->writes) { const VarNode* buffer_var = output_region->buffer->data.get(); buffer_producers[buffer_var].emplace_back(producer_block_sref, output_region); @@ -340,27 +340,27 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent } Array GetProducers(const ScheduleState& self, const StmtSRef& block_sref) { - Array pred_edges = self->scopes - .at(GetScopeRoot(block_sref)) // - ->GetPredecessors(block_sref); + Array pred_edges = self->scopes + .at(GetScopeRoot(block_sref)) // + ->GetDepsByDst(block_sref); Array results; results.reserve(pred_edges.size()); - for (const DepEdge edge : pred_edges) { - if (edge->type == DepType::kRAW || edge->type == DepType::kWAW) { - results.push_back(edge->dst); + for (const Dependency& edge : pred_edges) { + if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) { + results.push_back(edge->src); } } return results; } Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) { - Array succ_edges = self->scopes - .at(GetScopeRoot(block_sref)) // - ->GetSuccessors(block_sref); + Array succ_edges = self->scopes + .at(GetScopeRoot(block_sref)) // + ->GetDepsBySrc(block_sref); Array results; results.reserve(succ_edges.size()); - for (const DepEdge edge : succ_edges) { - if (edge->type == DepType::kRAW || edge->type == DepType::kWAW) { + for (const Dependency& edge : succ_edges) { + if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) { results.push_back(edge->dst); } } diff --git a/src/tir/schedule/block_scope.cc b/src/tir/schedule/block_scope.cc index 448d9a5e0d..43a34b9199 100644 --- a/src/tir/schedule/block_scope.cc +++ b/src/tir/schedule/block_scope.cc @@ -30,14 +30,15 @@ using TBufferReaderWriter = /*! * \brief Add a dependency edge. - * \param from The source of the dependency - * \param to The destination of the dependecy - * \param type Type of the dependency + * \param src The source of the dependency + * \param dst The destination of the dependecy + * \param kind Type of the dependency */ -void AddEdge(BlockScopeNode* self, const StmtSRef& from, const StmtSRef& to, DepType type) { - if (!from.same_as(to)) { - self->forward_edges[from].push_back(DepEdge(to, type)); - self->backward_edges[to].push_back(DepEdge(from, type)); +void AddEdge(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) { + if (!src.same_as(dst)) { + Dependency dep(src, dst, kind); + self->src2deps[src].push_back(dep); + self->dst2deps[dst].push_back(dep); } } @@ -67,7 +68,7 @@ void AddChildBlock(BlockScopeNode* self, const StmtSRef& child_sref, for (const BufferRegion& region : block->reads) { if (buffer_writers.count(region->buffer)) { for (const StmtSRef& from : buffer_writers[region->buffer]) { - AddEdge(self, from, child_sref, DepType::kRAW); + AddEdge(self, from, child_sref, DepKind::kRAW); } } } @@ -75,7 +76,7 @@ void AddChildBlock(BlockScopeNode* self, const StmtSRef& child_sref, for (const BufferRegion& region : block->writes) { if (buffer_writers.count(region->buffer)) { for (const StmtSRef& from : buffer_writers[region->buffer]) { - AddEdge(self, from, child_sref, DepType::kWAW); + AddEdge(self, from, child_sref, DepKind::kWAW); } } } @@ -140,10 +141,11 @@ StmtSRef StmtSRef::RootMark() { return result; } -DepEdge::DepEdge(StmtSRef dst, DepType type) { - ObjectPtr node = make_object(); +Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) { + ObjectPtr node = make_object(); + node->src = std::move(src); node->dst = std::move(dst); - node->type = type; + node->kind = kind; data_ = std::move(node); } @@ -160,9 +162,9 @@ BlockScope::BlockScope(const Array& leaf_block_srefs) { /******** Dependency ********/ -Array BlockScopeNode::GetPredecessors(const StmtSRef& block_sref) const { - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& edges = - this->backward_edges; +Array BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const { + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& edges = + this->src2deps; auto iter = edges.find(block_sref); if (iter != edges.end()) { return iter->second; @@ -171,9 +173,9 @@ Array BlockScopeNode::GetPredecessors(const StmtSRef& block_sref) const } } -Array BlockScopeNode::GetSuccessors(const StmtSRef& block_sref) const { - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& edges = - this->forward_edges; +Array BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const { + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& edges = + this->dst2deps; auto iter = edges.find(block_sref); if (iter != edges.end()) { return iter->second; @@ -336,17 +338,17 @@ bool BlockScopeNode::CanMergeReduction(const StmtSRef& init_sref, /******** FFI ********/ TVM_REGISTER_NODE_TYPE(StmtSRefNode); -TVM_REGISTER_NODE_TYPE(DepEdgeNode); +TVM_REGISTER_NODE_TYPE(DependencyNode); TVM_REGISTER_NODE_TYPE(BlockScopeNode); TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark") // .set_body_typed(StmtSRef::RootMark); TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark") // .set_body_typed(StmtSRef::InlineMark); -TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetPredecessors") - .set_body_method(&BlockScopeNode::GetPredecessors); -TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetSuccessors") - .set_body_method(&BlockScopeNode::GetSuccessors); +TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc") + .set_body_method(&BlockScopeNode::GetDepsBySrc); +TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst") + .set_body_method(&BlockScopeNode::GetDepsByDst); TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt") .set_body_typed([](StmtSRef sref) -> Optional { return sref->stmt != nullptr ? GetRef(sref->stmt) : Optional(NullOpt); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 92ecbc20df..93b6c5b536 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -89,19 +89,19 @@ struct SRefTranslator { return result; } - /*! \brief Translate Array */ - Array Trans(const Array& list) { - Array result; + /*! \brief Translate Array */ + Array Trans(const Array& list) { + Array result; result.reserve(list.size()); - for (const DepEdge& elem : list) { - result.push_back(DepEdge(Trans(elem->dst), elem->type)); + for (const Dependency& elem : list) { + result.push_back(Dependency(Trans(elem->src), Trans(elem->dst), elem->kind)); } return result; } - /*! \brief Translate SMap> */ - SMap> Trans(const SMap>& map) { - SMap> result; + /*! \brief Translate SMap> */ + SMap> Trans(const SMap>& map) { + SMap> result; result.reserve(map.size()); for (const auto& kv : map) { result[Trans(kv.first)] = Trans(kv.second); @@ -126,8 +126,8 @@ struct SRefTranslator { const StmtSRef& old_sref = kv.first; const BlockScope& old_scope = kv.second; ObjectPtr scope = make_object(); - scope->forward_edges = Trans(old_scope->forward_edges); - scope->backward_edges = Trans(old_scope->backward_edges); + scope->src2deps = Trans(old_scope->src2deps); + scope->dst2deps = Trans(old_scope->dst2deps); scope->buffer_writers = Trans(old_scope->buffer_writers); result.Set(Trans(old_sref), BlockScope(std::move(scope))); } diff --git a/src/tir/schedule/primitives/bind_annotate.cc b/src/tir/schedule/primitives/bind_annotate.cc index ab990e5076..1218f6c5ff 100644 --- a/src/tir/schedule/primitives/bind_annotate.cc +++ b/src/tir/schedule/primitives/bind_annotate.cc @@ -217,7 +217,7 @@ void DoubleBuffer(ScheduleState self, const StmtSRef& block_sref) { << "ValueError: 'double_buffer' expects 'block' with only one write buffer"; Block new_block = WithAnnotation(block_ptr, tir::attr::double_buffer_scope, IntImm(DataType::Int(32), 1)); - self->Replace(block_sref, new_block, {{new_block, GetRef(block_ptr)}}); + self->Replace(block_sref, new_block, {{GetRef(block_ptr), new_block}}); } } // namespace schedule diff --git a/src/tir/schedule/primitives/blockize_tensorize.cc b/src/tir/schedule/primitives/blockize_tensorize.cc index c70b48d0f1..41c48f7b4c 100644 --- a/src/tir/schedule/primitives/blockize_tensorize.cc +++ b/src/tir/schedule/primitives/blockize_tensorize.cc @@ -557,7 +557,7 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, const String& e "blockized_" + block->name_hint, body, new_init); auto outer_realize = BlockRealize(outer_bindings, division.back()->outer_extent, outer_block); - self->Replace(loop_sref, outer_realize, {{inner_block, block}}); + self->Replace(loop_sref, outer_realize, {{block, inner_block}}); UpdateScope(self, GetScopeRoot(self->stmt2ref.at(outer_block.get()))); UpdateScope(self, self->stmt2ref.at(outer_block.get())); // Check loop binding @@ -823,7 +823,7 @@ void Tensorize(ScheduleState self, const StmtSRef& loop_sref, const TensorIntrin new_block_ptr->body = new_body; Block new_block(new_block_ptr); self->Replace(self->stmt2ref.at(block_realize->block.get()), new_block, - {{new_block, block_realize->block}}); + {{block_realize->block, new_block}}); } } // namespace schedule diff --git a/src/tir/schedule/primitives/cache_read_write.cc b/src/tir/schedule/primitives/cache_read_write.cc index 069ab5ebb2..384c4d972c 100644 --- a/src/tir/schedule/primitives/cache_read_write.cc +++ b/src/tir/schedule/primitives/cache_read_write.cc @@ -299,8 +299,8 @@ class CacheLocDetector : public StmtVisitor { static void Detect(const ScheduleState self, const StmtSRef& block_sref, const StmtSRef& scope_sref, CacheStageInfo* info) { std::vector related_blocks; - for (const DepEdge& x : self->scopes.at(scope_sref)->GetSuccessors(block_sref)) { - if (x->type == DepType::kRAW) { + for (const Dependency& x : self->scopes.at(scope_sref)->GetDepsBySrc(block_sref)) { + if (x->kind == DepKind::kRAW) { related_blocks.push_back(x->dst); } } @@ -395,7 +395,7 @@ class CacheReadRewriter : public StmtExprMutator { stmt = Block(n); } } - info_->block_map[stmt] = old_stmt; + info_->block_map[old_stmt] = stmt; return std::move(stmt); } @@ -479,7 +479,7 @@ class CacheWriteRewriter : public StmtExprMutator { stmt = Block(n); } } - info_->block_map[stmt] = old_stmt; + info_->block_map[old_stmt] = stmt; return std::move(stmt); } @@ -608,16 +608,10 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int i, Stmt new_scope = CacheWriteRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); // Handling block remapping std::unordered_map& block_map = info.block_map; - for (const auto& mapping : block_map) { - const Block& new_block = mapping.first; - const Block& old_block = mapping.second; - if (old_block.get() == block_sref->stmt) { - // It is okay to mutate inside iteration, because it is going to break anyways - block_map[cache_write_stage] = old_block; - cache_write_stage = new_block; - block_map.erase(new_block); - break; - } + { + auto it = block_map.find(GetRef(block)); + ICHECK(it != block_map.end()); + std::swap(it->second, cache_write_stage); } self->Replace(scope_sref, new_scope, block_map); return self->stmt2ref.at(cache_write_stage.get()); diff --git a/src/tir/schedule/primitives/compute_location.cc b/src/tir/schedule/primitives/compute_location.cc index eaf6128c84..32d46184ef 100644 --- a/src/tir/schedule/primitives/compute_location.cc +++ b/src/tir/schedule/primitives/compute_location.cc @@ -45,10 +45,10 @@ Array GatherVars(const ObjectRef& stmt_or_expr) { * \param blocks A list of candidate blocks * \return True if there is at least one edge that points to a block in the list */ -bool AnyEdgePointsToABlock(const Array& edges, const Array& blocks) { - for (const DepEdge& edge : edges) { +bool AnyEdgePointsToABlock(const Array& edges, const Array& blocks) { + for (const StmtSRef& edge : edges) { for (const StmtSRef& block : blocks) { - if (edge->dst.same_as(block)) { + if (edge.same_as(block)) { return true; } } @@ -56,22 +56,40 @@ bool AnyEdgePointsToABlock(const Array& edges, const Array& b return false; } +Array GetProducersFromDependency(const Array& deps) { + Array result; + result.reserve(deps.size()); + for (const Dependency& dep : deps) { + result.push_back(dep->src); + } + return result; +} + +Array GetConsumersFromDependency(const Array& deps) { + Array result; + result.reserve(deps.size()); + for (const Dependency& dep : deps) { + result.push_back(dep->dst); + } + return result; +} + /*! * \brief Helper function to check if every edge points to a block in the given set of blocks * \param edges A list of edges to be check * \param blocks A list of candidate blocks - * \param raw_edge_only Only consider RAW-dependency edges * \return True if all edges that have a corresponding block */ -bool EachEdgePointsToABlock(const Array& edges, const Array& blocks, - bool raw_edge_only) { - for (const DepEdge& edge : edges) { - if (raw_edge_only && edge->type != DepType::kRAW) { +bool EachEdgePointsToABlock(const Array& edges, const Array& blocks, + bool use_dst) { + for (const Dependency& edge : edges) { + if (edge->kind != DepKind::kRAW) { continue; } bool found = false; + const StmtSRef& sref = use_dst ? edge->dst : edge->src; for (const StmtSRef& block : blocks) { - if (edge->dst.same_as(block)) { + if (sref.same_as(block)) { found = true; break; } @@ -84,14 +102,14 @@ bool EachEdgePointsToABlock(const Array& edges, const Array& } /*! - * \brief Extract StmtSRef from DepEdgeNode::dst + * \brief Extract StmtSRef from DependencyNode::dst * \param edges List of edges to be extracted * \return A list of StmtSRef as the result */ -std::vector EdgesToSRefs(const Array& edges) { +std::vector EdgesToSRefs(const Array& edges) { std::vector result; result.reserve(edges.size()); - for (const DepEdge& edge : edges) { + for (const Dependency& edge : edges) { result.push_back(edge->dst); } return result; @@ -404,7 +422,7 @@ class StatementInliner : public StmtExprMutator { block_node->alloc_buffers = alloc_buffers; Block block(block_node); - block_sref_map_->Set(block, origin_block); + block_sref_map_->Set(origin_block, block); return std::move(block); } @@ -483,7 +501,7 @@ class ReverseStatementInliner : public StmtExprMutator { block_node->alloc_buffers = alloc_buffers; Block block(block_node); - if (is_producer) block_sref_map_->Set(block, origin_producer); + if (is_producer) block_sref_map_->Set(origin_producer, block); return std::move(Block(block)); } @@ -570,8 +588,10 @@ void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& l const StmtSRef& parent_block_sref = GetScopeRoot(block_sref); const auto* parent_block = parent_block_sref->GetStmt(); const BlockScope& scope = self->scopes.at(parent_block_sref); - Array edges_to_pred = scope->GetPredecessors(block_sref); - Array edges_to_succ = scope->GetSuccessors(block_sref); + Array edges_to_pred = scope->GetDepsByDst(block_sref); + Array edges_to_succ = scope->GetDepsBySrc(block_sref); + Array producers = GetProducersFromDependency(edges_to_pred); + Array consumers = GetConsumersFromDependency(edges_to_succ); // Cond 0. `block` and `loop` are in the same scope CHECK_EQ(parent_block_sref.get(), GetScopeRoot(loop_sref).get()) << "ValueError: 'compute_at' expects 'block' and 'loop' be in the same block"; @@ -580,7 +600,7 @@ void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& l << "ValueError: 'compute_at' expects 'block' to be a complete or reduction block"; // Cond 2. Check all RAW successors are in the subtree rooted by loop_sref CHECK(EachEdgePointsToABlock(edges_to_succ, GetChildBlocks(self, loop_sref, true), - /*raw_edge_only=*/true)) + /*use_dst=*/true)) << "ValueError: 'compute_at' does not apply to a block that some other " << "blocks outside the scope depends on"; // Cond 3. The subtree has compact data flow @@ -602,8 +622,7 @@ void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& l int n_stmts = loop_body.size(); for (insert_pos = n_stmts; insert_pos > 0; --insert_pos) { const StmtNode* stmt = loop_body[insert_pos - 1].get(); - if (AnyEdgePointsToABlock(edges_to_pred, - GetChildBlocks(self, self->stmt2ref.at(stmt), true))) { + if (AnyEdgePointsToABlock(producers, GetChildBlocks(self, self->stmt2ref.at(stmt), true))) { break; } } @@ -611,8 +630,7 @@ void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& l int before_pos; for (before_pos = 0; before_pos < n_stmts; before_pos++) { const StmtNode* stmt = loop_body[before_pos].get(); - if (AnyEdgePointsToABlock(edges_to_succ, - GetChildBlocks(self, self->stmt2ref.at(stmt), true))) { + if (AnyEdgePointsToABlock(consumers, GetChildBlocks(self, self->stmt2ref.at(stmt), true))) { break; } } @@ -625,7 +643,7 @@ void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& l SolveCover(block, GatherRequirements(/*produced_regions=*/block->writes, /*lca_loop_sref=*/loop_sref, - /*consumer_blocks=*/EdgesToSRefs(edges_to_succ), + /*consumer_blocks=*/{consumers.begin(), consumers.end()}, /*relax_vars=*/RelaxForExecScope(loop_sref, block_sref), /*gather_read=*/true), true), @@ -641,7 +659,7 @@ void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& l StmtSRef lca = LowestCommonAncestor({block_sref, loop_sref}, root); Stmt replaced = StmtReplacer(replace_map)(GetRef(lca->stmt)); if (const auto* replaced_block = replaced.as()) { - self->Replace(lca, replaced, {{GetRef(replaced_block), GetRef(parent_block)}}); + self->Replace(lca, replaced, {{GetRef(parent_block), GetRef(replaced_block)}}); } else { self->Replace(lca, replaced, {}); } @@ -671,8 +689,10 @@ void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stmt const StmtSRef& parent_block_sref = GetScopeRoot(block_sref); const auto* parent_block = parent_block_sref->GetStmt(); const BlockScope& scope = self->scopes.at(parent_block_sref); - Array edges_to_pred = scope->GetPredecessors(block_sref); - Array edges_to_succ = scope->GetSuccessors(block_sref); + Array edges_to_pred = scope->GetDepsByDst(block_sref); + Array edges_to_succ = scope->GetDepsBySrc(block_sref); + Array producers = GetProducersFromDependency(edges_to_pred); + Array consumers = GetConsumersFromDependency(edges_to_succ); // Cond 0. `block` and `loop` are in the same scope CHECK_EQ(parent_block_sref.get(), GetScopeRoot(loop_sref).get()) << "ValueError: 'reverse_compute_at' expects 'block' and 'loop' be in the same block"; @@ -681,7 +701,7 @@ void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stmt << "ValueError: 'reverse_compute_at' expects 'block' to be a complete or reduction block"; // Cond 2. Check all RAW predecessors are in the subtree rooted by loop_sref CHECK(EachEdgePointsToABlock(edges_to_pred, GetChildBlocks(self, loop_sref, true), - /*raw_edge_only=*/true)) + /*use_dst=*/false)) << "ValueError: 'reverse_compute_at' does not apply to a block that some other " << "blocks outside the scope depends on"; // Cond 3. The subtree has compact data flow @@ -703,8 +723,7 @@ void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stmt int n_stmts = loop_body.size(); for (insert_pos = n_stmts; insert_pos > 0; --insert_pos) { const StmtNode* stmt = loop_body[insert_pos - 1].get(); - if (AnyEdgePointsToABlock(edges_to_pred, - GetChildBlocks(self, self->stmt2ref.at(stmt), true))) { + if (AnyEdgePointsToABlock(producers, GetChildBlocks(self, self->stmt2ref.at(stmt), true))) { break; } } @@ -712,8 +731,7 @@ void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stmt int before_pos; for (before_pos = 0; before_pos < n_stmts; before_pos++) { const StmtNode* stmt = loop_body[before_pos].get(); - if (AnyEdgePointsToABlock(edges_to_succ, - GetChildBlocks(self, self->stmt2ref.at(stmt), true))) { + if (AnyEdgePointsToABlock(consumers, GetChildBlocks(self, self->stmt2ref.at(stmt), true))) { break; } } @@ -721,16 +739,16 @@ void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stmt "point that satisfies dependency"; } // Generate new LoopNode to substitute loop_sref->stmt - For new_loop = - RegenerateLoops(block_sref, loop_sref, insert_pos, - SolveCover(block, - GatherRequirements(/*produced_regions=*/block->reads, - /*lca_loop_sref=*/loop_sref, - /*consumer_blocks=*/EdgesToSRefs(edges_to_pred), - /*relax_vars=*/{}, - /*gather_read=*/false), - false), - preserve_trivial_loop); + For new_loop = RegenerateLoops( + block_sref, loop_sref, insert_pos, + SolveCover(block, + GatherRequirements(/*produced_regions=*/block->reads, + /*lca_loop_sref=*/loop_sref, + /*consumer_blocks=*/{producers.begin(), producers.end()}, + /*relax_vars=*/{}, + /*gather_read=*/false), + false), + preserve_trivial_loop); // Remove leaf StmtSRef root = GetSRefTreeRoot(block_sref); std::pair removed = RemoveLeaf(block_sref, root); @@ -742,7 +760,7 @@ void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stmt StmtSRef lca = LowestCommonAncestor({block_sref, loop_sref}, root); Stmt replaced = StmtReplacer(replace_map)(GetRef(lca->stmt)); if (const auto* replaced_block = replaced.as()) { - self->Replace(lca, replaced, {{GetRef(replaced_block), GetRef(parent_block)}}); + self->Replace(lca, replaced, {{GetRef(parent_block), GetRef(replaced_block)}}); } else { self->Replace(lca, replaced, {}); } @@ -800,12 +818,12 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref) { CHECK(block->body.as()) << "ValueError: 'reverse_compute_inline' expects the 'block' contains a single BufferStore"; // Cond 3. block_sref has only one RAW producer - const auto& producers = scope->GetPredecessors(block_sref); + Array producers = scope->GetDepsByDst(block_sref); CHECK_EQ(producers.size(), 1) << "ValueError: 'reverse_compute_inline' expects the 'block' has only one producer"; - CHECK(producers[0]->type == DepType::kRAW) + CHECK(producers[0]->kind == DepKind::kRAW) << "ValueError: 'reverse_compute_inline' expects the 'block' has only one producer"; - const StmtSRef& producer_sref = producers[0]->dst; + const StmtSRef& producer_sref = producers[0]->src; // Cond 4. The producer is complete CHECK(scope->IsComplete(producer_sref)) << "ValueError: 'reverse_compute_inline' expects the producer of 'block' to be complete"; @@ -815,7 +833,7 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref) { << "ValueError: 'reverse_compute_inline' expects the producer of 'block' to contain a single " "BufferStore"; // Cond 6. The producer has only one consumer(which is block_sref) - const auto& consumers = scope->GetSuccessors(producer_sref); + Array consumers = scope->GetDepsBySrc(producer_sref); CHECK_EQ(consumers.size(), 1) << "ValueError: 'reverse_compute_inline' expects 'block' is the " "only consumer of its producer"; CHECK_EQ(consumers[0]->dst, block_sref) << "ValueError: 'reverse_compute_inline' expects 'block' " diff --git a/src/tir/schedule/primitives/reduction.cc b/src/tir/schedule/primitives/reduction.cc index 32e0dc140d..e9a2b115a2 100644 --- a/src/tir/schedule/primitives/reduction.cc +++ b/src/tir/schedule/primitives/reduction.cc @@ -165,7 +165,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, block_node->init = NullOpt; Block new_block = Block(block_node); self->Replace(GetRef(loop_sref->parent), new_block, - {{new_block, GetRef(parent)}}); + {{GetRef(parent), new_block}}); } else { LOG(FATAL) << "TypeError: 'decompose_reduction' is applied to loop whose parent's type is not " @@ -177,7 +177,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, update_block_node->name_hint = block->name_hint + "_update"; update_block_node->init = NullOpt; Block update_block(update_block_node); - self->Replace(block_sref, update_block, {{update_block, GetRef(block)}}); + self->Replace(block_sref, update_block, {{GetRef(block), update_block}}); // Update scope information UpdateScope(self, block_sref); return self->stmt2ref.at(init_block.get()); @@ -204,7 +204,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, block_node->body = body; Block new_block(block_node); - self->Replace(block_sref, new_block, {{new_block, GetRef(block)}}); + self->Replace(block_sref, new_block, {{GetRef(block), new_block}}); // Update scope information UpdateScope(self, block_sref); return self->stmt2ref.at(new_block.get()); @@ -301,7 +301,7 @@ void MergeReduction(ScheduleState self, const StmtSRef& init_sref, const StmtSRe auto merged_node = make_object(*update); merged_node->init = new_init; Block merged(merged_node); - self->Replace(update_sref, merged, {{merged, GetRef(update)}}); + self->Replace(update_sref, merged, {{GetRef(update), merged}}); // Update scope information UpdateScope(self, update_sref); } @@ -611,7 +611,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis) SeqStmt parent_body = insert(parent->body, top.value()->seq_index, {rf_body, wb_body}); self->Replace(GetRef(top.value()->parent), For(parent->loop_var, parent->min, parent->extent, ForKind::kSerial, parent_body), - {{wb_block, block}}); + {{block, wb_block}}); } else if (const auto* parent = top.value()->parent->GetStmt()) { SeqStmt parent_body = insert(parent->body, top.value()->seq_index, {rf_body, wb_body}); auto block_node = make_object(*parent); @@ -619,7 +619,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis) block_node->init = NullOpt; Block new_block = Block(block_node); self->Replace(GetRef(top.value()->parent), new_block, - {{new_block, GetRef(parent)}, {wb_block, block}}); + {{GetRef(parent), new_block}, {block, wb_block}}); } // Insert the rfactor buffer into the scope block's allocation. @@ -627,7 +627,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis) Block scope_block = GetRef(scope_sref->GetStmt()), new_scope_block = scope_block; new_scope_block.CopyOnWrite()->alloc_buffers.push_back(rf_buf); - self->Replace(scope_sref, new_scope_block, {{new_scope_block, scope_block}}); + self->Replace(scope_sref, new_scope_block, {{scope_block, new_scope_block}}); // Update scope information. UpdateScope(self, scope_sref); return self->stmt2ref.at(rf_block.get()); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 82c4c88f31..d21e87197b 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -272,15 +272,15 @@ struct ReuseInfo { class ReuseCollector : public StmtVisitor { public: static ReuseInfo Collect(ScheduleStateNode* self, const Stmt& tgt_stmt, - const Map& block_sref_reverse_reuse) { + const Map& block_sref_reuse) { ReuseCollector collector(self); collector.VisitStmt(tgt_stmt); ReuseInfo result; result.intact = {collector.intact_.begin(), collector.intact_.end()}; result.loop_sref_reuse = {collector.loop_vars_.begin(), collector.loop_vars_.end()}; - for (const auto& kv : block_sref_reverse_reuse) { - const Block& new_block = kv.first; - const Block& old_block = kv.second; + for (const auto& kv : block_sref_reuse) { + const Block& old_block = kv.first; + const Block& new_block = kv.second; result.block_sref_reuse.emplace(old_block.get(), new_block.get()); } return result; diff --git a/tests/python/tir/test_block_dependency.py b/tests/python/tir/test_block_dependency.py index 77d3d75da0..d659f9bfc9 100644 --- a/tests/python/tir/test_block_dependency.py +++ b/tests/python/tir/test_block_dependency.py @@ -49,14 +49,14 @@ def test_element_wise_dependency(): root = s.get_sref(s.get_block("root")) block_b = s.get_sref(s.get_block("B")) block_c = s.get_sref(s.get_block("C")) - # Check get_predecessors - (predecessor_c,) = s.state.scopes[root].get_predecessors(block_c) - assert predecessor_c.dst.same_as(block_b) - assert predecessor_c.type == tir.schedule.DepEdge.kRAW - # Check get_successor - (successor_b,) = s.state.scopes[root].get_successor(block_b) + # Check get_deps_by_dst + (predecessor_c,) = s.state.scopes[root].get_deps_by_dst(block_c) + assert predecessor_c.src.same_as(block_b) + assert predecessor_c.kind == tir.schedule.Dependency.kRAW + # Check get_deps_by_src + (successor_b,) = s.state.scopes[root].get_deps_by_src(block_b) assert successor_b.dst.same_as(block_c) - assert predecessor_c.type == tir.schedule.DepEdge.kRAW + assert predecessor_c.kind == tir.schedule.Dependency.kRAW def test_matmul_dependency(): @@ -66,21 +66,21 @@ def test_matmul_dependency(): init = s.get_sref(s.get_block("init")) update = s.get_sref(s.get_block("update")) # Check predecessors - p0, p1 = s.state.scopes[root].get_predecessors(update) - assert p0.dst.same_as(init) - assert p1.dst.same_as(init) + p0, p1 = s.state.scopes[root].get_deps_by_dst(update) + assert p0.src.same_as(init) + assert p1.src.same_as(init) # WAW and RAW - assert (p0.type == tir.schedule.DepEdge.kRAW and p1.type == tir.schedule.DepEdge.kWAW) or ( - p0.type == tir.schedule.DepEdge.kWAW and p1.type == tir.schedule.DepEdge.kRAW - ) + assert ( + p0.kind == tir.schedule.Dependency.kRAW and p1.kind == tir.schedule.Dependency.kWAW + ) or (p0.kind == tir.schedule.Dependency.kWAW and p1.kind == tir.schedule.Dependency.kRAW) # Check successors - p0, p1 = s.state.scopes[root].get_successor(init) + p0, p1 = s.state.scopes[root].get_deps_by_src(init) assert p0.dst == update assert p1.dst == update # WAW and RAW - assert (p0.type == tir.schedule.DepEdge.kRAW and p1.type == tir.schedule.DepEdge.kWAW) or ( - p0.type == tir.schedule.DepEdge.kWAW and p1.type == tir.schedule.DepEdge.kRAW - ) + assert ( + p0.kind == tir.schedule.Dependency.kRAW and p1.kind == tir.schedule.Dependency.kWAW + ) or (p0.kind == tir.schedule.Dependency.kWAW and p1.kind == tir.schedule.Dependency.kRAW) def test_WAR_dependency(): diff --git a/tests/python/tir/test_schedule_replace.py b/tests/python/tir/test_schedule_replace.py index 1e143ebfbc..e60b20c18b 100644 --- a/tests/python/tir/test_schedule_replace.py +++ b/tests/python/tir/test_schedule_replace.py @@ -242,7 +242,13 @@ def test_replace_block_remap(): # The target stmt target = util.matmul_stmt_original().body.block.body.body.body[0].block sref = s.get_sref(s.mod["main"].body.block.body[0].body.body.block) - s.state.replace(sref, target, {target: s.mod["main"].body.block.body[0].body.body.block}) + s.state.replace( + sref, + target, + { + s.mod["main"].body.block.body[0].body.body.block: target, + }, + ) sref_new = s.get_sref(s.get_block("init")) # Check the original sref has been remapped assert sref.__hash__() == sref_new.__hash__()