From c381ec306837488e145893947417c718686b2338 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sat, 2 Nov 2019 21:01:59 +0000 Subject: [PATCH] [VM] add a few more API to vm --- include/tvm/runtime/vm.h | 8 +++- python/tvm/relay/backend/vm.py | 66 +++++++++++++++++++++++++++++--- src/runtime/vm/executable.cc | 35 +++++++++++++++++ src/runtime/vm/memory_manager.cc | 8 ++-- src/runtime/vm/vm.cc | 50 ++++++++++++++++-------- 5 files changed, 139 insertions(+), 28 deletions(-) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index a196afdee2f3..fbfb1f3a0a54 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -534,7 +534,7 @@ class Executable : public ModuleNode { */ std::string GetBytecode() const; -/*! + /*! * \brief Print the detailed statistics of the given code, i.e. number of * globls and constants, etc. */ @@ -547,6 +547,10 @@ class Executable : public ModuleNode { */ runtime::Module GetLib() const { return lib; } + int GetFunctionArity(std::string func) const; + + std::string GetFunctionParameterName(std::string func, uint32_t index) const; + virtual ~Executable() {} const char* type_key() const final { @@ -773,7 +777,7 @@ class VirtualMachine : public runtime::ModuleNode { void InvokeGlobal(const VMFunction& func, const std::vector& args); /*! \brief The parameter name to data mapping. */ - std::unordered_map params_; + std::unordered_map> inputs_; /*! * \brief The constant pool for runtime. It caches the device dependent diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index e190e3f1eb41..dd0c921b1af7 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -57,10 +57,13 @@ class Executable(object): """Relay VM executable""" def __init__(self, mod): self.mod = mod + self._function_params = {} self._save = self.mod["save"] self._get_lib = self.mod["get_lib"] self._get_bytecode = self.mod["get_bytecode"] self._get_stats = self.mod["get_stats"] + self._get_function_arity = self.mod["get_function_arity"] + self._get_function_param_name = self.mod["get_function_param_name"] def save(self): """Save the Relay VM Executable. @@ -239,6 +242,19 @@ def module(self): """Return the runtime module contained in a virtual machine executable.""" return self.mod + def get_function_params(self, func_name): + if func_name in self._function_params: + return self._function_params[func_name] + arity = self._get_function_arity(func_name) + assert arity >= 0 + params = [] + for i in range(arity): + p = self._get_function_param_name(func_name, i) + assert p + params.append(p) + self._function_params[func_name] = params + return params + class VirtualMachine(object): """Relay VM runtime.""" @@ -248,8 +264,10 @@ def __init__(self, mod): "tvm.Module, but received {}".format(type(mod))) m = mod.module if isinstance(mod, Executable) else mod self.mod = _vm._VirtualMachine(m) + self._exec = mod self._init = self.mod["init"] self._invoke = self.mod["invoke"] + self._set_inputs = self.mod["set_inputs"] def init(self, ctx): """Initialize the context in the VM. @@ -262,7 +280,27 @@ def init(self, ctx): args = [ctx.device_type, ctx.device_id] self._init(*args) - def invoke(self, func_name, *args): + def set_inputs(self, func_name, *args, **kwargs): + new_args = [] + if kwargs: + func_params = self._exec.get_function_params(func_name) + new_args = [None] * len(func_params) + assert len(args) + len(kwargs) == len(func_params) + for k in kwargs: + idx = func_params.index(k) + new_args[idx] = kwargs[k] + idx = 0 + for i in range(len(new_args)): + if new_args[i] is None: + new_args[i] = args[idx] + idx += 1 + if idx == len(args): + break + args = new_args + cargs = convert(args) + self._set_inputs(func_name, *cargs) + + def invoke(self, func_name, *args, **kwargs): """Invoke a function. Parameters @@ -278,10 +316,26 @@ def invoke(self, func_name, *args): result : Object The output. """ - cargs = convert(args) - return self._invoke(func_name, *cargs) - - def run(self, *args): + # if kwargs: + # func_params = self._exec.get_function_params(func_name) + # new_args = [None] * len(func_params) + # assert len(args) + len(kwargs) == len(func_params) + # for k in kwargs: + # idx = func_params.index(k) + # new_args[idx] = kwargs[k] + # idx = 0 + # for i in range(len(new_args)): + # if new_args[i] is None: + # new_args[i] = args[idx] + # idx += 1 + # if idx == len(args): + # break + # args = new_args + # if args: + # cargs = convert(args) + return self._invoke(func_name) #, *cargs) + + def run(self, *args, **kwargs): """Run the main function. Parameters @@ -294,7 +348,7 @@ def run(self, *args): result : Object The output. """ - return self.invoke("main", *args) + return self.invoke("main", *args, **kwargs) def compile(mod, target=None, target_host=None, params=None): diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 4c4554cc6c86..581a6713d3e6 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -68,12 +68,47 @@ PackedFunc Executable::GetFunction(const std::string& name, return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Save(); }); + } else if (name == "get_function_arity") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + *rv = this->GetFunctionArity(func_name); + }); + } else if (name == "get_function_param_name") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::string func_name = args[0]; + int index = args[1]; + *rv = this->GetFunctionParameterName(func_name, index); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc(nullptr); } } +int Executable::GetFunctionArity(std::string func_name) const { + auto it = global_map.find(func_name); + if (it == global_map.end()) { + LOG(ERROR) << "Cannot find function " << func_name << " in executable"; + return -1; + } + const auto& func = functions[it->second]; + return func.params.size(); +} + +std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const { + auto it = global_map.find(func_name); + if (it == global_map.end()) { + LOG(ERROR) << "Cannot find function " << func_name << " in executable"; + return ""; + } + const auto& func = functions[it->second]; + if (index > func.params.size()) { + LOG(ERROR) << "Invalid parameter index"; + return ""; + } + return func.params[index]; +} + std::string Executable::GetBytecode() const { std::ostringstream oss; diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index 1c7e029de2ca..453944bc2cbc 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -80,10 +80,10 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDa VerifyDataType(dtype); NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, this->buffer.ctx); container->deleter = StorageObj::Deleter; - size_t needed_size = GetDataSize(container->dl_tensor); + // size_t needed_size = GetDataSize(container->dl_tensor); // TODO(@jroesch): generalize later to non-overlapping allocations. - CHECK(needed_size == this->buffer.size) - << "size mistmatch required " << needed_size << " found " << this->buffer.size; + // CHECK(needed_size == this->buffer.size) + // << "size mistmatch required " << needed_size << " found " << this->buffer.size; this->IncRef(); container->manager_ctx = reinterpret_cast(this); container->dl_tensor.data = this->buffer.data; @@ -100,7 +100,7 @@ Allocator* MemoryManager::GetAllocator(TVMContext ctx) { if (allocators_.find(ctx) == allocators_.end()) { DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "(" << ctx.device_id << ")"; - std::unique_ptr alloc(new NaiveAllocator(ctx)); + std::unique_ptr alloc(new PooledAllocator(ctx)); allocators_.emplace(ctx, std::move(alloc)); } return allocators_.at(ctx).get(); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 05935b7833a5..a6ead3f78aa4 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -639,25 +639,24 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, const auto& param_names = vm_func.params; auto ctx = this->GetParamsContext(); - // Prepare the func args - std::vector func_args(param_names.size()); - std::vector empty_slots; + std::vector func_args; - for (size_t i = 0; i < param_names.size(); ++i) { - const auto& pit = params_.find(param_names[i]); - if (pit != params_.end()) { - func_args[i] = pit->second; - } else { - empty_slots.push_back(i); + if (args.size() == 1) { + if (param_names.size() > 0) { + auto argit = inputs_.find(func_name); + CHECK(argit != inputs_.end()) << "No arguments are set for " << func_name; + func_args = argit->second; + } + } else { + CHECK_EQ(args.size() - 1, param_names.size()) << + "The number of provided parameters doesn't match the number of arguments"; + // Prepare the func args + std::vector func_args(param_names.size()); + for (int i = 1; i < args.size(); ++i) { + ObjectRef obj = CopyTo(args[i], ctx); + func_args[i - 1] = obj; } } - CHECK_EQ(empty_slots.size(), args.size() - 1) - << "The number of provided parameters doesn't match the number of arguments"; - for (int i = 1; i < args.size(); ++i) { - ObjectRef obj = CopyTo(args[i], ctx); - func_args[empty_slots[i - 1]] = obj; - } - *rv = this->Invoke(vm_func, func_args); }); } else if (name == "init") { @@ -673,6 +672,25 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } this->Init(contexts); }); + } else if (name == "set_inputs") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK(exec) << "The executable is not created yet."; + std::string func_name = args[0]; + auto gvit = exec->global_map.find(func_name); + CHECK(gvit != exec->global_map.end()) << "Cannot find function " << func_name; + auto func_index = gvit->second; + const auto& vm_func = exec->functions[func_index]; + const auto& param_names = vm_func.params; + TVMContext ctx = ctxs[0]; + CHECK_EQ(args.size() - 1, param_names.size()) << + "The number of provided parameters doesn't match the number of arguments"; + std::vector func_args(param_names.size()); + for (int i = 1; i < args.size(); ++i) { + ObjectRef obj = CopyTo(args[i], ctx); + func_args[i - 1] = obj; + } + inputs_.emplace(func_name, func_args); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});