Skip to content

Commit

Permalink
move BindParams function to cc file
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 11, 2022
1 parent efeccea commit 4a5e4aa
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 47 deletions.
50 changes: 50 additions & 0 deletions src/relay/backend/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,56 @@ std::vector<int64_t> ShapeToJSON(tvm::Array<IndexExpr> shape) {
return ret;
}

relay::Function BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(name_dict[name]);
} else {
name_dict[name] = arg;
}
}

std::unordered_map<relay::Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = Constant(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}

void BindParamsInModule(IRModule mod,
const std::unordered_map<std::string, runtime::NDArray>& params) {
if (!params.empty()) {
BaseFunc base_func = mod->Lookup("main");
ICHECK(base_func->IsInstance<FunctionNode>());
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
}

void BindParamsInModule(IRModule mod, Map<String, Constant> params) {
std::unordered_map<std::string, runtime::NDArray> params_tmp;
for (const auto& kv : params) {
params_tmp[kv.first] = kv.second->data;
}
BindParamsInModule(mod, params_tmp);
}

} // namespace backend
} // namespace relay
} // namespace tvm
52 changes: 5 additions & 47 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,55 +386,13 @@ inline std::string DType2String(const tvm::DataType dtype) {
* \param params params dict
* \return relay::Function
*/
inline relay::Function BindParamsByName(
relay::Function func, const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(name_dict[name]);
} else {
name_dict[name] = arg;
}
}

std::unordered_map<relay::Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = Constant(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}
relay::Function BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params);

inline void BindParamsInModule(IRModule mod,
const std::unordered_map<std::string, runtime::NDArray>& params) {
if (!params.empty()) {
BaseFunc base_func = mod->Lookup("main");
ICHECK(base_func->IsInstance<FunctionNode>());
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
}
void BindParamsInModule(IRModule mod,
const std::unordered_map<std::string, runtime::NDArray>& params);

inline void BindParamsInModule(IRModule mod, Map<String, Constant> params) {
std::unordered_map<std::string, runtime::NDArray> params_tmp;
for (const auto& kv : params) {
params_tmp[kv.first] = kv.second->data;
}
BindParamsInModule(mod, params_tmp);
}
void BindParamsInModule(IRModule mod, Map<String, Constant> params);

/*!
* \brief Extract the shape from a Relay tensor type.
Expand Down

0 comments on commit 4a5e4aa

Please sign in to comment.