diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index d0f8a36e267ef..e0267ae526d66 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -58,9 +58,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -301,7 +303,7 @@ Pass CreateModulePass( * \return The created function pass. */ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< - Function(Function, Module, PassContext)>& pass_func, + Function(Function, Module, PassContext)>& pass_func, int opt_level, const std::string& name, const tvm::Array& required); @@ -441,7 +443,22 @@ TVM_DLL Pass InferType(); * * \return The pass. */ -TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip); +TVM_DLL Pass EliminateCommonSubexpr( + PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + Expr expr = args[0]; + if (expr.as()) { + auto call_node = expr.as(); + auto op_node = call_node->op.as(); + if (op_node->name == "cast") { + auto attrs = call_node->attrs.as(); + if (attrs->dtype == HalideIR::Int(32)) { + *rv = true; + } + } + } + *rv = false; + }) +); /*! * \brief Combine parallel 2d convolutions into a single convolution if the diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index d76f9c786a879..4c39cfa1f91ea 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -23,9 +23,7 @@ */ #include #include -#include #include -#include #include #include @@ -307,24 +305,7 @@ class RelayBuildModule : public runtime::ModuleNode { const std::unordered_map& params) { Array pass_seqs; pass_seqs.push_back(transform::SimplifyInference()); - - // Can we move to the pass implementation file and make it as default? - auto fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - Expr expr = args[0]; - if (expr.as()) { - auto call_node = expr.as(); - auto op_node = call_node->op.as(); - if (op_node->name == "cast") { - auto attrs = call_node->attrs.as(); - if (attrs->dtype == HalideIR::Int(32)) { - *rv = true; - } - } - } - *rv = false; - }); - - pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::EliminateCommonSubexpr()); pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeOps()); diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index a7a403078b897..3a7054bb9a288 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -289,7 +289,7 @@ class SequentialNode : public PassNode { * * \return true if the pass is enabled. Otherwise, false. */ - bool pass_enabled(const std::string& pass_name) const; + bool PassEnabled(const std::string& pass_name) const; /*! * \brief Resolve the pass dependency. It globs all required passes by @@ -473,7 +473,7 @@ std::unordered_set SequentialNode::RequiredPasses( return ret; } -bool SequentialNode::pass_enabled(const std::string& pass_name) const { +bool SequentialNode::PassEnabled(const std::string& pass_name) const { PassContext ctx = PassContext::Current(); auto required = RequiredPasses(ctx->required_pass); @@ -512,7 +512,7 @@ Module SequentialNode::operator()(const Module& module, PassInfo info = pass->Info(); const auto& pass_name = info->name; // Execute the pass if it is enabled. - if (pass_enabled(pass_name)) { + if (PassEnabled(pass_name)) { const auto* pn = pass.operator->(); mod = (*pn)(mod, pass_ctx); } diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential.cc index 40011a278009a..a6ed0069c1e01 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential.cc @@ -69,8 +69,7 @@ TEST(Relay, Sequential) { tvm::Array pass_seqs; pass_seqs.push_back(relay::transform::InferType()); pass_seqs.push_back(relay::transform::DeadCodeElimination()); - pass_seqs.push_back( - relay::transform::EliminateCommonSubexpr(tvm::PackedFunc(nullptr))); + pass_seqs.push_back(relay::transform::EliminateCommonSubexpr()); pass_seqs.push_back(relay::transform::AlterOpLayout()); relay::GlobalVar var = relay::GlobalVarNode::make("main");