diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 2319f8baec00a..c72612791b521 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -415,6 +415,32 @@ class TupleGetItemNode : public ExprNode { RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); +/*! + * \brief Base class of the temporary expression. + * + * TempExprs are pass specific expression that can be + * useful to define intermediate result in the + * rewriting pass such as layout or type transformation. + * + * Subclass TempExprNode allows us to pattern match on + * specific kind TempExpr and use them for expression rewriting. + * + * TempExpr should only be used within a pass, + */ +class TempExprNode : public ExprNode { + public: + /*! + * \brief Convert the expression to a normal(non-temp) Expr. + * \return The corresponding normal(non-temp) expression. + */ + virtual Expr Realize() const = 0; + + static constexpr const char* _type_key = "relay.TempExpr"; + TVM_DECLARE_BASE_NODE_INFO(TempExprNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); + // implementataions template inline const TTypeNode* ExprNode::type_as() const { diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index ad447ad13cee4..d3c5edd314615 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -276,6 +276,16 @@ class GenericOpMap { */ template inline ValueType get(const Op& op, ValueType def_value) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param expr The key to the map + * \param def_value The default value when the key does not exist + * or if expr is not an Op. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + inline ValueType get(const Expr& expr, ValueType def_value) const; private: friend class OpRegistry; @@ -313,6 +323,14 @@ class OpMap { * \return the const reference to the content value. */ inline ValueType get(const Op& op, ValueType def_value) const; + /*! + * \brief get the corresponding value element at op with default value. + * \param expr The key to the map + * \param def_value The default value when the key does not exist + * or if expr is not an Op. + * \return the const reference to the content value. + */ + inline ValueType get(const Expr& expr, ValueType def_value) const; private: friend class Op; @@ -496,6 +514,21 @@ inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { } } +template +inline ValueType GenericOpMap::get(const Expr& expr, ValueType value) const { + CHECK(expr.defined()); + if (const OpNode* op = expr.as()) { + const uint32_t idx = op->index_; + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return value; + } + } else { + return value; + } +} + template inline int OpMap::count(const Op& op) const { return map_.count(op); @@ -505,12 +538,19 @@ template inline ValueType OpMap::operator[](const Op& op) const { return map_[op]; } + template inline ValueType OpMap::get(const Op& op, ValueType def_value) const { return map_.get(op, def_value); } +template +inline ValueType OpMap::get(const Expr& expr, + ValueType def_value) const { + return map_.get(expr, def_value); +} + /*! * \brief Check that an expression is a "primtive operator". * diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index f80d51772ae21..3d9fa56855c36 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -85,6 +85,25 @@ using FTVMSchedule = runtime::TypedPackedFunc< Schedule(const Attrs& attrs, const Array& outs, const Target& target)>; + +/*! + * \brief Forward rewriting rule for a specific op. + * + * \param ref_call The reference old call type to be rewritten. + * We can make use of the op and type information. + * \param new_args The new arguments (some of them could be TempExpr). + * \param ctx Optional context information about ref_call. + * \return The rewriten result call, can also return nullptr, + * which indicate the rewriter should use the default fallback + * rule that realizes all its input and compose the call. + * + * \note When we register the function, we can register + * a different signature with ctx to be a specific node type. + */ +using FForwardRewrite = runtime::TypedPackedFunc< + Expr(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx)>; } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_ATTR_TYPES_H_ diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 3ca81ebd027da..4410ed0d0de13 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -158,6 +158,17 @@ Expr FoldConstant(const Expr& expr); */ Expr FuseOps(const Expr& expr, int fuse_opt_level); +/*! + * \brief Apply rewrite rules to rewrite the expr in post DFS order. + * \param expr The expression. + * \param rewrite_map_attr_name The Op's attr name which corresponds to the rewrite + * rule function. + * \param fcontext Additional callback to provide context argument for each call node. + * \return The rewritten expression. + */ +Expr ForwardRewrite(const Expr& expr, + const std::string& rewrite_map_attr_name, + std::function fcontext = nullptr); /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index f25785d39eeb1..0aeb7f2b15135 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -73,6 +73,8 @@ class PackedFunc { using FType = std::function; /*! \brief default constructor */ PackedFunc() {} + /*! \brief constructor from null */ + PackedFunc(std::nullptr_t null) {} // NOLINT(*) /*! * \brief constructing a packed function from a std::function. * \param body the internal container of packed function. diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index 038f34df57608..d3f7043088ebe 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -87,23 +87,6 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { return ret; } -/*! - * \param Get function from op_map. - * \param op_map The OpMap. - * \param op The operator being called. - * \tparam ValueType the content value type. - * \return The result value map. - */ -template -ValueType GetFunc(const OpMap& op_map, - const Expr& op) { - if (const OpNode* opnode = op.as()) { - return op_map.get(GetRef(opnode), ValueType()); - } else { - return ValueType(); - } -} - /*! * \brief Preparation function for pass scale forward. * \param call The call node. @@ -114,7 +97,7 @@ using FForwardPrep = runtime::TypedPackedFunc< Array (const Call& call, const AxesSet& out_scale_axes)>; /*! \brief Axis scale tuple. */ -class STupleNode : public Node { +class ScaledExprNode : public TempExprNode { public: /*! \brief The value */ Expr value; @@ -123,29 +106,26 @@ class STupleNode : public Node { /*! \brief The scaling factor */ Expr scale = NullValue(); + Expr Realize() const final { + CHECK(!axes.defined()) + << "outstanding scale"; + return value; + } + void VisitAttrs(AttrVisitor* v) final { v->Visit("value", &value); v->Visit("axes", &axes); v->Visit("scale", &scale); } - static constexpr const char* _type_key = "relay.fold_scale_axis.STupleNode"; - TVM_DECLARE_NODE_TYPE_INFO(STupleNode, Node); + static constexpr const char* _type_key = "relay.fold_scale_axis.ScaledExpr"; + TVM_DECLARE_NODE_TYPE_INFO(ScaledExprNode, TempExprNode); }; -RELAY_DEFINE_NODE_REF(STuple, STupleNode, NodeRef); - -/*! - * \brief The transform function, transform an old call to - * a new one given the new args. - * \param ref_call Reference call node that represent the op and the types. - * \param expected_out_axes The scale axes allowed in the output. - * \param sargs The input arguments. - */ -using FForwardTransform = TypedPackedFunc< - STuple(const Call& ref_call, - const AxesSet& expected_out_axes, - const Array& sargs)>; +using FForwardRewrite = TypedPackedFunc< + Expr(const Call& ref_call, + const Array& new_args, + const AxesSet& expeced_out_axes)>; //---------------------------------------------- // Generic Visitors for FScaleAxisForward @@ -219,7 +199,7 @@ class ForwardPrep : private ExprVisitor { out_axes = NullValue(); } // pass the message back to all the children it references. - auto f = GetFunc(fprep, call->op); + auto f = fprep.get(call->op, nullptr); if (f != nullptr) { Array in_axes = f(GetRef(call), out_axes); CHECK_EQ(in_axes.size(), call->args.size()); @@ -261,87 +241,6 @@ class ForwardPrep : private ExprVisitor { } }; -class ForwardTransformer : private ExprMutator { - public: - // Transform expression. - Expr Fold(Expr expr) { - expected_scale_axes_ = - ForwardPrep().Prepare(expr); - return this->Mutate(expr); - } - - private: - // Valid axes on each node. - std::unordered_map expected_scale_axes_; - std::unordered_map scale_memo_; - // If user simply call mutate, - // then only Expr is returned and we cannot - // accept outstanding scales. - Expr VisitExpr(const Expr& expr) final { - Expr res = ExprMutator::VisitExpr(expr); - CHECK(!scale_memo_.count(expr.get())) - << "Outstanding scale"; - return res; - } - - STuple GetSTuple(const Expr& expr) { - Expr res = ExprMutator::VisitExpr(expr); - auto it = scale_memo_.find(expr.get()); - if (it != scale_memo_.end()) { - CHECK(it->second->value.same_as(res)); - return it->second; - } else { - auto node = make_node(); - node->value = res; - return STuple(node); - } - } - - Expr VisitExpr_(const CallNode* call_node) final { - static const auto& ftransform = - Op::GetAttr("FScaleAxisForwardTransform"); - auto new_op = this->Mutate(call_node->op); - bool has_scale = false; - bool unchanged = call_node->op.same_as(new_op); - - Array call_sargs; - Array call_args; - for (auto arg : call_node->args) { - STuple new_sarg = this->GetSTuple(arg); - unchanged &= new_sarg->value.same_as(arg); - if (new_sarg->axes.defined()) has_scale = true; - call_sargs.push_back(new_sarg); - call_args.push_back(new_sarg->value); - } - - // get expected scale axes. - AxesSet expected_out_axes; - auto axis_it = expected_scale_axes_.find(call_node); - if (axis_it != expected_scale_axes_.end()) { - expected_out_axes = axis_it->second; - } - // propagation function - auto f = GetFunc(ftransform, call_node->op); - if (f != nullptr) { - STuple sret = f(GetRef(call_node), expected_out_axes, call_sargs); - if (sret.defined()) { - if (sret->axes.defined()) { - scale_memo_[call_node] = sret; - } - return sret->value; - } - } - // normal path - CHECK(!has_scale) << "Outstanding scale, on op=" << call_node->op; - if (unchanged) { - return GetRef(call_node); - } else { - return CallNode::make( - new_op, call_args, call_node->attrs, call_node->type_args); - } - } -}; - //---------------------------------------------- // Per operator defs for FScaleAxisForward //---------------------------------------------- @@ -351,30 +250,31 @@ Array ReluForwardPrep(const Call& call, AxesSet out) { return {out}; } -STuple ReluForwardTransform(const Call& ref_call, - const AxesSet& expected_axes, - const Array& sargs) { - if (!sargs[0]->axes.defined()) return STuple(); +Expr ReluForwardRewrite(const Call& ref_call, + const Array& new_args, + const AxesSet& expected_axes) { + const auto* input = new_args[0].as(); + if (input == nullptr) return Expr(nullptr); // return transformed conv2d - auto rnode = make_node(); + auto rnode = make_node(); rnode->value = CallNode::make( - ref_call->op, {sargs[0]->value}, ref_call->attrs, ref_call->type_args); - rnode->scale = sargs[0]->scale; - rnode->axes = sargs[0]->axes; - return STuple(rnode); + ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); + rnode->scale = input->scale; + rnode->axes = input->axes; + return Expr(rnode); } RELAY_REGISTER_OP("nn.relu") .set_attr("FScaleAxisForwardPrep", ReluForwardPrep); RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisForwardTransform", ReluForwardTransform); +.set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); RELAY_REGISTER_OP("nn.leaky_relu") .set_attr("FScaleAxisForwardPrep", ReluForwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisForwardTransform", ReluForwardTransform); +.set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); // AddSub Array AddSubForwardPrep(const Call& call, AxesSet out_axes) { @@ -391,69 +291,69 @@ Array AddSubForwardPrep(const Call& call, AxesSet out_axes) { } } -STuple AddSubForwardTransform(const Call& ref_call, - const AxesSet& expected_out_axes, - const Array& sargs) { - if (!sargs[0]->axes.defined() && !sargs[1]->axes.defined()) { - return STuple(); - } +Expr AddSubForwardRewrite(const Call& ref_call, + const Array& new_args, + const AxesSet& expected_out_axes) { + const auto* slhs = new_args[0].as(); + const auto* srhs = new_args[1].as(); + if (!slhs && !srhs) return Expr(); const auto* tlhs = ref_call->args[0]->type_as(); const auto* trhs = ref_call->args[1]->type_as(); + auto rnode = make_node(); - auto rnode = make_node(); - if (sargs[0]->axes.defined()) { - CHECK(!sargs[1]->axes.defined()); - CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, sargs[0]->axes)); + if (slhs != nullptr) { + CHECK(srhs == nullptr); + CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); Expr scale = ExpandBiasToMatchAxis( - sargs[0]->scale, tlhs->shape.size(), sargs[0]->axes); - Expr rhs = Divide(sargs[1]->value, scale); - rnode->value = CallNode::make(ref_call->op, {sargs[0]->value, rhs}, + slhs->scale, tlhs->shape.size(), slhs->axes); + Expr rhs = Divide(new_args[1], scale); + rnode->value = CallNode::make(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); - rnode->scale = sargs[0]->scale; - rnode->axes = sargs[0]->axes; + rnode->scale = slhs->scale; + rnode->axes = slhs->axes; } else { - CHECK(sargs[1]->axes.defined()); - CHECK(sargs[0]->axes.defined()); - CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, sargs[1]->axes)); + CHECK(slhs != nullptr); + CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); Expr scale = ExpandBiasToMatchAxis( - sargs[1]->scale, trhs->shape.size(), sargs[1]->axes); - Expr lhs = Divide(sargs[0]->value, scale); - rnode->value = CallNode::make(ref_call->op, {lhs, sargs[1]->value}, + srhs->scale, trhs->shape.size(), srhs->axes); + Expr lhs = Divide(new_args[0], scale); + rnode->value = CallNode::make(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); - rnode->scale = sargs[1]->scale; - rnode->axes = sargs[1]->axes; + rnode->scale = srhs->scale; + rnode->axes = srhs->axes; } - return STuple(rnode); + return Expr(rnode); } RELAY_REGISTER_OP("add") .set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisForwardTransform", AddSubForwardTransform); +.set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); RELAY_REGISTER_OP("subtract") .set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisForwardTransform", AddSubForwardTransform); +.set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); // Producer operators // Multiply produces the scale-axis pair. -STuple MultiplyForwardTransform(const Call& ref_call, - const AxesSet& expected_out_axes, - const Array& sargs) { - if (!expected_out_axes.defined()) return STuple(); +Expr MultiplyForwardRewrite(const Call& ref_call, + const Array& new_args, + const AxesSet& expected_out_axes) { + if (!expected_out_axes.defined()) return Expr(); // TODO(tvm-team) allow same axes accumulation // not as important because it is less common in nn. - CHECK(!sargs[0]->axes.defined()); - CHECK(!sargs[1]->axes.defined()); + const auto* slhs = new_args[0].as(); + const auto* srhs = new_args[1].as(); + CHECK(!slhs && !srhs); + const auto* tlhs = ref_call->args[0]->type_as(); const auto* trhs = ref_call->args[1]->type_as(); - - Expr lhs = sargs[0]->value; - Expr rhs = sargs[1]->value; - auto rnode = make_node(); + Expr lhs = new_args[0]; + Expr rhs = new_args[1]; + auto rnode = make_node(); if (MatchBroadcastToLeftAxes(tlhs, trhs, expected_out_axes, &rhs)) { rnode->value = lhs; rnode->scale = rhs; @@ -463,11 +363,11 @@ STuple MultiplyForwardTransform(const Call& ref_call, rnode->scale = lhs; rnode->axes = expected_out_axes; } - return STuple(rnode); + return Expr(rnode); } RELAY_REGISTER_OP("multiply") -.set_attr("FScaleAxisForwardTransform", MultiplyForwardTransform); +.set_attr("FScaleAxisForwardRewrite", MultiplyForwardRewrite); // Consumer operators // Conv2D send out requirement of axis folding. @@ -500,13 +400,14 @@ Array Conv2DForwardPrep(const Call& call, AxesSet out) { } // Conv2D consumes the scale axis during transformation. -STuple Conv2DForwardTransform(const Call& ref_call, - const AxesSet& expected_axes, - const Array& sargs) { +Expr Conv2DForwardRewrite(const Call& ref_call, + const Array& new_args, + const AxesSet& expected_axes) { // if data do not have scale, normal transform path. - STuple sdata = sargs[0]; - if (!sdata->scale.defined()) return STuple(); - CHECK(sdata->axes.defined()); + const auto* sdata = new_args[0].as(); + const auto* sweight = new_args[1].as(); + if (sdata == nullptr) return Expr(); + if (sweight != nullptr) return Expr(); const auto* param = ref_call->attrs.as(); CHECK(param != nullptr); Layout data_layout(param->data_layout); @@ -524,7 +425,8 @@ STuple Conv2DForwardTransform(const Call& ref_call, // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, weight_layout); CHECK(param->groups == 1 || is_depthwise_conv2d); - Expr weight = sargs[1]->value; + + Expr weight = new_args[1]; // match the ic_axis if (is_depthwise_conv2d) { @@ -537,21 +439,30 @@ STuple Conv2DForwardTransform(const Call& ref_call, weight = Multiply(weight, scale); } // return transformed conv2d - auto rnode = make_node(); - rnode->value = CallNode::make( + return CallNode::make( ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); - return STuple(rnode); } RELAY_REGISTER_OP("nn.conv2d") .set_attr("FScaleAxisForwardPrep", Conv2DForwardPrep); RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisForwardTransform", Conv2DForwardTransform); +.set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); Expr ForwardFoldScaleAxis(Expr data) { - return ForwardTransformer().Fold(data); + auto expected_scale_axes = + ForwardPrep().Prepare(data); + auto fcontext = [&](const Call& call) -> NodeRef{ + auto it = expected_scale_axes.find(call.get()); + if (it != expected_scale_axes.end()) { + return it->second; + } else { + return NodeRef(nullptr); + } + }; + return ForwardRewrite( + data, "FScaleAxisForwardRewrite", fcontext); } // Expose the FoldScaleAxisFoward @@ -602,7 +513,7 @@ class BackwardPrep : private ExprVisitor { ExprVisitor::VisitExpr_(call); static const auto& fprep = Op::GetAttr("FScaleAxisBackwardPrep"); - auto f = GetFunc(fprep, call->op); + auto f = fprep.get(call->op, nullptr); if (f == nullptr) return; auto rit = ref_counter_.find(call); CHECK(rit != ref_counter_.end()); @@ -705,7 +616,7 @@ Expr BackwardTransformerNode::Transform( const CallNode* call_node, AxesSet axes, Expr scale) { static const auto& ftransform = Op::GetAttr("FScaleAxisBackwardTransform"); - auto f = GetFunc(ftransform, call_node->op); + auto f = ftransform.get(call_node->op, nullptr); if (f != nullptr) { return f(GetRef(call_node), axes, diff --git a/src/relay/pass/forward_rewrite.cc b/src/relay/pass/forward_rewrite.cc new file mode 100644 index 0000000000000..9c1e35782e926 --- /dev/null +++ b/src/relay/pass/forward_rewrite.cc @@ -0,0 +1,132 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file forward_rewrite.cc + * \brief Apply rewriting rules in a forward fashion. + */ +#include +#include +#include + +namespace tvm { +namespace relay { + +// Realizer class that realizes the expression +// Note that we can take benefit of its internal memo +// so that calling realize repeatively won't hurt perf. +class TempRealizer : private ExprMutator { + public: + Expr Realize(Expr expr) { + return VisitExpr(expr); + } + + private: + Expr VisitExpr(const Expr& expr) final { + auto it = memo_.find(expr); + if (it != memo_.end()) { + return it->second; + } else { + Expr res; + if (const auto* temp = expr.as_derived()) { + res = temp->Realize(); + + } else { + res = ExprFunctor::VisitExpr(expr); + } + memo_[res] = res; + return res; + } + } +}; + +class ForwardRewriter : private ExprMutator { + public: + ForwardRewriter(const OpMap& rewrite_map, + std::function fcontext) + : rewrite_map_(rewrite_map), + fcontext_(fcontext) { + } + + // Transform expression. + Expr Rewrite(Expr expr) { + return this->VisitExpr(expr); + } + + private: + // The rewrite rule. + const OpMap& rewrite_map_; + // The context. + std::function fcontext_{nullptr}; + // internal realizer + TempRealizer realizer_; + + Expr VisitExpr(const Expr& expr) final { + // by default always realize. + return realizer_.Realize(ExprMutator::VisitExpr(expr)); + } + + // Visit and allow non-realized version. + Expr GetTempExpr(const Expr& expr) { + return ExprMutator::VisitExpr(expr); + } + + // Automatic fold TupleGetItem. + Expr VisitExpr_(const TupleGetItemNode* op) final { + Expr tuple = this->GetTempExpr(op->tuple); + if (const auto* ptuple = tuple.as()) { + return ptuple->fields[op->index]; + } else { + if (tuple.same_as(op->tuple)) { + return GetRef(op); + } else { + return TupleGetItemNode::make(tuple, op->index); + } + } + } + + Expr VisitExpr_(const CallNode* call_node) final { + const Call& ref_call = GetRef(call_node); + PackedFunc frewrite = rewrite_map_.get(call_node->op, nullptr); + + auto new_op = this->Mutate(call_node->op); + bool unchanged = call_node->op.same_as(new_op); + + Array call_args; + for (auto arg : call_node->args) { + Expr new_arg = this->GetTempExpr(arg); + if (frewrite == nullptr) { + new_arg = realizer_.Realize(new_arg); + } + unchanged &= new_arg.same_as(arg); + call_args.push_back(new_arg); + } + // try to rewrite. + if (frewrite != nullptr) { + Expr res = frewrite( + ref_call, call_args, + fcontext_ != nullptr ? fcontext_(ref_call) : NodeRef(nullptr)); + if (res.defined()) return res; + // abort, use old rule + for (size_t i = 0; i < call_args.size(); ++i) { + Expr arg = call_args[i]; + Expr new_arg = realizer_.Realize(arg); + if (!arg.same_as(new_arg)) { + call_args.Set(i, new_arg); + unchanged = false; + } + } + } + if (unchanged) return ref_call; + return CallNode::make( + new_op, call_args, call_node->attrs, call_node->type_args); + } +}; + +Expr ForwardRewrite(const Expr& expr, + const std::string& rewrite_map_name, + std::function fcontext) { + auto rewrite_map = Op::GetAttr(rewrite_map_name); + return ForwardRewriter(rewrite_map, fcontext).Rewrite(expr); +} +} // namespace relay +} // namespace tvm