diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 035ab1ba5beeb..ff64d4a3acbb3 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -41,43 +41,6 @@ using tir::LoweredFunc; using TargetsMap = Map; using namespace tvm::relay::transform; -/*! - * \brief Bind params to function by using name - * \param func Relay function - * \param params params dict - * \return relay::Function - */ -relay::Function BindParamsByName(relay::Function func, - const std::unordered_map& params) { - std::unordered_map name_dict; - std::unordered_set repeat_var; - for (auto arg : func->params) { - const auto& name = arg->name_hint(); - if (name_dict.count(name)) { - repeat_var.insert(arg); - } else { - name_dict[name] = arg; - } - } - - std::unordered_map bind_dict; - for (auto& kv : params) { - if (name_dict.count(kv.first) == 0) { - continue; - } - auto arg = name_dict.at(kv.first); - if (repeat_var.count(arg)) { - LOG(FATAL) << "Multiple args in the function have name " << kv.first; - } - bind_dict[arg] = ConstantNode::make(kv.second); - } - Expr bound_expr = relay::Bind(func, bind_dict); - Function ret = Downcast(bound_expr); - CHECK(ret.defined()) << "The returning type is expected to be a Relay Function." - << "\n"; - return ret; -} - /*! * \brief Output of building module * @@ -527,7 +490,7 @@ TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName") for (const auto& kv : params) { params_[kv.first] = kv.second->data; } - *rv = BindParamsByName(args[0], params_); + *rv = relay::backend::BindParamsByName(args[0], params_); }); } // namespace backend diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 24e338e680289..cccd4badbb3af 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,8 @@ #include #include +#include +#include namespace tvm { namespace relay { @@ -81,6 +84,44 @@ inline std::string DType2String(const tvm::DataType dtype) { return os.str(); } +/*! + * \brief Bind params to function by using name + * \param func Relay function + * \param params params dict + * \return relay::Function + */ +inline relay::Function +BindParamsByName(relay::Function func, + const std::unordered_map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto& name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(arg); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map bind_dict; + for (auto& kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = ConstantNode::make(kv.second); + } + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + CHECK(ret.defined()) << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index c899644c4dd31..cc5d6bc02a476 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -37,9 +37,8 @@ #include #include #include -#include -#include #include +#include "../utils.h" #include "../../backend/compile_engine.h" #include "../../pass/pass_util.h" #include "../../op/op_common.h" @@ -783,38 +782,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -relay::Function VMCompiler::BindParamsByName( - relay::Function func, - const std::unordered_map& params) { - std::unordered_map name_dict; - std::unordered_set repeat_var; - for (auto arg : func->params) { - const auto &name = arg->name_hint(); - if (name_dict.count(name)) { - repeat_var.insert(arg); - } else { - name_dict[name] = arg; - } - } - std::unordered_map bind_dict; - for (auto &kv : params) { - if (name_dict.count(kv.first) == 0) { - continue; - } - auto arg = name_dict.at(kv.first); - if (repeat_var.count(arg)) { - LOG(FATAL) << "Multiple args in the function have name " << kv.first; - } - bind_dict[arg] = ConstantNode::make(kv.second); - } - Expr bound_expr = relay::Bind(func, bind_dict); - Function ret = Downcast(bound_expr); - CHECK(ret.defined()) - << "The returning type is expected to be a Relay Function." - << "\n"; - return ret; -} - void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { @@ -824,7 +791,7 @@ void VMCompiler::Lower(IRModule mod, BaseFunc base_func = mod->Lookup("main"); CHECK(base_func->IsInstance()) << "VM compiler expects to compile relay::Function"; - auto f = BindParamsByName(Downcast(base_func), params_); + auto f = relay::backend::BindParamsByName(Downcast(base_func), params_); auto gvar = mod->GetGlobalVar("main"); mod->Add(gvar, f); } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 602e6cceb3dd7..19e1ee8781a82 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -115,16 +115,6 @@ class VMCompiler : public runtime::ModuleNode { void Codegen(); protected: - /*! - * \brief Bind params to function by using name - * \param func Relay function - * \param params params dict - * \return relay::Function - */ - relay::Function BindParamsByName( - relay::Function func, - const std::unordered_map& params); - IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets); void PopulateGlobalMap();