diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 3959b877c344..8113064b40ff 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1438,6 +1438,7 @@ def _from_mxnet_impl(mod, symbol, shape_dict, dtype_info): outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) if isinstance(outputs, _expr.Function): return outputs + print("Free vars are {}".format(analysis.free_vars(outputs))) func = _expr.Function(analysis.free_vars(outputs), outputs) return func diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index dee97e42c061..277df592fcee 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -853,36 +853,67 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -relay::Function VMCompiler::BindParamsByName( - relay::Function func, - const std::unordered_map& params) { - std::unordered_map name_dict; - std::unordered_set repeat_var; - for (auto arg : func->params) { - const auto &name = arg->name_hint(); - if (name_dict.count(name)) { - repeat_var.insert(arg); - } else { - name_dict[name] = arg; +class BindParams : public ExprMutator { + public: + explicit BindParams(const std::unordered_map& params) + : params_(params) {} + + relay::Function BindParamsByName( + relay::Function func, + const std::unordered_map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto &name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(arg); + } else { + name_dict[name] = arg; + } } - } - std::unordered_map bind_dict; - for (auto &kv : params) { - if (name_dict.count(kv.first) == 0) { - continue; + std::unordered_map bind_dict; + for (auto &kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = ConstantNode::make(kv.second); } - auto arg = name_dict.at(kv.first); - if (repeat_var.count(arg)) { - LOG(FATAL) << "Multiple args in the function have name " << kv.first; + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + CHECK(ret.defined()) + << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; + } + + + Expr VisitExpr_(const FunctionNode* op) { + auto f = GetRef(op); + auto it = visited.find(f); + if (it != visited.end()) { + return it->second; } - bind_dict[arg] = ConstantNode::make(kv.second); + + auto func = ExprMutator::VisitExpr(f); + auto res = BindParamsByName(Downcast(func), params_); + visited[f] = res; + return res; } - Expr bound_expr = relay::Bind(func, bind_dict); - Function ret = Downcast(bound_expr); - CHECK(ret.defined()) - << "The returning type is expected to be a Relay Function." - << "\n"; - return ret; + + private: + std::unordered_map params_; + std::unordered_map visited; +}; + +relay::Function VMCompiler::BindParamsByName( + relay::Function func, + const std::unordered_map& params) { + std::cout << "BindParams " << func << "\n"; + return Downcast(BindParams(params).Mutate(func)); } void VMCompiler::Compile(Module mod, @@ -891,9 +922,14 @@ void VMCompiler::Compile(Module mod, CHECK_EQ(targets.size(), 1) << "Currently VM compiler doesn't support heterogeneous compilation"; if (params_.size()) { - auto f = BindParamsByName(mod->Lookup("main"), params_); - auto gvar = mod->GetGlobalVar("main"); - mod->Add(gvar, f); + auto functions = mod->functions; + for (auto p : functions) { + std::cout << "Binding " << p.first << "\n"; + auto f = BindParamsByName(p.second, params_); + mod->Add(p.first, f); + } + } else { + std::cout << "Params empty\n"; } InitVM(); @@ -902,7 +938,6 @@ void VMCompiler::Compile(Module mod, // Run the optimizations necessary to target the VM. context_.module = OptimizeModule(mod, targets_); - // Populate the global map. // // This maps global variables to a global index diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 5cdaa11621c0..1eab448d8e91 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -668,8 +668,8 @@ def helper_rnn(cell_type): inputs = [relay.var('data')] for i in range(num_states): inputs.append(relay.var('state%s' % i)) - # for name, _ in params: - # inputs.append(relay.var(name)) + for name, _ in params: + inputs.append(relay.var(name)) mod['main'] = relay.Function(inputs, relay.Call(relay_net, inputs)) l = 5