diff --git a/python/tvm/relay/backend/profiler_vm.py b/python/tvm/relay/backend/profiler_vm.py index 3adbecaa2531..544722ceecb5 100644 --- a/python/tvm/relay/backend/profiler_vm.py +++ b/python/tvm/relay/backend/profiler_vm.py @@ -21,6 +21,7 @@ Provides extra APIs for profiling vm execution. """ import tvm +from tvm import autotvm from . import vm, _vm def _update_target(target): @@ -49,8 +50,9 @@ def __init__(self): self.mod = _vm._VMCompilerProfiler() self._compile = self.mod["compile"] self._get_vm = self.mod["get_vm"] + self._set_params_func = self.mod["set_params"] - def compile(self, mod, target=None, target_host=None): + def compile(self, mod, target=None, target_host=None, params=None): """ Parameters ---------- @@ -71,13 +73,38 @@ def compile(self, mod, target=None, target_host=None): By default, llvm is used if it is enabled, otherwise a stackvm intepreter is used. + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. Used for constant folding. + Returns ------- vm : VirtualMachineProfiler The profile VM runtime. """ target = _update_target(target) - self._compile(mod, target, target_host) + if params: + self.set_params(params) + + target_host = None if target_host == "" else target_host + if not target_host: + for device_type, tgt in target.items(): + if device_type.value == tvm.nd.cpu(0).device_type: + target_host = tgt + break + if not target_host: + target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" + target_host = tvm.target.create(target_host) + + # If current dispatch context is fallback context (the default root context), + # then load pre-tuned parameters from TopHub + if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): + tophub_context = autotvm.tophub.context(list(target.values())) + else: + tophub_context = autotvm.util.EmptyContext() + + with tophub_context: + self._compile(mod, target, target_host) return VirtualMachineProfiler(self._get_vm()) class VirtualMachineProfiler(vm.VirtualMachine): diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index a6cb91c2dfde..563423ae3471 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -25,6 +25,7 @@ import tvm from tvm import autotvm from tvm._ffi.runtime_ctypes import TVMByteArray +from tvm.relay import expr as _expr from . import _vm from . import vmobj as _obj from .interpreter import Executor @@ -150,8 +151,17 @@ def __init__(self): self.mod = _vm._VMCompiler() self._compile = self.mod["compile"] self._get_vm = self.mod["get_vm"] + self._set_params_func = self.mod["set_params"] - def compile(self, mod, target=None, target_host=None): + def set_params(self, params): + inputs = {} + for name, param in params.items(): + if isinstance(param, np.ndarray): + param = _nd.array(param) + inputs[name] = _expr.const(param) + self._set_params_func(inputs) + + def compile(self, mod, target=None, target_host=None, params=None): """ Parameters ---------- @@ -172,6 +182,10 @@ def compile(self, mod, target=None, target_host=None): By default, llvm is used if it is enabled, otherwise a stackvm intepreter is used. + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. Used for constant folding. + Returns ------- vm : VirtualMachine @@ -188,6 +202,9 @@ def compile(self, mod, target=None, target_host=None): target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" target_host = tvm.target.create(target_host) + if params: + self.set_params(params) + # If current dispatch context is fallback context (the default root context), # then load pre-tuned parameters from TopHub if isinstance(autotvm.DispatchContext.current, autotvm.FallbackContext): diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 49079fbc107e..c88843206029 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -780,23 +780,73 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, if (name == "compile") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); - this->Compile(args[0], args[1], args[2]); + Module mod = args[0]; + this->Compile(mod, args[1], args[2]); }); } else if (name == "get_vm") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = runtime::Module(vm_); }); + } else if (name == "set_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Map params = args[0]; + for (const auto& kv : params) { + this->SetParam(kv.first, kv.second->data); + } + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); } } -void VMCompiler::Compile(const Module& mod_ref, +void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { + params_[name] = data_in; +} + +relay::Function VMCompiler::BindParamsByName( + relay::Function func, + const std::unordered_map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto &name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(arg); + } else { + name_dict[name] = arg; + } + } + std::unordered_map bind_dict; + for (auto &kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = ConstantNode::make(kv.second); + } + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + CHECK(ret.defined()) + << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; +} + + +void VMCompiler::Compile(Module mod, const TargetsMap& targets, const tvm::Target& target_host) { CHECK_EQ(targets.size(), 1) << "Currently VM compiler doesn't support heterogeneous compilation"; + if (params_.size()) { + auto f = BindParamsByName(mod->Lookup("main"), params_); + auto gvar = mod->GetGlobalVar("main"); + mod->Add(gvar, f); + } InitVM(); targets_ = targets; @@ -804,7 +854,7 @@ void VMCompiler::Compile(const Module& mod_ref, // Run some optimizations first, this code should // be moved to pass manager. - context_.module = OptimizeModule(mod_ref, targets_); + context_.module = OptimizeModule(mod, targets_); // Populate the global map. // diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 14a5035b20dc..1738112cdfdf 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -100,11 +100,29 @@ class VMCompiler : public runtime::ModuleNode { vm_ = std::make_shared(); } - void Compile(const Module& mod_ref, + /*! + * \brief Set the parameters + * + * \param name name of parameter + * \param data_in input DLTensor + */ + void SetParam(const std::string& name, runtime::NDArray data_in); + + void Compile(Module mod, const TargetsMap& targets, const tvm::Target& target_host); protected: + /*! + * \brief Bind params to function by using name + * \param func Relay function + * \param params params dict + * \return relay::Function + */ + relay::Function BindParamsByName( + relay::Function func, + const std::unordered_map& params); + Module OptimizeModule(const Module& mod, const TargetsMap& targets); void PopulateGlobalMap(); @@ -120,6 +138,8 @@ class VMCompiler : public runtime::ModuleNode { VMCompilerContext context_; /*! \brief Compiled virtual machine. */ std::shared_ptr vm_; + /*! \brief parameters */ + std::unordered_map params_; }; } // namespace vm diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 1d3ac836925a..5f59f6ed7f48 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -98,6 +98,11 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, Index output_size, const std::vector& args) { auto ctx = VirtualMachine::GetParamsContext(); + // warmup + VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, + args); + TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); + auto op_begin = std::chrono::high_resolution_clock::now(); VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 02ea3a42b156..d48f7373a5d8 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -589,18 +589,9 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, }); CHECK(it != functions.end()) << "Cannot find function " << func_name << "\n"; - CHECK_EQ(func_args.size() + params_.size(), it->params.size()) + CHECK_EQ(func_args.size(), it->params.size()) << "The number of provided parameters doesn't match the number of arguments" << "\n"; - if (!params_.empty()) { - for (const auto& p : it->params) { - const auto& pit = params_.find(p); - if (pit != params_.end()) { - func_args.push_back(pit->second); - } - } - CHECK_EQ(func_args.size(), it->params.size()); - } *rv = this->Invoke(func_name, func_args); }); } else if (name == "init") {