Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Jan 23, 2019
1 parent ec8327d commit 7226df5
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
current->extern_ref = true;
}
}

void AddNode(const tvm::Node* key) {
auto it = graph_.node_map.find(key);
CHECK(it != graph_.node_map.end())
Expand All @@ -173,15 +174,15 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}

// Post order tree
void VisitExpr_(const FunctionNode* op) {
void VisitExpr_(const FunctionNode* op) final {
for (auto param : op->params) {
this->Update(param, nullptr, kOpaque);
}
this->Update(op->body, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const ConstantNode* op) {
void VisitExpr_(const ConstantNode* op) final {
this->AddNode(op);
Node* node = graph_.node_map.at(op);
DataType dtype = TVMType2Type(op->data->dtype);
Expand All @@ -201,7 +202,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}
}

void VisitExpr_(const CallNode* call) {
void VisitExpr_(const CallNode* call) final {
CHECK(graph_.node_map.count(call));
Node* node = graph_.node_map.at(call);
static auto fpattern =
Expand Down Expand Up @@ -231,7 +232,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(call);
}

void VisitExpr_(const TupleNode* op) {
void VisitExpr_(const TupleNode* op) final {
CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective;
Expand All @@ -246,19 +247,19 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(op);
}

void VisitExpr_(const TupleGetItemNode* op) {
void VisitExpr_(const TupleGetItemNode* op) final {
CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op);
this->Update(op->tuple, node, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}

void VisitExpr_(const VarNode* op) {
void VisitExpr_(const VarNode* op) final {
this->AddNode(op);
}

void VisitExpr_(const LetNode* op) {
void VisitExpr_(const LetNode* op) final {
// do not fuse through let.
this->Update(op->var, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
Expand All @@ -267,14 +268,33 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(op);
}

void VisitExpr_(const IfNode* op) {
void VisitExpr_(const IfNode* op) final {
// do not fuse through if.
this->Update(op->cond, nullptr, kOpaque);
this->Update(op->true_branch, nullptr, kOpaque);
this->Update(op->false_branch, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}

void VisitExpr_(const RefNewNode* op) final {
this->Update(op->value, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}

void VisitExpr_(const RefReadNode* op) final {
this->Update(op->ref, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}

void VisitExpr_(const RefWriteNode* op) final {
this->Update(op->ref, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
};

IndexedForwardGraph IndexedForwardGraph::Create(
Expand Down

0 comments on commit 7226df5

Please sign in to comment.