Skip to content

Commit

Permalink
Dedup BindParamByName function in VM compiler (#4793)
Browse files Browse the repository at this point in the history
masahi authored Jan 30, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 24126b4 commit 10f85d0
Showing 4 changed files with 44 additions and 83 deletions.
39 changes: 1 addition & 38 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
@@ -41,43 +41,6 @@ using tir::LoweredFunc;
using TargetsMap = Map<tvm::Integer, tvm::Target>;
using namespace tvm::relay::transform;

/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}

std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}

/*!
* \brief Output of building module
*
@@ -527,7 +490,7 @@ TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
for (const auto& kv : params) {
params_[kv.first] = kv.second->data;
}
*rv = BindParamsByName(args[0], params_);
*rv = relay::backend::BindParamsByName(args[0], params_);
});

} // namespace backend
41 changes: 41 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
@@ -27,13 +27,16 @@
#include <dmlc/json.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/transform.h>
#include <tvm/driver/driver_api.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/te/operation.h>

#include <typeinfo>
#include <string>
#include <unordered_map>
#include <unordered_set>

namespace tvm {
namespace relay {
@@ -81,6 +84,44 @@ inline std::string DType2String(const tvm::DataType dtype) {
return os.str();
}

/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
inline relay::Function
BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}

std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}

} // namespace backend
} // namespace relay
} // namespace tvm
37 changes: 2 additions & 35 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
@@ -37,9 +37,8 @@
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../utils.h"
#include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
#include "../../op/op_common.h"
@@ -783,38 +782,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
params_[name] = data_in;
}

relay::Function VMCompiler::BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}
std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined())
<< "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}

void VMCompiler::Lower(IRModule mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
@@ -824,7 +791,7 @@ void VMCompiler::Lower(IRModule mod,
BaseFunc base_func = mod->Lookup("main");
CHECK(base_func->IsInstance<FunctionNode>())
<< "VM compiler expects to compile relay::Function";
auto f = BindParamsByName(Downcast<Function>(base_func), params_);
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
10 changes: 0 additions & 10 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
@@ -115,16 +115,6 @@ class VMCompiler : public runtime::ModuleNode {
void Codegen();

protected:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params);

IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets);

void PopulateGlobalMap();

0 comments on commit 10f85d0

Please sign in to comment.