diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 7662018e4f71..9883fe85c253 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -308,6 +308,56 @@ std::vector ShapeToJSON(tvm::Array shape) { return ret; } +relay::Function BindParamsByName(relay::Function func, + const std::unordered_map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto& name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(name_dict[name]); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map bind_dict; + for (auto& kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = Constant(kv.second); + } + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; +} + +void BindParamsInModule(IRModule mod, + const std::unordered_map& params) { + if (!params.empty()) { + BaseFunc base_func = mod->Lookup("main"); + ICHECK(base_func->IsInstance()); + auto f = relay::backend::BindParamsByName(Downcast(base_func), params); + auto gvar = mod->GetGlobalVar("main"); + mod->Add(gvar, f); + } +} + +void BindParamsInModule(IRModule mod, Map params) { + std::unordered_map params_tmp; + for (const auto& kv : params) { + params_tmp[kv.first] = kv.second->data; + } + BindParamsInModule(mod, params_tmp); +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index f15ae4765add..cfbf0a900777 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -386,55 +386,13 @@ inline std::string DType2String(const tvm::DataType dtype) { * \param params params dict * \return relay::Function */ -inline relay::Function BindParamsByName( - relay::Function func, const std::unordered_map& params) { - std::unordered_map name_dict; - std::unordered_set repeat_var; - for (auto arg : func->params) { - const auto& name = arg->name_hint(); - if (name_dict.count(name)) { - repeat_var.insert(name_dict[name]); - } else { - name_dict[name] = arg; - } - } - - std::unordered_map bind_dict; - for (auto& kv : params) { - if (name_dict.count(kv.first) == 0) { - continue; - } - auto arg = name_dict.at(kv.first); - if (repeat_var.count(arg)) { - LOG(FATAL) << "Multiple args in the function have name " << kv.first; - } - bind_dict[arg] = Constant(kv.second); - } - Expr bound_expr = relay::Bind(func, bind_dict); - Function ret = Downcast(bound_expr); - ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function." - << "\n"; - return ret; -} +relay::Function BindParamsByName(relay::Function func, + const std::unordered_map& params); -inline void BindParamsInModule(IRModule mod, - const std::unordered_map& params) { - if (!params.empty()) { - BaseFunc base_func = mod->Lookup("main"); - ICHECK(base_func->IsInstance()); - auto f = relay::backend::BindParamsByName(Downcast(base_func), params); - auto gvar = mod->GetGlobalVar("main"); - mod->Add(gvar, f); - } -} +void BindParamsInModule(IRModule mod, + const std::unordered_map& params); -inline void BindParamsInModule(IRModule mod, Map params) { - std::unordered_map params_tmp; - for (const auto& kv : params) { - params_tmp[kv.first] = kv.second->data; - } - BindParamsInModule(mod, params_tmp); -} +void BindParamsInModule(IRModule mod, Map params); /*! * \brief Extract the shape from a Relay tensor type.