diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index dc3f77e3180c..92cc64dedba6 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -37,13 +37,13 @@ namespace tvm { namespace relay { -class CommonSubexprEliminator : public ExprMutator { +class CommonSubexprEliminator : public MixedModeMutator { public: explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip) : fskip_(fskip) {} - Expr VisitExpr_(const CallNode* call) final { + Expr Rewrite_(const CallNode* call, const Expr& post) final { static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); - Expr new_expr = ExprMutator::VisitExpr_(call); + Expr new_expr = post; const CallNode* new_call = new_expr.as(); CHECK(new_call); const OpNode* op = new_call->op.as(); @@ -80,8 +80,8 @@ class CommonSubexprEliminator : public ExprMutator { return new_expr; } - Expr VisitExpr_(const TupleGetItemNode* op) final { - Expr new_expr = ExprMutator::VisitExpr_(op); + Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { + Expr new_expr = post; const TupleGetItemNode* new_tuple_item = new_expr.as(); CHECK(new_tuple_item);