diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index fad4fb781b5a8..2f96c4b52b428 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -161,6 +161,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { current->extern_ref = true; } } + void AddNode(const tvm::Node* key) { auto it = graph_.node_map.find(key); CHECK(it != graph_.node_map.end()) @@ -173,7 +174,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } // Post order tree - void VisitExpr_(const FunctionNode* op) { + void VisitExpr_(const FunctionNode* op) final { for (auto param : op->params) { this->Update(param, nullptr, kOpaque); } @@ -181,7 +182,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ExprVisitor::VisitExpr_(op); } - void VisitExpr_(const ConstantNode* op) { + void VisitExpr_(const ConstantNode* op) final { this->AddNode(op); Node* node = graph_.node_map.at(op); DataType dtype = TVMType2Type(op->data->dtype); @@ -201,7 +202,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } } - void VisitExpr_(const CallNode* call) { + void VisitExpr_(const CallNode* call) final { CHECK(graph_.node_map.count(call)); Node* node = graph_.node_map.at(call); static auto fpattern = @@ -231,7 +232,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(call); } - void VisitExpr_(const TupleNode* op) { + void VisitExpr_(const TupleNode* op) final { CHECK(graph_.node_map.count(op)); Node* tuple_node = graph_.node_map.at(op); tuple_node->pattern = kInjective; @@ -246,7 +247,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const TupleGetItemNode* op) { + void VisitExpr_(const TupleGetItemNode* op) final { CHECK(graph_.node_map.count(op)); Node* node = graph_.node_map.at(op); this->Update(op->tuple, node, kOpaque); @@ -254,11 +255,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const VarNode* op) { + void VisitExpr_(const VarNode* op) final { this->AddNode(op); } - void VisitExpr_(const LetNode* op) { + void VisitExpr_(const LetNode* op) final { // do not fuse through let. this->Update(op->var, nullptr, kOpaque); this->Update(op->value, nullptr, kOpaque); @@ -267,7 +268,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const IfNode* op) { + void VisitExpr_(const IfNode* op) final { // do not fuse through if. this->Update(op->cond, nullptr, kOpaque); this->Update(op->true_branch, nullptr, kOpaque); @@ -275,6 +276,25 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ExprVisitor::VisitExpr_(op); this->AddNode(op); } + + void VisitExpr_(const RefNewNode* op) final { + this->Update(op->value, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } + + void VisitExpr_(const RefReadNode* op) final { + this->Update(op->ref, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } + + void VisitExpr_(const RefWriteNode* op) final { + this->Update(op->ref, nullptr, kOpaque); + this->Update(op->value, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } }; IndexedForwardGraph IndexedForwardGraph::Create(