Skip to content

Commit

Permalink
recursive bind
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Nov 5, 2019
1 parent 91d4c69 commit dd0f1d5
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 32 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,7 @@ def _from_mxnet_impl(mod, symbol, shape_dict, dtype_info):
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
if isinstance(outputs, _expr.Function):
return outputs
print("Free vars are {}".format(analysis.free_vars(outputs)))
func = _expr.Function(analysis.free_vars(outputs), outputs)
return func

Expand Down
95 changes: 65 additions & 30 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -853,36 +853,67 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
params_[name] = data_in;
}

relay::Function VMCompiler::BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
class BindParams : public ExprMutator {
public:
explicit BindParams(const std::unordered_map<std::string, runtime::NDArray>& params)
: params_(params) {}

relay::Function BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, NodeHash, NodeEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}
}
std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
std::unordered_map<relay::Var, Expr, NodeHash, NodeEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined())
<< "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}


Expr VisitExpr_(const FunctionNode* op) {
auto f = GetRef<Function>(op);
auto it = visited.find(f);
if (it != visited.end()) {
return it->second;
}
bind_dict[arg] = ConstantNode::make(kv.second);

auto func = ExprMutator::VisitExpr(f);
auto res = BindParamsByName(Downcast<Function>(func), params_);
visited[f] = res;
return res;
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined())
<< "The returning type is expected to be a Relay Function."
<< "\n";
return ret;

private:
std::unordered_map<std::string, runtime::NDArray> params_;
std::unordered_map<Expr, Expr, NodeHash> visited;
};

relay::Function VMCompiler::BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::cout << "BindParams " << func << "\n";
return Downcast<Function>(BindParams(params).Mutate(func));
}

void VMCompiler::Compile(Module mod,
Expand All @@ -891,9 +922,14 @@ void VMCompiler::Compile(Module mod,
CHECK_EQ(targets.size(), 1)
<< "Currently VM compiler doesn't support heterogeneous compilation";
if (params_.size()) {
auto f = BindParamsByName(mod->Lookup("main"), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
auto functions = mod->functions;
for (auto p : functions) {
std::cout << "Binding " << p.first << "\n";
auto f = BindParamsByName(p.second, params_);
mod->Add(p.first, f);
}
} else {
std::cout << "Params empty\n";
}

InitVM();
Expand All @@ -902,7 +938,6 @@ void VMCompiler::Compile(Module mod,

// Run the optimizations necessary to target the VM.
context_.module = OptimizeModule(mod, targets_);

// Populate the global map.
//
// This maps global variables to a global index
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,8 +668,8 @@ def helper_rnn(cell_type):
inputs = [relay.var('data')]
for i in range(num_states):
inputs.append(relay.var('state%s' % i))
# for name, _ in params:
# inputs.append(relay.var(name))
for name, _ in params:
inputs.append(relay.var(name))
mod['main'] = relay.Function(inputs, relay.Call(relay_net, inputs))

l = 5
Expand Down

0 comments on commit dd0f1d5

Please sign in to comment.