Skip to content

Commit

Permalink
[Refactor] Change semantics of Replace; Update DepEdge => Dependency (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and Hzfengsy committed Mar 26, 2021
1 parent d68997f commit e46b917
Show file tree
Hide file tree
Showing 15 changed files with 193 additions and 171 deletions.
39 changes: 21 additions & 18 deletions include/tvm/tir/schedule/block_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,48 +113,51 @@ class StmtSRef : public runtime::ObjectRef {
};

/*! \brief Type of dependency */
enum class DepType : int32_t {
enum class DepKind : int32_t {
kRAW = 0,
kWAW = 1,
kWAR = 2,
kOpaque = 3,
};

/*! \brief An edge representing certain types of dependency, e.g. read-after-write */
class DepEdgeNode : public runtime::Object {
class DependencyNode : public runtime::Object {
public:
/*! \brief The destination block */
/*! \brief The source of the dependency relation */
StmtSRef src;
/*! \brief The destination of the dependency relation */
StmtSRef dst;
/*! \brief The dependency type */
DepType type;
/*! \brief The dependency kind */
DepKind kind;

void VisitAttrs(AttrVisitor* v) {
v->Visit("src", &src);
v->Visit("dst", &dst);
v->Visit("type", &type);
v->Visit("kind", &kind);
}

static constexpr const char* _type_key = "tir.DepEdge";
TVM_DECLARE_FINAL_OBJECT_INFO(DepEdgeNode, Object);
static constexpr const char* _type_key = "tir.Dependency";
TVM_DECLARE_FINAL_OBJECT_INFO(DependencyNode, Object);
};

/*!
* \brief Managed reference to DepEdgeNode
* \sa DepEdgeNode
* \brief Managed reference to DependencyNode
* \sa DependencyNode
*/
class DepEdge : public runtime::ObjectRef {
class Dependency : public runtime::ObjectRef {
public:
/*! \brief Constructor */
explicit DepEdge(StmtSRef dst, DepType type);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DepEdge, ObjectRef, DepEdgeNode);
explicit Dependency(StmtSRef src, StmtSRef dst, DepKind kind);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Dependency, ObjectRef, DependencyNode);
};

/*! \brief An object recording the producer-consumer dependency between child blocks of a scope */
class BlockScopeNode : public runtime::Object {
public:
/*! \brief The forward dependency edges of the block */
std::unordered_map<StmtSRef, Array<DepEdge>, ObjectPtrHash, ObjectPtrEqual> forward_edges;
std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> src2deps;
/*! \brief The backward dependency edges of the block */
std::unordered_map<StmtSRef, Array<DepEdge>, ObjectPtrHash, ObjectPtrEqual> backward_edges;
std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
/*! \brief The mapping from the buffer to the blocks who write it */
std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;

Expand All @@ -164,19 +167,19 @@ class BlockScopeNode : public runtime::Object {
TVM_DECLARE_FINAL_OBJECT_INFO(BlockScopeNode, runtime::Object);

public:
/******** Dependency ********/
/******** DependencyNode ********/
/*!
* \brief Get all blocks the block depends on
* \param block_sref The queried block
* \return The predecessors edges
*/
TVM_DLL Array<DepEdge> GetPredecessors(const StmtSRef& block_sref) const;
TVM_DLL Array<Dependency> GetDepsBySrc(const StmtSRef& block_sref) const;
/*!
* \brief Get all blocks that depends on the block
* \param block_sref The queried block
* \return The successor edges
*/
TVM_DLL Array<DepEdge> GetSuccessors(const StmtSRef& block_sref) const;
TVM_DLL Array<Dependency> GetDepsByDst(const StmtSRef& block_sref) const;

/******** Property of a block ********/
/*!
Expand Down
4 changes: 1 addition & 3 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,10 @@ class ScheduleStateNode : public runtime::Object {
*
* \param src_sref The sref to the statement to be replaced
* \param tgt_stmt The statement to be replaced to
* \param block_sref_reuse Maps an new block (replaced to) back to an old block (to be replaced),
* \param block_sref_reuse Maps an old block (to be replaced) to a new block (replaced to),
* and enforces reuse of srefs between them (rather than create new srefs)
* i.e. after being replaced, the sref that points to the old block will point to the new one
* \note `loop_sref_reuse` will be automatically detected via loop vars
*
* TODO(@junrushao1994): change `block_sref_reuse` from "new -> old" to "old -> new"
*/
TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
const Map<Block, Block>& block_sref_reuse);
Expand Down
19 changes: 10 additions & 9 deletions python/tvm/tir/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,25 @@ def inline_mark() -> StmtSRef:
return _ffi_api_schedule.StmtSRefInlineMark() # pylint: disable=no-member


@_register_object("tir.DepEdge")
class DepEdge(Object):
@_register_object("tir.Dependency")
class Dependency(Object):
"""An edge in the dependency graph"""

kRAW = 0
kWAW = 1
kWAR = 2
kOpaque = 3

src: StmtSRef
dst: StmtSRef
type: int
kind: int


@_register_object("tir.BlockScope")
class BlockScope(Object):
"""Dependency Graph that stores read/write dependency between Blocks"""

def get_predecessors(self, block: StmtSRef) -> List[DepEdge]:
def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]:
"""Get the dependency predecessors of the block
Parameters
Expand All @@ -74,12 +75,12 @@ def get_predecessors(self, block: StmtSRef) -> List[DepEdge]:
Returns
-------
blocks: List of DepEdge
blocks: List of Dependency
The predecessors of the block
"""
return _ffi_api_schedule.BlockScopeGetPredecessors(self, block) # pylint: disable=no-member
return _ffi_api_schedule.BlockScopeGetDepsBySrc(self, block) # pylint: disable=no-member

def get_successor(self, block: StmtSRef) -> List[DepEdge]:
def get_deps_by_dst(self, block: StmtSRef) -> List[Dependency]:
"""Get the dependency successor of the block
Parameters
Expand All @@ -89,10 +90,10 @@ def get_successor(self, block: StmtSRef) -> List[DepEdge]:
Returns
-------
blocks: List of DepEdge
blocks: List of Dependency
The predecessors of the block
"""
return _ffi_api_schedule.BlockScopeGetSuccessors(self, block) # pylint: disable=no-member
return _ffi_api_schedule.BlockScopeGetDepsByDst(self, block) # pylint: disable=no-member


@_register_object("tir.ScheduleState")
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ inline void DelAnn(const tir::ScheduleState& sch, const tir::StmtSRef& sref,
ObjectPtr<tir::BlockNode> n = make_object<tir::BlockNode>(*block);
n->annotations = std::move(new_ann);
tir::Block p(n);
sch->Replace(sref, p, {{p, GetRef<tir::Block>(block)}});
sch->Replace(sref, p, {{GetRef<tir::Block>(block), p}});
} else {
LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey();
throw;
Expand Down Expand Up @@ -303,7 +303,7 @@ inline void AddAnn(const tir::ScheduleState& sch, const tir::StmtSRef& sref, con
ObjectPtr<tir::BlockNode> n = make_object<tir::BlockNode>(*block);
n->annotations = std::move(new_ann);
tir::Block p(n);
sch->Replace(sref, p, {{p, GetRef<tir::Block>(block)}});
sch->Replace(sref, p, {{GetRef<tir::Block>(block), p}});
} else {
LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey();
throw;
Expand Down
30 changes: 15 additions & 15 deletions src/tir/schedule/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ void VerifyRegionCover(const ScheduleState& self, const StmtSRef& consumer_block
// Maps a buffer var to its producers
std::unordered_map<const VarNode*, std::vector<Producer>> buffer_producers;
// Collect all producers to a buffer by enumerating all RAW predecessors of the consumer
for (const DepEdge& edge :
self->scopes.at(parent_block_sref)->GetPredecessors(consumer_block_sref)) {
if (edge->type != DepType::kRAW) {
for (const Dependency& edge :
self->scopes.at(parent_block_sref)->GetDepsByDst(consumer_block_sref)) {
if (edge->kind != DepKind::kRAW) {
continue;
}
// i.e. the RAW predecessor is producer
const StmtSRef& producer_block_sref = edge->dst;
const StmtSRef& producer_block_sref = edge->src;
for (const BufferRegion& output_region : producer_block_sref->GetStmt<BlockNode>()->writes) {
const VarNode* buffer_var = output_region->buffer->data.get();
buffer_producers[buffer_var].emplace_back(producer_block_sref, output_region);
Expand Down Expand Up @@ -340,27 +340,27 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent
}

Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sref) {
Array<DepEdge> pred_edges = self->scopes
.at(GetScopeRoot(block_sref)) //
->GetPredecessors(block_sref);
Array<Dependency> pred_edges = self->scopes
.at(GetScopeRoot(block_sref)) //
->GetDepsByDst(block_sref);
Array<StmtSRef> results;
results.reserve(pred_edges.size());
for (const DepEdge edge : pred_edges) {
if (edge->type == DepType::kRAW || edge->type == DepType::kWAW) {
results.push_back(edge->dst);
for (const Dependency& edge : pred_edges) {
if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) {
results.push_back(edge->src);
}
}
return results;
}

Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref) {
Array<DepEdge> succ_edges = self->scopes
.at(GetScopeRoot(block_sref)) //
->GetSuccessors(block_sref);
Array<Dependency> succ_edges = self->scopes
.at(GetScopeRoot(block_sref)) //
->GetDepsBySrc(block_sref);
Array<StmtSRef> results;
results.reserve(succ_edges.size());
for (const DepEdge edge : succ_edges) {
if (edge->type == DepType::kRAW || edge->type == DepType::kWAW) {
for (const Dependency& edge : succ_edges) {
if (edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) {
results.push_back(edge->dst);
}
}
Expand Down
48 changes: 25 additions & 23 deletions src/tir/schedule/block_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ using TBufferReaderWriter =

/*!
* \brief Add a dependency edge.
* \param from The source of the dependency
* \param to The destination of the dependecy
* \param type Type of the dependency
* \param src The source of the dependency
* \param dst The destination of the dependecy
* \param kind Type of the dependency
*/
void AddEdge(BlockScopeNode* self, const StmtSRef& from, const StmtSRef& to, DepType type) {
if (!from.same_as(to)) {
self->forward_edges[from].push_back(DepEdge(to, type));
self->backward_edges[to].push_back(DepEdge(from, type));
void AddEdge(BlockScopeNode* self, const StmtSRef& src, const StmtSRef& dst, DepKind kind) {
if (!src.same_as(dst)) {
Dependency dep(src, dst, kind);
self->src2deps[src].push_back(dep);
self->dst2deps[dst].push_back(dep);
}
}

Expand Down Expand Up @@ -67,15 +68,15 @@ void AddChildBlock(BlockScopeNode* self, const StmtSRef& child_sref,
for (const BufferRegion& region : block->reads) {
if (buffer_writers.count(region->buffer)) {
for (const StmtSRef& from : buffer_writers[region->buffer]) {
AddEdge(self, from, child_sref, DepType::kRAW);
AddEdge(self, from, child_sref, DepKind::kRAW);
}
}
}
// Step 3. Update WAW dependency
for (const BufferRegion& region : block->writes) {
if (buffer_writers.count(region->buffer)) {
for (const StmtSRef& from : buffer_writers[region->buffer]) {
AddEdge(self, from, child_sref, DepType::kWAW);
AddEdge(self, from, child_sref, DepKind::kWAW);
}
}
}
Expand Down Expand Up @@ -140,10 +141,11 @@ StmtSRef StmtSRef::RootMark() {
return result;
}

DepEdge::DepEdge(StmtSRef dst, DepType type) {
ObjectPtr<DepEdgeNode> node = make_object<DepEdgeNode>();
Dependency::Dependency(StmtSRef src, StmtSRef dst, DepKind kind) {
ObjectPtr<DependencyNode> node = make_object<DependencyNode>();
node->src = std::move(src);
node->dst = std::move(dst);
node->type = type;
node->kind = kind;
data_ = std::move(node);
}

Expand All @@ -160,9 +162,9 @@ BlockScope::BlockScope(const Array<StmtSRef>& leaf_block_srefs) {

/******** Dependency ********/

Array<DepEdge> BlockScopeNode::GetPredecessors(const StmtSRef& block_sref) const {
const std::unordered_map<StmtSRef, Array<DepEdge>, ObjectPtrHash, ObjectPtrEqual>& edges =
this->backward_edges;
Array<Dependency> BlockScopeNode::GetDepsBySrc(const StmtSRef& block_sref) const {
const std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual>& edges =
this->src2deps;
auto iter = edges.find(block_sref);
if (iter != edges.end()) {
return iter->second;
Expand All @@ -171,9 +173,9 @@ Array<DepEdge> BlockScopeNode::GetPredecessors(const StmtSRef& block_sref) const
}
}

Array<DepEdge> BlockScopeNode::GetSuccessors(const StmtSRef& block_sref) const {
const std::unordered_map<StmtSRef, Array<DepEdge>, ObjectPtrHash, ObjectPtrEqual>& edges =
this->forward_edges;
Array<Dependency> BlockScopeNode::GetDepsByDst(const StmtSRef& block_sref) const {
const std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual>& edges =
this->dst2deps;
auto iter = edges.find(block_sref);
if (iter != edges.end()) {
return iter->second;
Expand Down Expand Up @@ -336,17 +338,17 @@ bool BlockScopeNode::CanMergeReduction(const StmtSRef& init_sref,
/******** FFI ********/

TVM_REGISTER_NODE_TYPE(StmtSRefNode);
TVM_REGISTER_NODE_TYPE(DepEdgeNode);
TVM_REGISTER_NODE_TYPE(DependencyNode);
TVM_REGISTER_NODE_TYPE(BlockScopeNode);

TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark") //
.set_body_typed(StmtSRef::RootMark);
TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark") //
.set_body_typed(StmtSRef::InlineMark);
TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetPredecessors")
.set_body_method<BlockScope>(&BlockScopeNode::GetPredecessors);
TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetSuccessors")
.set_body_method<BlockScope>(&BlockScopeNode::GetSuccessors);
TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc")
.set_body_method<BlockScope>(&BlockScopeNode::GetDepsBySrc);
TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst")
.set_body_method<BlockScope>(&BlockScopeNode::GetDepsByDst);
TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt")
.set_body_typed([](StmtSRef sref) -> Optional<Stmt> {
return sref->stmt != nullptr ? GetRef<Stmt>(sref->stmt) : Optional<Stmt>(NullOpt);
Expand Down
20 changes: 10 additions & 10 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,19 @@ struct SRefTranslator {
return result;
}

/*! \brief Translate Array<DepEdge> */
Array<DepEdge> Trans(const Array<DepEdge>& list) {
Array<DepEdge> result;
/*! \brief Translate Array<Dependency> */
Array<Dependency> Trans(const Array<Dependency>& list) {
Array<Dependency> result;
result.reserve(list.size());
for (const DepEdge& elem : list) {
result.push_back(DepEdge(Trans(elem->dst), elem->type));
for (const Dependency& elem : list) {
result.push_back(Dependency(Trans(elem->src), Trans(elem->dst), elem->kind));
}
return result;
}

/*! \brief Translate SMap<StmtSRef, Array<DepEdge>> */
SMap<StmtSRef, Array<DepEdge>> Trans(const SMap<StmtSRef, Array<DepEdge>>& map) {
SMap<StmtSRef, Array<DepEdge>> result;
/*! \brief Translate SMap<StmtSRef, Array<Dependency>> */
SMap<StmtSRef, Array<Dependency>> Trans(const SMap<StmtSRef, Array<Dependency>>& map) {
SMap<StmtSRef, Array<Dependency>> result;
result.reserve(map.size());
for (const auto& kv : map) {
result[Trans(kv.first)] = Trans(kv.second);
Expand All @@ -126,8 +126,8 @@ struct SRefTranslator {
const StmtSRef& old_sref = kv.first;
const BlockScope& old_scope = kv.second;
ObjectPtr<BlockScopeNode> scope = make_object<BlockScopeNode>();
scope->forward_edges = Trans(old_scope->forward_edges);
scope->backward_edges = Trans(old_scope->backward_edges);
scope->src2deps = Trans(old_scope->src2deps);
scope->dst2deps = Trans(old_scope->dst2deps);
scope->buffer_writers = Trans(old_scope->buffer_writers);
result.Set(Trans(old_sref), BlockScope(std::move(scope)));
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitives/bind_annotate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ void DoubleBuffer(ScheduleState self, const StmtSRef& block_sref) {
<< "ValueError: 'double_buffer' expects 'block' with only one write buffer";
Block new_block =
WithAnnotation(block_ptr, tir::attr::double_buffer_scope, IntImm(DataType::Int(32), 1));
self->Replace(block_sref, new_block, {{new_block, GetRef<Block>(block_ptr)}});
self->Replace(block_sref, new_block, {{GetRef<Block>(block_ptr), new_block}});
}

} // namespace schedule
Expand Down
Loading

0 comments on commit e46b917

Please sign in to comment.