Skip to content

Commit

Permalink
Dedup BindParamByName function in VM compiler (apache#4793)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and alexwong committed Feb 26, 2020
1 parent 4438f1e commit 3662b0a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 83 deletions.
39 changes: 1 addition & 38 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,6 @@ using tir::LoweredFunc;
using TargetsMap = Map<tvm::Integer, tvm::Target>;
using namespace tvm::relay::transform;

/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}

std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}

/*!
* \brief Output of building module
*
Expand Down Expand Up @@ -527,7 +490,7 @@ TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
for (const auto& kv : params) {
params_[kv.first] = kv.second->data;
}
*rv = BindParamsByName(args[0], params_);
*rv = relay::backend::BindParamsByName(args[0], params_);
});

} // namespace backend
Expand Down
41 changes: 41 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
#include <dmlc/json.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/transform.h>
#include <tvm/driver/driver_api.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/te/operation.h>

#include <typeinfo>
#include <string>
#include <unordered_map>
#include <unordered_set>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -81,6 +84,44 @@ inline std::string DType2String(const tvm::DataType dtype) {
return os.str();
}

/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
inline relay::Function
BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}

std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}

} // namespace backend
} // namespace relay
} // namespace tvm
Expand Down
37 changes: 2 additions & 35 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "../utils.h"
#include "../../backend/compile_engine.h"
#include "../../pass/pass_util.h"
#include "../../op/op_common.h"
Expand Down Expand Up @@ -783,38 +782,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
params_[name] = data_in;
}

relay::Function VMCompiler::BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}
std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined())
<< "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}

void VMCompiler::Lower(IRModule mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
Expand All @@ -824,7 +791,7 @@ void VMCompiler::Lower(IRModule mod,
BaseFunc base_func = mod->Lookup("main");
CHECK(base_func->IsInstance<FunctionNode>())
<< "VM compiler expects to compile relay::Function";
auto f = BindParamsByName(Downcast<Function>(base_func), params_);
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
Expand Down
10 changes: 0 additions & 10 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,6 @@ class VMCompiler : public runtime::ModuleNode {
void Codegen();

protected:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params);

IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets);

void PopulateGlobalMap();
Expand Down

0 comments on commit 3662b0a

Please sign in to comment.