From 3fbf640965bd73217fe3e57ecc04ef8597a1d8aa Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Tue, 9 Jun 2020 19:43:46 +0800 Subject: [PATCH 01/29] Support Module based interface runtime --- python/tvm/contrib/graph_runtime.py | 9 + python/tvm/relay/build_module.py | 8 +- python/tvm/runtime/graph_runtime_factory.py | 153 +++++++ python/tvm/runtime/module.py | 23 ++ .../graph/debug/graph_runtime_debug.cc | 4 +- src/runtime/graph/graph_runtime.cc | 10 +- src/runtime/graph/graph_runtime.h | 61 ++- src/runtime/graph/graph_runtime_factory.cc | 222 +++++++++++ src/runtime/graph/graph_runtime_factory.h | 136 +++++++ .../unittest/test_module_runtime_interface.py | 376 ++++++++++++++++++ 10 files changed, 991 insertions(+), 11 deletions(-) create mode 100644 python/tvm/runtime/graph_runtime_factory.py create mode 100644 src/runtime/graph/graph_runtime_factory.cc create mode 100644 src/runtime/graph/graph_runtime_factory.h create mode 100644 tests/python/unittest/test_module_runtime_interface.py diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 9b714a84b541..3bf09b79306b 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -63,6 +63,15 @@ def create(graph_json_str, libmod, ctx): return GraphModule(fcreate(graph_json_str, libmod, *device_type_id)) +# TODO (FrozenGene): rename +def create4unified(libmod, ctx): + ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx) + if num_rpc_ctx == len(ctx): + fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_factory.runtime_create") + else: + fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_factory.runtime_create") + + return GraphModule(fcreate(libmod, *device_type_id)) def get_device_ctx(libmod, ctx): """Parse and validate all the device context(s). diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index a28ab853dd9d..5d7996a38a9d 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -26,6 +26,7 @@ from tvm.tir import expr as tvm_expr from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt +from ..runtime import graph_runtime_factory as _graph_runtime_factory from . import _build_module from . import ty as _ty from . import expr as _expr @@ -181,7 +182,7 @@ def get_params(self): return ret -def build(mod, target=None, target_host=None, params=None): +def build(mod, target=None, target_host=None, params=None, mod_name='default', export_graph_module=False): """Helper function that builds a Relay function to run on TVM graph runtime. @@ -249,7 +250,10 @@ def build(mod, target=None, target_host=None, params=None): with tophub_context: bld_mod = BuildModule() graph_json, mod, params = bld_mod.build(mod, target, target_host, params) - return graph_json, mod, params + if export_graph_module: + mod = _graph_runtime_factory.create("graph", graph_json, mod, params, mod_name) + return mod + return graph_json, mod, params def optimize(mod, target=None, params=None): diff --git a/python/tvm/runtime/graph_runtime_factory.py b/python/tvm/runtime/graph_runtime_factory.py new file mode 100644 index 000000000000..d14afa919cad --- /dev/null +++ b/python/tvm/runtime/graph_runtime_factory.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Graph runtime factory.""" +import numpy as np +import warnings +from tvm._ffi.base import string_types +from tvm._ffi.registry import get_global_func +from tvm._ffi.runtime_ctypes import TVMContext +from tvm.contrib.graph_runtime import get_device_ctx +from .packed_func import _set_class_module +from tvm.rpc import base as rpc_base +from .module import Module +from . import ndarray + + +def create(graph_runtime_kind, graph_json_str, libmod, params, module_name='default'): + """Create a runtime executor module given a graph and module. + Parameters + ---------- + graph_runtime_kind: str + The kind of graph runtime. Like graphruntime, vm and so on. + graph_json_str : str or graph class + The graph to be deployed in json format output by nnvm graph. + The graph can only contain one operator(tvm_op) that + points to the name of PackedFunc in the libmod. + libmod : tvm.Module + The module of the corresponding function + Returns + ------- + graph_module : GraphModule + Runtime graph module that can be used to execute the graph. + """ + if not isinstance(graph_json_str, string_types): + try: + graph_json_str = graph_json_str._tvm_graph_json() + except AttributeError: + raise ValueError("Type %s is not supported" % type(graph_json_str)) + fcreate = get_global_func("tvm.graph_runtime_factory.create") + args = [] + for k, v in params.items(): + args.append(k) + args.append(ndarray.array(v)) + return GraphRuntimeFactoryModule(fcreate(graph_runtime_kind, graph_json_str, libmod, module_name, *args)) + + +class GraphRuntimeFactoryModule(Module): + """Graph runtime factory module. + + This is a module of graph runtime factory + + Parameters + ---------- + module : Module + The interal tvm module that holds the actual graph functions. + + Attributes + ---------- + module : Module + The interal tvm module that holds the actual graph functions. + """ + + def __init__(self, module): + self.module = module + self._select_module = module.get_function("select_module") + self._import_module = module.get_function("import_module") + self.selected_module = None + self.graph_json = module.get_function("get_json")() + self.lib = module.get_function("get_lib")() + self.params = {} + for k, v in module.get_function("get_params")().items(): + self.params[k] = v + self.iter_cnt = 0 + super(GraphRuntimeFactoryModule, self).__init__(self.module.handle) + + def __del__(self): + pass + + def runtime_create(self, ctx): + """Create the runtime using ctx + + Parameters + ---------- + ctx : TVMContext or list of TVMContext + """ + ctx, num_rpc_ctx, device_type_id = get_device_ctx(self.selected_module, ctx) + if num_rpc_ctx == len(ctx): + fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_factory.runtime_create") + else: + fcreate = get_global_func("tvm.graph_runtime_factory.runtime_create") + return fcreate(self.selected_module, *device_type_id) + + def import_module(self, mod, mod_name): + """Create the runtime using ctx + + Parameters + ---------- + mod : GraphRuntimeFactoryModule + The graph runtime factory module we want to import + mod_name: str + The module name + """ + return self._import_module(mod, mod_name) + + def __getitem__(self, key='default'): + """Get specific module + + Parameters + ---------- + key : str + The key of module. + """ + self.selected_module = self._select_module(key) + self.selected_module._entry = self.runtime_create + return self.selected_module + + def __iter__(self): + warnings.warn( + "legacy graph runtime behaviour of producing json / lib / params will be removed in the next release ", + DeprecationWarning, 2) + return self + + + def __next__(self): + if self.iter_cnt > 2: + raise StopIteration + + objs = [self.graph_json, self.lib, self.params] + obj = objs[self.iter_cnt] + self.iter_cnt += 1 + return obj + + def get_json(self): + return self.graph_json + + def get_lib(self): + return self.lib + + def get_params(self): + return self.params \ No newline at end of file diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 3cdb28f8c496..c1def255b019 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -41,6 +41,10 @@ def __init__(self, handle): self.handle = handle self._entry = None self.entry_name = "__tvm_main__" + # TODO:(FrozenGene): support rpc + if self.type_key == 'GraphRuntimeFactory': + #from tvm.runtime.graph_runtime_factory import GraphRuntimeFactoryModule + self._entry = self.runtime_create def __del__(self): check_call(_LIB.TVMModFree(self.handle)) @@ -99,6 +103,8 @@ def import_module(self, module): check_call(_LIB.TVMModImport(self.handle, module.handle)) def __getitem__(self, name): + if self.type_key == 'GraphRuntimeFactory': + return self.get_function("select_module")(name) if not isinstance(name, string_types): raise ValueError("Can only take string as function name") return self.get_function(name) @@ -112,6 +118,23 @@ def __call__(self, *args): def __repr__(self): return "Module(%s, %x)" % (self.type_key, self.handle.value) + # TODO (FrozenGene): remove + def runtime_create(self, ctx): + """Create the runtime using ctx + + Parameters + ---------- + ctx : TVMContext or list of TVMContext + """ + from tvm.contrib.graph_runtime import get_device_ctx + from tvm._ffi.registry import get_global_func + ctx, num_rpc_ctx, device_type_id = get_device_ctx(self, ctx) + if num_rpc_ctx == len(ctx): + fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_factory.runtime_create") + else: + fcreate = get_global_func("tvm.graph_runtime_factory.runtime_create") + return fcreate(self, *device_type_id) + @property def type_key(self): """Get type key of the module.""" diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 5439be9109f9..70c027d6d0fe 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -24,12 +24,10 @@ #include #include #include - +#include #include #include -#include "../graph_runtime.h" - namespace tvm { namespace runtime { diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index e984861769a0..4fa72986d2de 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -20,7 +20,7 @@ /*! * \file graph_runtime.cc */ -#include "graph_runtime.h" +//#include "graph_runtime.h" #include #include @@ -28,7 +28,6 @@ #include #include #include - #include #include #include @@ -37,7 +36,7 @@ #include #include #include - +#include "./graph_runtime.h" namespace tvm { namespace runtime { namespace details { @@ -66,7 +65,10 @@ void GraphRuntime::Run() { * executed on. */ void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module module, - const std::vector& ctxs) { + const std::vector& ctxs, + const std::unordered_map& params) { + graph_json_ = graph_json; + params_ = params; std::istringstream is(graph_json); dmlc::JSONReader reader(&is); this->Load(&reader); diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index d0c982281b34..d08026eac855 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -32,11 +32,13 @@ #include #include +#include #include #include #include #include +#include "./graph_runtime_factory.h" namespace tvm { namespace runtime { @@ -64,7 +66,7 @@ struct TVMOpParam { * This runtime can be acccesibly in various language via * TVM runtime PackedFunc API. */ -class TVM_DLL GraphRuntime : public ModuleNode { +class TVM_DLL GraphRuntime : public GraphRuntimeFactory { struct OpArgs { std::vector args; std::vector arg_values; @@ -94,10 +96,12 @@ class TVM_DLL GraphRuntime : public ModuleNode { * processor. * \param ctxs The context of the host and devices where graph nodes will be * executed on. + * \param params The params of graph. */ void Init(const std::string& graph_json, tvm::runtime::Module module, - const std::vector& ctxs); + const std::vector& ctxs, + const std::unordered_map& params = {}); /*! * \brief Get the input index given the name of input. @@ -171,6 +175,55 @@ class TVM_DLL GraphRuntime : public ModuleNode { std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; } + /*! + * \brief Set graph json value. + * \param graph_json The graph json value we want to set. + */ + void SetGraphJson(const std::string& graph_json) { graph_json_ = graph_json; } + + /*! + * \brief Get the graph json. + * \return The graph json. + */ + std::string GetGraphJson() const { return graph_json_; } + + /*! + * \brief Set the graph params. + * \param params The graph params value we want to set. + */ + void SetParams(const std::unordered_map& params) { + params_ = params; + + // upload big arrays first to avoid memory issue in rpc mode + std::vector keys; + for (const auto& p : params_) { + keys.emplace_back(p.first); + } + std::sort(std::begin(keys), std::end(keys), + [this](const std::string& lhs, const std::string& rhs) -> bool { + auto lhs_shape = params_[lhs].Shape(); + auto rhs_shape = params_[rhs].Shape(); + auto lhs_prod = std::accumulate(std::begin(lhs_shape), std::end(lhs_shape), 1, + std::multiplies()); + auto rhs_prod = std::accumulate(std::begin(rhs_shape), std::end(rhs_shape), 1, + std::multiplies()); + return lhs_prod > rhs_prod; + }); + + for (const auto& key : keys) { + int in_idx = this->GetInputIndex(key); + if (in_idx >= 0) { + this->SetInput(in_idx, const_cast(params_[key].operator->())); + } + } + } + + /*! + * \brief Get the graph params. + * \return The graph params. + */ + std::unordered_map GetParams() const { return params_; } + protected: // Memory pool entry. struct PoolEntry { @@ -389,6 +442,10 @@ class TVM_DLL GraphRuntime : public ModuleNode { std::vector outputs_; /*! \brief Additional graph attributes. */ GraphAttr attrs_; + /*! \brief The execution graph. */ + std::string graph_json_; + /*! \brief The params. */ + std::unordered_map params_; /*! \brief The code module that contains both host and device code. */ tvm::runtime::Module module_; /*! \brief Execution context of all devices including the host. */ diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc new file mode 100644 index 000000000000..79b6c9a0885f --- /dev/null +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_runtime_factory.cc + * \brief Graph runtime factory implementations + */ + +#include +#include +#include +#include +#include +#include "./graph_runtime_factory.h" +#include "./graph_runtime.h" + +namespace tvm { +namespace runtime { + +void GraphRuntimeFactory::Init(const std::string& kind, + const std::string& graph_json, + const std::unordered_map& params) { + kind_ = kind; + graph_json_ = graph_json; + params_ = params; +} + +void GraphRuntimeFactory::ImportModule(Module other, std::string module_name) { + this->Import(other); + module_names_.push_back(module_name); +} + +PackedFunc GraphRuntimeFactory::GetFunction(const std::string& name, + const tvm::runtime::ObjectPtr& sptr_to_self) { + if (name == "runtime_create") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + std::vector contexts; + TVMContext ctx; + // arg is: module, ctxs + CHECK_EQ((args.size() - 1) % 2, 0); + for (int i = 1; i < args.num_args; i += 2) { + int dev_type = args[i]; + ctx.device_type = static_cast(dev_type); + ctx.device_id = args[i + 1]; + contexts.push_back(ctx); + } + *rv = this->RuntimeCreate(args[0], contexts); + }); + } else if (name == "import_module") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size(), 2); + this->ImportModule(args[0], args[1]); + }); + } else if (name == "select_module") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size(), 1); + *rv = this->SelectModule(args[0]); + }); + } else if (name == "get_json") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->graph_json_; + }); + } else if (name == "get_lib") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_GT(this->imports().size(), 0); + *rv = this->imports_[0]; + }); + } else if (name == "get_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + Map ret; + for (const auto& kv : this->params_) { + ret.Set(kv.first, kv.second); + } + *rv = ret; + }); + } else { + return PackedFunc(); + } +} + +void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) { + stream->Write(module_names_); + stream->Write(kind_); + stream->Write(graph_json_); + std::vector names; + std::vector arrays; + for (const auto& v : params_) { + names.emplace_back(v.first); + arrays.emplace_back(const_cast(v.second.operator->())); + } + stream->Write(names); + uint64_t sz = arrays.size(); + CHECK(sz == names.size()); + stream->Write(sz); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(stream, arrays[i]); + } +} + +Module GraphRuntimeFactory::RuntimeCreate(Module module, const std::vector &ctxs) { + auto factory_module = module.as(); + CHECK(factory_module != nullptr); + if (factory_module->GetKind() == "graph") { + auto exec = make_object(); + exec->Init(factory_module->GetJson(), factory_module->GetLib(), ctxs); + exec->SetParams(factory_module->GetParams()); + return Module(exec); + } + + return Module(); +} + +Module GraphRuntimeFactory::SelectModule(const std::string &name) { + CHECK(std::find(module_names_.begin(), module_names_.end(), name) != module_names_.end()); + auto iter = std::find(module_names_.begin(), module_names_.end(), name); + CHECK(iter != module_names_.end()); + if (iter == module_names_.begin()) { + auto exec = make_object(); + exec->Init(this->kind_, this->graph_json_, this->params_); + exec->ImportModule(this->imports_[0], *iter); + return Module(exec); + } else { + return this->imports_[std::distance(module_names_.begin(), iter)]; + } +} + +Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::vector module_names; + std::string kind; + std::string graph_json; + CHECK(stream->Read(&module_names)); + CHECK(stream->Read(&kind)); + CHECK(stream->Read(&graph_json)); + std::vector names; + CHECK(stream->Read(&names)); + uint64_t sz; + CHECK(stream->Read(&sz)); + CHECK(sz == names.size()); + std::unordered_map params; + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::NDArray temp; + temp.Load(stream); + params[names[i]] = temp; + } + + auto exec = make_object(); + exec->Init(kind, graph_json, params); + exec->SetModuleNames(module_names); + return Module(exec); +} + +Module RuntimeCreate(Module module, const std::vector &ctxs) { + auto mod = module.as(); + CHECK(mod != nullptr); + if (mod->GetKind() == "graph") { + auto exec = make_object(); + exec->Init(mod->GetJson(), mod->GetLib(), ctxs); + exec->SetParams(mod->GetParams()); + return Module(exec); + } else { + LOG(ERROR) << "Doesn't support graph kind of " << mod->GetKind(); + } + + return Module(); +} + +TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create") +.set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) << "The expected number of arguments for " + "graph_runtime_factory.create needs at least 3, " + "but it has " + << args.num_args; + auto exec = make_object(); + // The argument order is graph_runtime_kind, graph_json, module, module_name, params. + CHECK_EQ((args.size() - 4) % 2, 0); + std::unordered_map params; + for (size_t i = 4; i < static_cast(args.size()); i += 2) { + std::string name = args[i].operator String(); + params[name] = args[i + 1].operator tvm::runtime::NDArray(); + } + exec->Init(args[0], args[1], params); + exec->ImportModule(args[2], args[3]); + *rv = Module(exec); + }); + +TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.runtime_create") +.set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector contexts; + TVMContext ctx; + // arg is: module, ctxs + CHECK_EQ((args.size() - 1) % 2, 0); + for (int i = 1; i < args.num_args; i += 2) { + int dev_type = args[i]; + ctx.device_type = static_cast(dev_type); + ctx.device_id = args[i + 1]; + contexts.push_back(ctx); + } + *rv = RuntimeCreate(args[0], contexts); +}); + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_GraphRuntimeFactory") +.set_body_typed(GraphRuntimeFactoryModuleLoadBinary); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h new file mode 100644 index 000000000000..0bc8d70d6570 --- /dev/null +++ b/src/runtime/graph/graph_runtime_factory.h @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/graph_runtime_factory.h + * \brief Graph runtime factory creating graph runtime. + */ + +#ifndef TVM_RUNTIME_GRAPH_RUNTIME_FACTORY_H_ +#define TVM_RUNTIME_GRAPH_RUNTIME_FACTORY_H_ + +#include +#include +#include +#include +#include +#include +#include + + +namespace tvm { +namespace runtime { + +class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { + + public: + + /*! + * \brief Initialize the GraphRuntimeFactory with graph and context. + * \param graph_json The execution graph. + * \param params The params of graph. + * \param kind The runtime kind to be created. + */ + + void Init(const std::string& kind, + const std::string& graph_json, + const std::unordered_map& params); + + void ImportModule(Module other, std::string module_name); + + /*! + * \brief Get member function to front-end + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self); + + /*! + * \return The type key of the executor. + */ + const char* type_key() const override { + return "GraphRuntimeFactory"; + } + + /*! + * \brief Save the module to binary stream. + * \param stream The binary stream to save to. + */ + void SaveToBinary(dmlc::Stream* stream) override; + + + /*! + * \brief Create a specific runtime module + * \param module The module we will be used for creating runtime + * \param ctxs The context of the host and devices where graph nodes will be + * executed on. + * \return created runtime module + */ + Module RuntimeCreate(Module module, const std::vector& ctxs); + + /*! + * \brief Select the specific module + * \param name The name of the module + * \return selected module + */ + Module SelectModule(const std::string& name); + + inline std::string GetJson() const { + return graph_json_; + } + + inline std::unordered_map GetParams() const { + return params_; + } + + inline Module GetLib() const { + CHECK_GT(this->imports().size(), 0); + return this->imports_[0]; + } + + inline std::string GetKind() const { + return kind_; + } + + inline std::vector GetModuleNames() const { + return module_names_; + } + + inline void SetModuleNames(const std::vector& module_names) { + module_names_ = module_names; + } + + protected: + /*! \brief The execution graph. */ + std::string graph_json_; + /*! \brief The params. */ + std::unordered_map params_; + /*! \brief runtime kind */ + std::string kind_; + /*! \brief module names list */ + std::vector module_names_; + +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_GRAPH_RUNTIME_FACTORY_H_ \ No newline at end of file diff --git a/tests/python/unittest/test_module_runtime_interface.py b/tests/python/unittest/test_module_runtime_interface.py new file mode 100644 index 000000000000..8b964734335f --- /dev/null +++ b/tests/python/unittest/test_module_runtime_interface.py @@ -0,0 +1,376 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import copy +import numpy as np +from tvm import relay +from tvm.relay import testing +import tvm +from tvm.contrib import graph_runtime +from tvm.runtime import graph_runtime_factory + +def get_workload(num_layers=18): + mod, params = relay.testing.resnet.get_workload(num_layers=num_layers) + return mod, params + +def verify(data, num_layers=18): + mod, params = get_workload(num_layers) + with relay.build_config(opt_level=3): + graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) + + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.set_input("data", data) + module.set_input(**graph_params) + module.run() + out = module.get_output(0).asnumpy() + + return out + +def test_legacy_compatibility(): + mod, params = get_workload() + with relay.build_config(opt_level=3): + graph, lib, graph_params = relay.build_module.build( + mod, "llvm", params=params, export_graph_module=True) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.set_input("data", data) + module.set_input(**graph_params) + module.run() + out = module.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_cpu(): + mod, params = get_workload() + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build( + mod, "llvm", params=params, export_graph_module=True) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + # raw api + ctx = tvm.cpu() + gmod = complied_graph_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + + # graph runtime + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + gmod = graph_runtime.create4unified(complied_graph_lib['default'], ctx) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_gpu(): + mod, params = get_workload() + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build( + mod, "cuda", params=params, export_graph_module=True) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.gpu() + gmod = complied_graph_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_multi_models(): + resnet18_mod, resnet18_params = get_workload() + resnet50_mod, resnet50_params = get_workload(50) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build( + resnet18_mod, "llvm", params=resnet18_params, mod_name='resnet18', export_graph_module=True) + with relay.build_config(opt_level=3): + resnet50_gpu_lib = relay.build_module.build( + resnet50_mod, "cuda", params=resnet50_params, mod_name='resnet50', export_graph_module=True) + complied_graph_lib.import_module(resnet50_gpu_lib, "resnet50") + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + # resnet18 + cpu_ctx = tvm.cpu() + gmod = complied_graph_lib['resnet18'](cpu_ctx) + set_input = gmod["set_input"] + get_input = gmod["get_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # resnet50 + gpu_ctx = tvm.gpu() + gmod = complied_graph_lib['resnet50'](gpu_ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data, 50), atol=1e-5) + +def test_cpu_export(format=".so"): + mod, params = get_workload() + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build( + mod, "llvm", params=params, export_graph_module=True) + + from tvm.contrib import util + temp = util.tempdir() + if format == ".so": + file_name = "deploy_lib.so" + else: + assert format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + loaded_lib = tvm.runtime.load_module(path_lib) + ctx = tvm.cpu(0) + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_gpu_export(format=".so"): + mod, params = get_workload() + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build( + mod, "cuda", params=params, export_graph_module=True) + + from tvm.contrib import util + temp = util.tempdir() + if format == ".so": + file_name = "deploy_lib.so" + else: + assert format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + loaded_lib = tvm.runtime.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.gpu() + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +# +def test_previous_cpu_export(format=".so"): + mod, params = get_workload() + with relay.build_config(opt_level=3): + graph, lib, graph_params = relay.build_module.build( + mod, "llvm --system-lib", params=params, export_graph_module=True) + + from tvm.contrib import util + temp = util.tempdir() + if format == ".so": + file_name = "deploy_lib.so" + else: + assert format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + lib.export_library(path_lib) + with open(temp.relpath("deploy_graph.json"), "w") as fo: + fo.write(graph) + with open(temp.relpath("deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(graph_params)) + loaded_json = open(temp.relpath("deploy_graph.json")).read() + #loaded_lib = tvm.module.load(path_lib) + import ctypes + # Load dll, will trigger system library registration + dll = ctypes.CDLL(path_lib) + # Load the system wide library + loaded_lib = tvm.runtime.system_lib() + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.cpu() + module = graph_runtime.create(loaded_json, loaded_lib, ctx) + module.load_params(loaded_params) + module.set_input("data", data) + module.run() + out = module.get_output(0).asnumpy() + + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +# +# def test_previous_gpu_export(format=".so"): +# #mod, params = get_workload() +# net, params = nnvm.testing.resnet.get_workload(num_layers=18) +# with nnvm.compiler.build_config(opt_level=3): +# graph, lib, graph_params = nnvm.compiler.build( +# net, "opencl", shape={'data': (1,3,224,224)}, params=params) +# +# from tvm.contrib import util +# temp = "tvm_deploy/" +# if format == ".so": +# file_name = "deploy_lib.so" +# else: +# assert format == ".tar" +# file_name = "deploy_lib.tar" +# path_lib = temp + file_name +# lib.export_library(path_lib) +# with open("tvm_deploy/deploy_graph.json", "w") as fo: +# fo.write(graph.json()) +# with open("tvm_deploy/deploy_param.params", "wb") as fo: +# fo.write(nnvm.compiler.save_param_dict(graph_params)) +# # loaded_json = open(temp.relpath("deploy_graph.json")).read() +# # loaded_lib = tvm.module.load(path_lib) +# # loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) +# # data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") +# # ctx = tvm.gpu() +# # module = graph_runtime.create(loaded_json, loaded_lib, ctx) +# # module.load_params(loaded_params) +# # module.set_input("data", data) +# # module.run() +# # out = module.get_output(0).asnumpy() +# # +# # tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +# +def test_rpc_export(format=".so"): + mod, params = get_workload() + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build( + mod, "llvm", params=params, export_graph_module=True) + + from tvm.contrib import util + temp = util.tempdir() + if format == ".so": + file_name = "deploy_lib.so" + else: + assert format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + + from tvm import rpc + server = rpc.Server("localhost", use_popen=True) + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = remote.cpu() + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", data) + run() + out = get_output(0).asnumpy() + + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +# +# def test_previous_rpc_export(format=".so"): +# mod, params = get_workload() +# with relay.build_config(opt_level=3): +# graph, lib, graph_params = relay.build_module.build( +# mod, "llvm", params=params, export_graph_module=False) +# +# from tvm.contrib import util +# temp = util.tempdir() +# if format == ".so": +# file_name = "deploy_lib.so" +# else: +# assert format == ".tar" +# file_name = "deploy_lib.tar" +# path_lib = temp.relpath(file_name) +# lib.export_library(path_lib) +# with open(temp.relpath("deploy_graph.json"), "w") as fo: +# fo.write(graph) +# with open(temp.relpath("deploy_param.params"), "wb") as fo: +# fo.write(relay.save_param_dict(graph_params)) +# +# from tvm import rpc +# server = rpc.Server("localhost", use_popen=True) +# remote = rpc.connect(server.host, server.port) +# remote.upload(path_lib) +# loaded_json = open(temp.relpath("deploy_graph.json")).read() +# loaded_lib = remote.load_module(path_lib) +# loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) +# data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") +# ctx = remote.cpu() +# module = graph_runtime.create(loaded_json, loaded_lib, ctx) +# module.load_params(loaded_params) +# module.set_input("data", data) +# module.run() +# out = module.get_output(0).asnumpy() +# +# tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +# +# +# def test_previous_gpu_load(): +# loaded_json = open("tvm_deploy/deploy_graph.json").read() +# loaded_lib = tvm.module.load("tvm_deploy/deploy_lib.so") +# loaded_params = bytearray(open("tvm_deploy/deploy_param.params", "rb").read()) +# data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") +# ctx = tvm.gpu() +# module = graph_runtime.create(loaded_json, loaded_lib, ctx) +# module.load_params(loaded_params) +# module.set_input("data", data) +# module.run() +# out = module.get_output(0).asnumpy() +# +# tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +# +# def test_previous_cpu_load(): +# loaded_json = open("tvm_deploy/deploy_cpu_graph.json").read() +# loaded_lib = tvm.module.load("tvm_deploy/deploy_cpu_lib.so") +# loaded_params = bytearray(open("tvm_deploy/deploy_cpu_param.params", "rb").read()) +# data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") +# ctx = tvm.cpu() +# module = graph_runtime.create(loaded_json, loaded_lib, ctx) +# module.load_params(loaded_params) +# module.set_input("data", data) +# module.run() +# out = module.get_output(0).asnumpy() +# +# tvm.testing.assert_allclose(out, verify(data), atol=1e-5) +if __name__ == "__main__": + test_legacy_compatibility() + test_cpu() + test_gpu() + test_multi_models() + test_cpu_export(".so") + test_cpu_export(".tar") + # test_gpu() + test_gpu_export(".so") + # test_gpu_export(".tar") + # test_rpc_export(".so") + # test_rpc_export(".tar") + # test_previous_cpu_export(".so") + # test_previous_cpu_export(".tar") + #test_previous_gpu_export(".so") + # test_previous_gpu_export(".tar") + # test_previous_rpc_export(".so") + # test_previous_rpc_export(".tar") + #test_previous_gpu_load() + #test_previous_cpu_load() \ No newline at end of file From 3fbd368ca367e51676caa5d56188d282bd18ceda Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Fri, 19 Jun 2020 14:51:35 +0800 Subject: [PATCH 02/29] remove unnecessary comment --- python/tvm/runtime/module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index c1def255b019..7e993993dd55 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -43,7 +43,6 @@ def __init__(self, handle): self.entry_name = "__tvm_main__" # TODO:(FrozenGene): support rpc if self.type_key == 'GraphRuntimeFactory': - #from tvm.runtime.graph_runtime_factory import GraphRuntimeFactoryModule self._entry = self.runtime_create def __del__(self): From cb101acdf8bb7fc28743911d20d6bb3a76550b4f Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Fri, 19 Jun 2020 20:14:25 +0800 Subject: [PATCH 03/29] support rpc (except params issue) --- python/tvm/rpc/client.py | 7 ++++- python/tvm/runtime/graph_runtime_factory.py | 13 +++++---- python/tvm/runtime/module.py | 29 ++++--------------- .../unittest/test_module_runtime_interface.py | 18 ++++++------ 4 files changed, 28 insertions(+), 39 deletions(-) diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 2f96c9b62976..971ce4d06019 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -160,7 +160,12 @@ def load_module(self, path): m : Module The remote module containing remote function. """ - return _ffi_api.LoadRemoteModule(self._sess, path) + module = _ffi_api.LoadRemoteModule(self._sess, path) + type_key = self.get_function("runtime.ModuleGetTypeKey")(module) + if type_key == "GraphRuntimeFactory": + from tvm.runtime.graph_runtime_factory import GraphRuntimeFactoryModule + return GraphRuntimeFactoryModule(module) + return module def cpu(self, dev_id=0): """Construct CPU device.""" diff --git a/python/tvm/runtime/graph_runtime_factory.py b/python/tvm/runtime/graph_runtime_factory.py index d14afa919cad..7e675af08b40 100644 --- a/python/tvm/runtime/graph_runtime_factory.py +++ b/python/tvm/runtime/graph_runtime_factory.py @@ -75,14 +75,15 @@ class GraphRuntimeFactoryModule(Module): def __init__(self, module): self.module = module - self._select_module = module.get_function("select_module") - self._import_module = module.get_function("import_module") + self._select_module = module["select_module"] + self._import_module = module["import_module"] self.selected_module = None - self.graph_json = module.get_function("get_json")() - self.lib = module.get_function("get_lib")() + self.graph_json = module["get_json"]() + self.lib = module["get_lib"]() self.params = {} - for k, v in module.get_function("get_params")().items(): - self.params[k] = v + # TODO (FrozenGene): Enable it + # for k, v in module["get_params"]().items(): + # self.params[k] = v self.iter_cnt = 0 super(GraphRuntimeFactoryModule, self).__init__(self.module.handle) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 7e993993dd55..f86fcdd24f30 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -41,9 +41,6 @@ def __init__(self, handle): self.handle = handle self._entry = None self.entry_name = "__tvm_main__" - # TODO:(FrozenGene): support rpc - if self.type_key == 'GraphRuntimeFactory': - self._entry = self.runtime_create def __del__(self): check_call(_LIB.TVMModFree(self.handle)) @@ -102,8 +99,6 @@ def import_module(self, module): check_call(_LIB.TVMModImport(self.handle, module.handle)) def __getitem__(self, name): - if self.type_key == 'GraphRuntimeFactory': - return self.get_function("select_module")(name) if not isinstance(name, string_types): raise ValueError("Can only take string as function name") return self.get_function(name) @@ -117,23 +112,6 @@ def __call__(self, *args): def __repr__(self): return "Module(%s, %x)" % (self.type_key, self.handle.value) - # TODO (FrozenGene): remove - def runtime_create(self, ctx): - """Create the runtime using ctx - - Parameters - ---------- - ctx : TVMContext or list of TVMContext - """ - from tvm.contrib.graph_runtime import get_device_ctx - from tvm._ffi.registry import get_global_func - ctx, num_rpc_ctx, device_type_id = get_device_ctx(self, ctx) - if num_rpc_ctx == len(ctx): - fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_factory.runtime_create") - else: - fcreate = get_global_func("tvm.graph_runtime_factory.runtime_create") - return fcreate(self, *device_type_id) - @property def type_key(self): """Get type key of the module.""" @@ -424,7 +402,12 @@ def load_module(path, fmt=""): elif path.endswith(".obj"): fmt = "micro_dev" # Redirect to the load API - return _ffi_api.ModuleLoadFromFile(path, fmt) + module = _ffi_api.ModuleLoadFromFile(path, fmt) + if module.type_key == 'GraphRuntimeFactory': + from tvm.runtime.graph_runtime_factory import GraphRuntimeFactoryModule + return GraphRuntimeFactoryModule(module) + return module + def enabled(target): diff --git a/tests/python/unittest/test_module_runtime_interface.py b/tests/python/unittest/test_module_runtime_interface.py index 8b964734335f..83d63f787e2d 100644 --- a/tests/python/unittest/test_module_runtime_interface.py +++ b/tests/python/unittest/test_module_runtime_interface.py @@ -283,7 +283,7 @@ def test_rpc_export(format=".so"): set_input = gmod["set_input"] run = gmod["run"] get_output = gmod["get_output"] - set_input("data", data) + set_input("data", tvm.nd.array(data, ctx=ctx)) run() out = get_output(0).asnumpy() @@ -355,16 +355,16 @@ def test_rpc_export(format=".so"): # # tvm.testing.assert_allclose(out, verify(data), atol=1e-5) if __name__ == "__main__": - test_legacy_compatibility() - test_cpu() - test_gpu() - test_multi_models() - test_cpu_export(".so") - test_cpu_export(".tar") + # test_legacy_compatibility() + # test_cpu() # test_gpu() - test_gpu_export(".so") + # test_multi_models() + # test_cpu_export(".so") + # test_cpu_export(".tar") + # test_gpu() + # test_gpu_export(".so") # test_gpu_export(".tar") - # test_rpc_export(".so") + test_rpc_export(".so") # test_rpc_export(".tar") # test_previous_cpu_export(".so") # test_previous_cpu_export(".tar") From 95803d0074376ddc9ee4230351d9014160111946 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Fri, 19 Jun 2020 20:42:20 +0800 Subject: [PATCH 04/29] solve rpc issue --- python/tvm/runtime/graph_runtime_factory.py | 11 +++++---- .../unittest/test_module_runtime_interface.py | 23 +++++++++++-------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/python/tvm/runtime/graph_runtime_factory.py b/python/tvm/runtime/graph_runtime_factory.py index 7e675af08b40..ecae4f03c46b 100644 --- a/python/tvm/runtime/graph_runtime_factory.py +++ b/python/tvm/runtime/graph_runtime_factory.py @@ -78,12 +78,9 @@ def __init__(self, module): self._select_module = module["select_module"] self._import_module = module["import_module"] self.selected_module = None - self.graph_json = module["get_json"]() - self.lib = module["get_lib"]() + self.graph_json = None + self.lib = None self.params = {} - # TODO (FrozenGene): Enable it - # for k, v in module["get_params"]().items(): - # self.params[k] = v self.iter_cnt = 0 super(GraphRuntimeFactoryModule, self).__init__(self.module.handle) @@ -132,6 +129,10 @@ def __iter__(self): warnings.warn( "legacy graph runtime behaviour of producing json / lib / params will be removed in the next release ", DeprecationWarning, 2) + self.graph_json = self.module["get_json"]() + self.lib = self.module["get_lib"]() + for k, v in self.module["get_params"]().items(): + self.params[k] = v return self diff --git a/tests/python/unittest/test_module_runtime_interface.py b/tests/python/unittest/test_module_runtime_interface.py index 83d63f787e2d..24864d2365e6 100644 --- a/tests/python/unittest/test_module_runtime_interface.py +++ b/tests/python/unittest/test_module_runtime_interface.py @@ -286,6 +286,12 @@ def test_rpc_export(format=".so"): set_input("data", tvm.nd.array(data, ctx=ctx)) run() out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + gmod = graph_runtime.create4unified(loaded_lib['default'], ctx) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() tvm.testing.assert_allclose(out, verify(data), atol=1e-5) # @@ -355,15 +361,14 @@ def test_rpc_export(format=".so"): # # tvm.testing.assert_allclose(out, verify(data), atol=1e-5) if __name__ == "__main__": - # test_legacy_compatibility() - # test_cpu() - # test_gpu() - # test_multi_models() - # test_cpu_export(".so") - # test_cpu_export(".tar") - # test_gpu() - # test_gpu_export(".so") - # test_gpu_export(".tar") + test_legacy_compatibility() + test_cpu() + test_gpu() + test_multi_models() + test_cpu_export(".so") + test_cpu_export(".tar") + test_gpu_export(".so") + test_gpu_export(".tar") test_rpc_export(".so") # test_rpc_export(".tar") # test_previous_cpu_export(".so") From 166e0099da3f3f06e8070aad16af7c071b910e3d Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Mon, 22 Jun 2020 19:38:44 +0800 Subject: [PATCH 05/29] support package params --- python/tvm/runtime/graph_runtime_factory.py | 11 +- python/tvm/runtime/module.py | 38 ++- src/runtime/graph/graph_runtime_factory.cc | 56 ++-- src/runtime/graph/graph_runtime_factory.h | 2 + .../unittest/test_module_runtime_interface.py | 284 ++++++++++-------- 5 files changed, 233 insertions(+), 158 deletions(-) diff --git a/python/tvm/runtime/graph_runtime_factory.py b/python/tvm/runtime/graph_runtime_factory.py index ecae4f03c46b..b2568d8111d0 100644 --- a/python/tvm/runtime/graph_runtime_factory.py +++ b/python/tvm/runtime/graph_runtime_factory.py @@ -143,13 +143,4 @@ def __next__(self): objs = [self.graph_json, self.lib, self.params] obj = objs[self.iter_cnt] self.iter_cnt += 1 - return obj - - def get_json(self): - return self.graph_json - - def get_lib(self): - return self.lib - - def get_params(self): - return self.params \ No newline at end of file + return obj \ No newline at end of file diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index f86fcdd24f30..d3afa18041f5 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -19,6 +19,7 @@ """Runtime Module namespace.""" import ctypes import struct +import os from collections import namedtuple import tvm._ffi @@ -222,29 +223,31 @@ def evaluator(*args): except NameError: raise NameError("time_evaluate is only supported when RPC is enabled") - def _collect_dso_modules(self): - """Helper function to collect dso modules, then return it.""" - visited, stack, dso_modules = set(), [], [] + def _collect_modules(self, module_type_keys): + """Helper function to collect specifit modules, then return it.""" + visited, stack, modules = set(), [], [] + type_keys = module_type_keys if isinstance(module_type_keys, (list, tuple)) else [module_type_keys] # append root module visited.add(self) stack.append(self) while stack: module = stack.pop() - if module._dso_exportable(): - dso_modules.append(module) + if module.type_key in type_keys: + modules.append(module) for m in module.imported_modules: if m not in visited: visited.add(m) stack.append(m) - return dso_modules + return modules def _dso_exportable(self): - return self.type_key == "llvm" or self.type_key == "c" + return ["llvm", "c"] def export_library(self, file_name, fcompile=None, addons=None, + package_params=True, **kwargs): """Export the module and its imported device code one library. @@ -261,6 +264,13 @@ def export_library(self, If fcompile has attribute object_format, will compile host library to that format. Otherwise, will use default format "o". + addons : str, optional + Extra files needed to be passed to compiler. + + package_params: bool, optional. + Whether we will package params into library. + The default value is True. + kwargs : dict, optional Additional arguments passed to fcompile """ @@ -282,7 +292,19 @@ def export_library(self, self.save(file_name) return - modules = self._collect_dso_modules() + graph_runtime_factory_modules = self._collect_modules("GraphRuntimeFactory") + for index, module in enumerate(graph_runtime_factory_modules): + if not package_params: + module.get_function("diable_package_params")() + path_params = os.path.join(os.path.dirname(file_name), "deploy_" + str(index) + ".params") + from tvm import relay + with open(path_params, "wb") as fo: + graph_params = {} + for k, v in module.get_function("get_params")().items(): + graph_params[k] = v + fo.write(relay.save_param_dict(graph_params)) + + modules = self._collect_modules(self._dso_exportable()) temp = _util.tempdir() files = addons if addons else [] is_system_lib = False diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index 79b6c9a0885f..beb22a2a087a 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -89,6 +89,10 @@ PackedFunc GraphRuntimeFactory::GetFunction(const std::string& name, } *rv = ret; }); + } else if (name == "diable_package_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->package_params_ = false; + }); } else { return PackedFunc(); } @@ -98,18 +102,21 @@ void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) { stream->Write(module_names_); stream->Write(kind_); stream->Write(graph_json_); - std::vector names; - std::vector arrays; - for (const auto& v : params_) { - names.emplace_back(v.first); - arrays.emplace_back(const_cast(v.second.operator->())); - } - stream->Write(names); - uint64_t sz = arrays.size(); - CHECK(sz == names.size()); - stream->Write(sz); - for (size_t i = 0; i < sz; ++i) { - tvm::runtime::SaveDLTensor(stream, arrays[i]); + stream->Write(package_params_); + if (package_params_) { + std::vector names; + std::vector arrays; + for (const auto& v : params_) { + names.emplace_back(v.first); + arrays.emplace_back(const_cast(v.second.operator->())); + } + uint64_t sz = arrays.size(); + CHECK(sz == names.size()); + stream->Write(sz); + stream->Write(names); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(stream, arrays[i]); + } } } @@ -145,21 +152,24 @@ Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { std::vector module_names; std::string kind; std::string graph_json; + bool package_params; + std::unordered_map params; CHECK(stream->Read(&module_names)); CHECK(stream->Read(&kind)); CHECK(stream->Read(&graph_json)); - std::vector names; - CHECK(stream->Read(&names)); - uint64_t sz; - CHECK(stream->Read(&sz)); - CHECK(sz == names.size()); - std::unordered_map params; - for (size_t i = 0; i < sz; ++i) { - tvm::runtime::NDArray temp; - temp.Load(stream); - params[names[i]] = temp; + CHECK(stream->Read(&package_params)); + if (package_params) { + uint64_t sz; + CHECK(stream->Read(&sz)); + std::vector names; + CHECK(stream->Read(&names)); + CHECK(sz == names.size()); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::NDArray temp; + temp.Load(stream); + params[names[i]] = temp; + } } - auto exec = make_object(); exec->Init(kind, graph_json, params); exec->SetModuleNames(module_names); diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index 0bc8d70d6570..b142fcb217bd 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -127,6 +127,8 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { std::string kind_; /*! \brief module names list */ std::vector module_names_; + /*! \brief whether to package params */ + bool package_params_ = true; }; diff --git a/tests/python/unittest/test_module_runtime_interface.py b/tests/python/unittest/test_module_runtime_interface.py index 24864d2365e6..ea2ef1102386 100644 --- a/tests/python/unittest/test_module_runtime_interface.py +++ b/tests/python/unittest/test_module_runtime_interface.py @@ -185,12 +185,12 @@ def test_gpu_export(format=".so"): out = get_output(0).asnumpy() tvm.testing.assert_allclose(out, verify(data), atol=1e-5) -# + def test_previous_cpu_export(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): graph, lib, graph_params = relay.build_module.build( - mod, "llvm --system-lib", params=params, export_graph_module=True) + mod, "llvm", params=params, export_graph_module=True) from tvm.contrib import util temp = util.tempdir() @@ -206,12 +206,7 @@ def test_previous_cpu_export(format=".so"): with open(temp.relpath("deploy_param.params"), "wb") as fo: fo.write(relay.save_param_dict(graph_params)) loaded_json = open(temp.relpath("deploy_graph.json")).read() - #loaded_lib = tvm.module.load(path_lib) - import ctypes - # Load dll, will trigger system library registration - dll = ctypes.CDLL(path_lib) - # Load the system wide library - loaded_lib = tvm.runtime.system_lib() + loaded_lib = tvm.runtime.load_module(path_lib) loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") ctx = tvm.cpu() @@ -222,40 +217,39 @@ def test_previous_cpu_export(format=".so"): out = module.get_output(0).asnumpy() tvm.testing.assert_allclose(out, verify(data), atol=1e-5) -# -# def test_previous_gpu_export(format=".so"): -# #mod, params = get_workload() -# net, params = nnvm.testing.resnet.get_workload(num_layers=18) -# with nnvm.compiler.build_config(opt_level=3): -# graph, lib, graph_params = nnvm.compiler.build( -# net, "opencl", shape={'data': (1,3,224,224)}, params=params) -# -# from tvm.contrib import util -# temp = "tvm_deploy/" -# if format == ".so": -# file_name = "deploy_lib.so" -# else: -# assert format == ".tar" -# file_name = "deploy_lib.tar" -# path_lib = temp + file_name -# lib.export_library(path_lib) -# with open("tvm_deploy/deploy_graph.json", "w") as fo: -# fo.write(graph.json()) -# with open("tvm_deploy/deploy_param.params", "wb") as fo: -# fo.write(nnvm.compiler.save_param_dict(graph_params)) -# # loaded_json = open(temp.relpath("deploy_graph.json")).read() -# # loaded_lib = tvm.module.load(path_lib) -# # loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) -# # data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") -# # ctx = tvm.gpu() -# # module = graph_runtime.create(loaded_json, loaded_lib, ctx) -# # module.load_params(loaded_params) -# # module.set_input("data", data) -# # module.run() -# # out = module.get_output(0).asnumpy() -# # -# # tvm.testing.assert_allclose(out, verify(data), atol=1e-5) -# + +def test_previous_gpu_export(format=".so"): + mod, params = get_workload() + with relay.build_config(opt_level=3): + graph, lib, graph_params = relay.build_module.build( + mod, "cuda", params=params, export_graph_module=True) + + from tvm.contrib import util + temp = util.tempdir() + if format == ".so": + file_name = "deploy_lib.so" + else: + assert format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + lib.export_library(path_lib) + with open(temp.relpath("deploy_graph.json"), "w") as fo: + fo.write(graph) + with open(temp.relpath("deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(graph_params)) + loaded_json = open(temp.relpath("deploy_graph.json")).read() + loaded_lib = tvm.runtime.load_module(path_lib) + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.gpu() + module = graph_runtime.create(loaded_json, loaded_lib, ctx) + module.load_params(loaded_params) + module.set_input("data", data) + module.run() + out = module.get_output(0).asnumpy() + + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + def test_rpc_export(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): @@ -294,72 +288,128 @@ def test_rpc_export(format=".so"): out = gmod.get_output(0).asnumpy() tvm.testing.assert_allclose(out, verify(data), atol=1e-5) -# -# def test_previous_rpc_export(format=".so"): -# mod, params = get_workload() -# with relay.build_config(opt_level=3): -# graph, lib, graph_params = relay.build_module.build( -# mod, "llvm", params=params, export_graph_module=False) -# -# from tvm.contrib import util -# temp = util.tempdir() -# if format == ".so": -# file_name = "deploy_lib.so" -# else: -# assert format == ".tar" -# file_name = "deploy_lib.tar" -# path_lib = temp.relpath(file_name) -# lib.export_library(path_lib) -# with open(temp.relpath("deploy_graph.json"), "w") as fo: -# fo.write(graph) -# with open(temp.relpath("deploy_param.params"), "wb") as fo: -# fo.write(relay.save_param_dict(graph_params)) -# -# from tvm import rpc -# server = rpc.Server("localhost", use_popen=True) -# remote = rpc.connect(server.host, server.port) -# remote.upload(path_lib) -# loaded_json = open(temp.relpath("deploy_graph.json")).read() -# loaded_lib = remote.load_module(path_lib) -# loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) -# data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") -# ctx = remote.cpu() -# module = graph_runtime.create(loaded_json, loaded_lib, ctx) -# module.load_params(loaded_params) -# module.set_input("data", data) -# module.run() -# out = module.get_output(0).asnumpy() -# -# tvm.testing.assert_allclose(out, verify(data), atol=1e-5) -# -# -# def test_previous_gpu_load(): -# loaded_json = open("tvm_deploy/deploy_graph.json").read() -# loaded_lib = tvm.module.load("tvm_deploy/deploy_lib.so") -# loaded_params = bytearray(open("tvm_deploy/deploy_param.params", "rb").read()) -# data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") -# ctx = tvm.gpu() -# module = graph_runtime.create(loaded_json, loaded_lib, ctx) -# module.load_params(loaded_params) -# module.set_input("data", data) -# module.run() -# out = module.get_output(0).asnumpy() -# -# tvm.testing.assert_allclose(out, verify(data), atol=1e-5) -# -# def test_previous_cpu_load(): -# loaded_json = open("tvm_deploy/deploy_cpu_graph.json").read() -# loaded_lib = tvm.module.load("tvm_deploy/deploy_cpu_lib.so") -# loaded_params = bytearray(open("tvm_deploy/deploy_cpu_param.params", "rb").read()) -# data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") -# ctx = tvm.cpu() -# module = graph_runtime.create(loaded_json, loaded_lib, ctx) -# module.load_params(loaded_params) -# module.set_input("data", data) -# module.run() -# out = module.get_output(0).asnumpy() -# -# tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_previous_rpc_export(format=".so"): + mod, params = get_workload() + with relay.build_config(opt_level=3): + graph, lib, graph_params = relay.build_module.build( + mod, "llvm", params=params, export_graph_module=True) + + from tvm.contrib import util + temp = util.tempdir() + if format == ".so": + file_name = "deploy_lib.so" + else: + assert format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + lib.export_library(path_lib) + with open(temp.relpath("deploy_graph.json"), "w") as fo: + fo.write(graph) + with open(temp.relpath("deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(graph_params)) + + from tvm import rpc + server = rpc.Server("localhost", use_popen=True) + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_json = open(temp.relpath("deploy_graph.json")).read() + loaded_lib = remote.load_module(path_lib) + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = remote.cpu() + module = graph_runtime.create(loaded_json, loaded_lib, ctx) + module.load_params(loaded_params) + module.set_input("data", data) + module.run() + out = module.get_output(0).asnumpy() + + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_package_params(format=".so"): + mod, params = get_workload() + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build( + mod, "llvm", params=params, export_graph_module=True) + + from tvm.contrib import util + temp = util.tempdir() + if format == ".so": + file_name = "deploy_lib.so" + else: + assert format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib, package_params=False) + loaded_lib = tvm.runtime.load_module(path_lib) + ctx = tvm.cpu(0) + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + loaded_params = bytearray(open(temp.relpath("deploy_0.params"), "rb").read()) + set_input("data", tvm.nd.array(data)) + load_params(loaded_params) + run() + out = get_output(0).asnumpy() + + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_multi_models_package_params(format=".so"): + resnet18_mod, resnet18_params = get_workload() + resnet50_mod, resnet50_params = get_workload(50) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build( + resnet18_mod, "llvm", params=resnet18_params, mod_name='resnet18', export_graph_module=True) + with relay.build_config(opt_level=3): + resnet50_gpu_lib = relay.build_module.build( + resnet50_mod, "cuda", params=resnet50_params, mod_name='resnet50', export_graph_module=True) + complied_graph_lib.import_module(resnet50_gpu_lib, "resnet50") + + from tvm.contrib import util + temp = util.tempdir() + if format == ".so": + file_name = "deploy_lib.so" + else: + assert format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib, package_params=True) + loaded_lib = tvm.runtime.load_module(path_lib) + + # resnet18 + ctx = tvm.cpu(0) + gmod = loaded_lib['resnet18'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + # loaded_params = bytearray(open(temp.relpath("deploy_0.params"), "rb").read()) + set_input("data", tvm.nd.array(data)) + # load_params(loaded_params) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + print("CPU PASS") + + # resnet50 + ctx = tvm.gpu() + gmod = loaded_lib['resnet50'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + # loaded_params = bytearray(open(temp.relpath("deploy_1.params"), "rb").read()) + set_input("data", tvm.nd.array(data)) + # load_params(loaded_params) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data, num_layers=50), atol=1e-5) + if __name__ == "__main__": test_legacy_compatibility() test_cpu() @@ -370,12 +420,12 @@ def test_rpc_export(format=".so"): test_gpu_export(".so") test_gpu_export(".tar") test_rpc_export(".so") - # test_rpc_export(".tar") - # test_previous_cpu_export(".so") - # test_previous_cpu_export(".tar") - #test_previous_gpu_export(".so") - # test_previous_gpu_export(".tar") - # test_previous_rpc_export(".so") - # test_previous_rpc_export(".tar") - #test_previous_gpu_load() - #test_previous_cpu_load() \ No newline at end of file + test_rpc_export(".tar") + test_previous_cpu_export(".so") + test_previous_cpu_export(".tar") + test_previous_gpu_export(".so") + test_previous_gpu_export(".tar") + test_previous_rpc_export(".so") + test_previous_rpc_export(".tar") + test_package_params(".so") + # test_multi_models_package_params(".so") \ No newline at end of file From 52adf7958ab9aa081248144f5d01b17b767620f7 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Tue, 23 Jun 2020 00:02:35 +0800 Subject: [PATCH 06/29] [Complete all the functionality] Support multi models of package params --- src/runtime/module.cc | 16 +++++-- .../unittest/test_module_runtime_interface.py | 47 +++++++++---------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 46ef6fab082b..769e9d9719f6 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -66,9 +66,19 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) PackedFunc pf = self->GetFunction(name, GetObjectPtr(this)); if (pf != nullptr) return pf; if (query_imports) { - for (Module& m : self->imports_) { - pf = m->GetFunction(name, m.data_); - if (pf != nullptr) return pf; + std::unordered_set visited{self}; + std::vector stack{self}; + while (!stack.empty()) { + ModuleNode* n = stack.back(); + stack.pop_back(); + for (Module& m : n->imports_) { + ModuleNode* next = m.operator->(); + if (visited.count(next)) continue; + pf = m->GetFunction(name, m.data_); + if (pf != nullptr) return pf; + visited.insert(next); + stack.push_back(next); + } } } return pf; diff --git a/tests/python/unittest/test_module_runtime_interface.py b/tests/python/unittest/test_module_runtime_interface.py index ea2ef1102386..79abda09905e 100644 --- a/tests/python/unittest/test_module_runtime_interface.py +++ b/tests/python/unittest/test_module_runtime_interface.py @@ -376,7 +376,7 @@ def test_multi_models_package_params(format=".so"): assert format == ".tar" file_name = "deploy_lib.tar" path_lib = temp.relpath(file_name) - complied_graph_lib.export_library(path_lib, package_params=True) + complied_graph_lib.export_library(path_lib, package_params=False) loaded_lib = tvm.runtime.load_module(path_lib) # resnet18 @@ -387,13 +387,12 @@ def test_multi_models_package_params(format=".so"): get_output = gmod["get_output"] load_params = gmod["load_params"] data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - # loaded_params = bytearray(open(temp.relpath("deploy_0.params"), "rb").read()) + loaded_params = bytearray(open(temp.relpath("deploy_0.params"), "rb").read()) set_input("data", tvm.nd.array(data)) - # load_params(loaded_params) + load_params(loaded_params) run() out = get_output(0).asnumpy() tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - print("CPU PASS") # resnet50 ctx = tvm.gpu() @@ -403,29 +402,29 @@ def test_multi_models_package_params(format=".so"): get_output = gmod["get_output"] load_params = gmod["load_params"] data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - # loaded_params = bytearray(open(temp.relpath("deploy_1.params"), "rb").read()) + loaded_params = bytearray(open(temp.relpath("deploy_1.params"), "rb").read()) set_input("data", tvm.nd.array(data)) - # load_params(loaded_params) + load_params(loaded_params) run() out = get_output(0).asnumpy() tvm.testing.assert_allclose(out, verify(data, num_layers=50), atol=1e-5) if __name__ == "__main__": - test_legacy_compatibility() - test_cpu() - test_gpu() - test_multi_models() - test_cpu_export(".so") - test_cpu_export(".tar") - test_gpu_export(".so") - test_gpu_export(".tar") - test_rpc_export(".so") - test_rpc_export(".tar") - test_previous_cpu_export(".so") - test_previous_cpu_export(".tar") - test_previous_gpu_export(".so") - test_previous_gpu_export(".tar") - test_previous_rpc_export(".so") - test_previous_rpc_export(".tar") - test_package_params(".so") - # test_multi_models_package_params(".so") \ No newline at end of file + # test_legacy_compatibility() + # test_cpu() + # test_gpu() + # test_multi_models() + # test_cpu_export(".so") + # test_cpu_export(".tar") + # test_gpu_export(".so") + # test_gpu_export(".tar") + # test_rpc_export(".so") + # test_rpc_export(".tar") + # test_previous_cpu_export(".so") + # test_previous_cpu_export(".tar") + # test_previous_gpu_export(".so") + # test_previous_gpu_export(".tar") + # test_previous_rpc_export(".so") + # test_previous_rpc_export(".tar") + # test_package_params(".so") + test_multi_models_package_params(".so") \ No newline at end of file From 081af5ffa8303238c43917c893a439245fff031b Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Tue, 23 Jun 2020 19:32:38 +0800 Subject: [PATCH 07/29] refactor graph runtime module list --- python/tvm/runtime/graph_runtime_factory.py | 6 +-- src/runtime/graph/graph_runtime.cc | 5 +- src/runtime/graph/graph_runtime.h | 38 +++------------ src/runtime/graph/graph_runtime_factory.cc | 47 ++++++++++--------- src/runtime/graph/graph_runtime_factory.h | 30 +++++++----- .../unittest/test_module_runtime_interface.py | 39 ++++++++------- 6 files changed, 74 insertions(+), 91 deletions(-) diff --git a/python/tvm/runtime/graph_runtime_factory.py b/python/tvm/runtime/graph_runtime_factory.py index b2568d8111d0..7925b2541c9b 100644 --- a/python/tvm/runtime/graph_runtime_factory.py +++ b/python/tvm/runtime/graph_runtime_factory.py @@ -101,17 +101,15 @@ def runtime_create(self, ctx): fcreate = get_global_func("tvm.graph_runtime_factory.runtime_create") return fcreate(self.selected_module, *device_type_id) - def import_module(self, mod, mod_name): + def import_module(self, mod): """Create the runtime using ctx Parameters ---------- mod : GraphRuntimeFactoryModule The graph runtime factory module we want to import - mod_name: str - The module name """ - return self._import_module(mod, mod_name) + return self._import_module(mod) def __getitem__(self, key='default'): """Get specific module diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 4fa72986d2de..62386b049e08 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -65,10 +65,7 @@ void GraphRuntime::Run() { * executed on. */ void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module module, - const std::vector& ctxs, - const std::unordered_map& params) { - graph_json_ = graph_json; - params_ = params; + const std::vector& ctxs) { std::istringstream is(graph_json); dmlc::JSONReader reader(&is); this->Load(&reader); diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index d08026eac855..952bdc86cf4c 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -100,8 +100,7 @@ class TVM_DLL GraphRuntime : public GraphRuntimeFactory { */ void Init(const std::string& graph_json, tvm::runtime::Module module, - const std::vector& ctxs, - const std::unordered_map& params = {}); + const std::vector& ctxs); /*! * \brief Get the input index given the name of input. @@ -175,34 +174,21 @@ class TVM_DLL GraphRuntime : public GraphRuntimeFactory { std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; } - /*! - * \brief Set graph json value. - * \param graph_json The graph json value we want to set. - */ - void SetGraphJson(const std::string& graph_json) { graph_json_ = graph_json; } - - /*! - * \brief Get the graph json. - * \return The graph json. - */ - std::string GetGraphJson() const { return graph_json_; } - /*! * \brief Set the graph params. * \param params The graph params value we want to set. */ void SetParams(const std::unordered_map& params) { - params_ = params; - + std::unordered_map value = params; // upload big arrays first to avoid memory issue in rpc mode std::vector keys; - for (const auto& p : params_) { + for (const auto& p : value) { keys.emplace_back(p.first); } std::sort(std::begin(keys), std::end(keys), - [this](const std::string& lhs, const std::string& rhs) -> bool { - auto lhs_shape = params_[lhs].Shape(); - auto rhs_shape = params_[rhs].Shape(); + [&](const std::string& lhs, const std::string& rhs) -> bool { + auto lhs_shape = value[lhs].Shape(); + auto rhs_shape = value[rhs].Shape(); auto lhs_prod = std::accumulate(std::begin(lhs_shape), std::end(lhs_shape), 1, std::multiplies()); auto rhs_prod = std::accumulate(std::begin(rhs_shape), std::end(rhs_shape), 1, @@ -213,17 +199,11 @@ class TVM_DLL GraphRuntime : public GraphRuntimeFactory { for (const auto& key : keys) { int in_idx = this->GetInputIndex(key); if (in_idx >= 0) { - this->SetInput(in_idx, const_cast(params_[key].operator->())); + this->SetInput(in_idx, const_cast(value[key].operator->())); } } } - /*! - * \brief Get the graph params. - * \return The graph params. - */ - std::unordered_map GetParams() const { return params_; } - protected: // Memory pool entry. struct PoolEntry { @@ -442,10 +422,6 @@ class TVM_DLL GraphRuntime : public GraphRuntimeFactory { std::vector outputs_; /*! \brief Additional graph attributes. */ GraphAttr attrs_; - /*! \brief The execution graph. */ - std::string graph_json_; - /*! \brief The params. */ - std::unordered_map params_; /*! \brief The code module that contains both host and device code. */ tvm::runtime::Module module_; /*! \brief Execution context of all devices including the host. */ diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index beb22a2a087a..f376fb114909 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -35,15 +35,20 @@ namespace runtime { void GraphRuntimeFactory::Init(const std::string& kind, const std::string& graph_json, - const std::unordered_map& params) { + const std::unordered_map& params, + const std::string& module_name) { kind_ = kind; graph_json_ = graph_json; params_ = params; + module_name_ = module_name; + graph_runtime_factory_module_list_.push_back(module_name_); } -void GraphRuntimeFactory::ImportModule(Module other, std::string module_name) { +void GraphRuntimeFactory::ImportModule(Module other) { this->Import(other); - module_names_.push_back(module_name); + auto module = other.as(); + CHECK(module) << "should only import graph runtiem factory module"; + graph_runtime_factory_module_list_.push_back(module->GetModuleName()); } PackedFunc GraphRuntimeFactory::GetFunction(const std::string& name, @@ -64,8 +69,8 @@ PackedFunc GraphRuntimeFactory::GetFunction(const std::string& name, }); } else if (name == "import_module") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 2); - this->ImportModule(args[0], args[1]); + CHECK_EQ(args.size(), 1); + this->ImportModule(args[0]); }); } else if (name == "select_module") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -74,17 +79,17 @@ PackedFunc GraphRuntimeFactory::GetFunction(const std::string& name, }); } else if (name == "get_json") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_json_; + *rv = this->GetJson(); }); } else if (name == "get_lib") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_GT(this->imports().size(), 0); - *rv = this->imports_[0]; + *rv = this->GetLib(); }); } else if (name == "get_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map ret; - for (const auto& kv : this->params_) { + for (const auto& kv : this->GetParams()) { ret.Set(kv.first, kv.second); } *rv = ret; @@ -99,7 +104,7 @@ PackedFunc GraphRuntimeFactory::GetFunction(const std::string& name, } void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) { - stream->Write(module_names_); + stream->Write(graph_runtime_factory_module_list_); stream->Write(kind_); stream->Write(graph_json_); stream->Write(package_params_); @@ -134,27 +139,27 @@ Module GraphRuntimeFactory::RuntimeCreate(Module module, const std::vector(); - exec->Init(this->kind_, this->graph_json_, this->params_); - exec->ImportModule(this->imports_[0], *iter); + exec->Init(this->GetKind(), this->GetJson(), this->GetParams()); + exec->Import(this->GetLib()); return Module(exec); } else { - return this->imports_[std::distance(module_names_.begin(), iter)]; + return this->imports()[std::distance(graph_runtime_factory_module_list_.begin(), iter)]; } } Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); - std::vector module_names; + std::vector graph_runtime_factory_module_list; std::string kind; std::string graph_json; bool package_params; std::unordered_map params; - CHECK(stream->Read(&module_names)); + CHECK(stream->Read(&graph_runtime_factory_module_list)); CHECK(stream->Read(&kind)); CHECK(stream->Read(&graph_json)); CHECK(stream->Read(&package_params)); @@ -172,7 +177,7 @@ Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { } auto exec = make_object(); exec->Init(kind, graph_json, params); - exec->SetModuleNames(module_names); + exec->SetGraphRuntimeFactoryModuleList(graph_runtime_factory_module_list); return Module(exec); } @@ -205,8 +210,8 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create") std::string name = args[i].operator String(); params[name] = args[i + 1].operator tvm::runtime::NDArray(); } - exec->Init(args[0], args[1], params); - exec->ImportModule(args[2], args[3]); + exec->Init(args[0], args[1], params, args[3]); + exec->Import(args[2]); *rv = Module(exec); }); diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index b142fcb217bd..8fc51432f44b 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -50,9 +50,11 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { void Init(const std::string& kind, const std::string& graph_json, - const std::unordered_map& params); + const std::unordered_map& params, + const std::string& module_name = "default"); + + void ImportModule(Module other); - void ImportModule(Module other, std::string module_name); /*! * \brief Get member function to front-end @@ -93,7 +95,7 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { */ Module SelectModule(const std::string& name); - inline std::string GetJson() const { + const std::string& GetJson() const { return graph_json_; } @@ -106,16 +108,21 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { return this->imports_[0]; } - inline std::string GetKind() const { + const std::string& GetKind() const { return kind_; } - inline std::vector GetModuleNames() const { - return module_names_; + const std::string& GetModuleName() const { + return module_name_; + } + + const std::vector& GetGraphRuntimeFactoryModuleList() const { + return graph_runtime_factory_module_list_; } - inline void SetModuleNames(const std::vector& module_names) { - module_names_ = module_names; + void SetGraphRuntimeFactoryModuleList( + const std::vector& graph_runtime_factory_module_list) { + graph_runtime_factory_module_list_ = graph_runtime_factory_module_list; } protected: @@ -125,11 +132,12 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { std::unordered_map params_; /*! \brief runtime kind */ std::string kind_; - /*! \brief module names list */ - std::vector module_names_; + /*! \brief module name */ + std::string module_name_; /*! \brief whether to package params */ bool package_params_ = true; - + /*! \brief graph runtime factory module lists */ + std::vector graph_runtime_factory_module_list_; }; } // namespace runtime diff --git a/tests/python/unittest/test_module_runtime_interface.py b/tests/python/unittest/test_module_runtime_interface.py index 79abda09905e..35478a0a8884 100644 --- a/tests/python/unittest/test_module_runtime_interface.py +++ b/tests/python/unittest/test_module_runtime_interface.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import copy import numpy as np from tvm import relay from tvm.relay import testing @@ -105,7 +104,7 @@ def test_multi_models(): with relay.build_config(opt_level=3): resnet50_gpu_lib = relay.build_module.build( resnet50_mod, "cuda", params=resnet50_params, mod_name='resnet50', export_graph_module=True) - complied_graph_lib.import_module(resnet50_gpu_lib, "resnet50") + complied_graph_lib.import_module(resnet50_gpu_lib) data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") # resnet18 cpu_ctx = tvm.cpu() @@ -366,7 +365,7 @@ def test_multi_models_package_params(format=".so"): with relay.build_config(opt_level=3): resnet50_gpu_lib = relay.build_module.build( resnet50_mod, "cuda", params=resnet50_params, mod_name='resnet50', export_graph_module=True) - complied_graph_lib.import_module(resnet50_gpu_lib, "resnet50") + complied_graph_lib.import_module(resnet50_gpu_lib) from tvm.contrib import util temp = util.tempdir() @@ -410,21 +409,21 @@ def test_multi_models_package_params(format=".so"): tvm.testing.assert_allclose(out, verify(data, num_layers=50), atol=1e-5) if __name__ == "__main__": - # test_legacy_compatibility() - # test_cpu() - # test_gpu() - # test_multi_models() - # test_cpu_export(".so") - # test_cpu_export(".tar") - # test_gpu_export(".so") - # test_gpu_export(".tar") - # test_rpc_export(".so") - # test_rpc_export(".tar") - # test_previous_cpu_export(".so") - # test_previous_cpu_export(".tar") - # test_previous_gpu_export(".so") - # test_previous_gpu_export(".tar") - # test_previous_rpc_export(".so") - # test_previous_rpc_export(".tar") - # test_package_params(".so") + test_legacy_compatibility() + test_cpu() + test_gpu() + test_multi_models() + test_cpu_export(".so") + test_cpu_export(".tar") + test_gpu_export(".so") + test_gpu_export(".tar") + test_rpc_export(".so") + test_rpc_export(".tar") + test_previous_cpu_export(".so") + test_previous_cpu_export(".tar") + test_previous_gpu_export(".so") + test_previous_gpu_export(".tar") + test_previous_rpc_export(".so") + test_previous_rpc_export(".tar") + test_package_params(".so") test_multi_models_package_params(".so") \ No newline at end of file From 4b34dc2bb8699e32b97488ec356dff39dc421f0b Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Tue, 23 Jun 2020 19:33:39 +0800 Subject: [PATCH 08/29] header reorder --- src/runtime/graph/graph_runtime.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 62386b049e08..69a27bd836d5 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -20,7 +20,7 @@ /*! * \file graph_runtime.cc */ -//#include "graph_runtime.h" +#include "graph_runtime.h" #include #include @@ -36,7 +36,7 @@ #include #include #include -#include "./graph_runtime.h" + namespace tvm { namespace runtime { namespace details { From 1acf52dc726970665db6701a9044409b8dd8dce4 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Tue, 23 Jun 2020 19:35:32 +0800 Subject: [PATCH 09/29] graph runtime debug --- src/runtime/graph/debug/graph_runtime_debug.cc | 4 +++- src/runtime/graph/graph_runtime.cc | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 70c027d6d0fe..5439be9109f9 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -24,10 +24,12 @@ #include #include #include -#include + #include #include +#include "../graph_runtime.h" + namespace tvm { namespace runtime { diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 69a27bd836d5..e984861769a0 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -28,6 +28,7 @@ #include #include #include + #include #include #include From b3f28735814dc96557aed00ef1bdc6ca98da6201 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Tue, 23 Jun 2020 19:42:22 +0800 Subject: [PATCH 10/29] function signature --- src/runtime/graph/graph_runtime_factory.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index 8fc51432f44b..9c85ad878cb9 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -99,11 +99,11 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { return graph_json_; } - inline std::unordered_map GetParams() const { + std::unordered_map GetParams() const { return params_; } - inline Module GetLib() const { + Module GetLib() const { CHECK_GT(this->imports().size(), 0); return this->imports_[0]; } From 46ff4e11f568adaeebd072e25e0483befb3cce07 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Tue, 23 Jun 2020 19:56:44 +0800 Subject: [PATCH 11/29] rebase to master and solve lint / clang-format error --- src/runtime/graph/graph_runtime.h | 2 + src/runtime/graph/graph_runtime_factory.cc | 97 +++++++++++----------- src/runtime/graph/graph_runtime_factory.h | 51 +++++------- 3 files changed, 69 insertions(+), 81 deletions(-) diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index 952bdc86cf4c..f9543c100cc3 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -31,6 +31,8 @@ #include #include +#include +#include #include #include #include diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index f376fb114909..ad79e23e39a7 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -22,19 +22,21 @@ * \brief Graph runtime factory implementations */ -#include +#include "./graph_runtime_factory.h" + #include -#include +#include + #include #include -#include "./graph_runtime_factory.h" +#include + #include "./graph_runtime.h" namespace tvm { namespace runtime { -void GraphRuntimeFactory::Init(const std::string& kind, - const std::string& graph_json, +void GraphRuntimeFactory::Init(const std::string& kind, const std::string& graph_json, const std::unordered_map& params, const std::string& module_name) { kind_ = kind; @@ -47,12 +49,12 @@ void GraphRuntimeFactory::Init(const std::string& kind, void GraphRuntimeFactory::ImportModule(Module other) { this->Import(other); auto module = other.as(); - CHECK(module) << "should only import graph runtiem factory module"; + CHECK(module) << "should only import graph runtime factory module"; graph_runtime_factory_module_list_.push_back(module->GetModuleName()); } -PackedFunc GraphRuntimeFactory::GetFunction(const std::string& name, - const tvm::runtime::ObjectPtr& sptr_to_self) { +PackedFunc GraphRuntimeFactory::GetFunction( + const std::string& name, const tvm::runtime::ObjectPtr& sptr_to_self) { if (name == "runtime_create") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::vector contexts; @@ -78,9 +80,8 @@ PackedFunc GraphRuntimeFactory::GetFunction(const std::string& name, *rv = this->SelectModule(args[0]); }); } else if (name == "get_json") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetJson(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetJson(); }); } else if (name == "get_lib") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_GT(this->imports().size(), 0); @@ -95,9 +96,8 @@ PackedFunc GraphRuntimeFactory::GetFunction(const std::string& name, *rv = ret; }); } else if (name == "diable_package_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->package_params_ = false; - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->package_params_ = false; }); } else { return PackedFunc(); } @@ -125,7 +125,7 @@ void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) { } } -Module GraphRuntimeFactory::RuntimeCreate(Module module, const std::vector &ctxs) { +Module GraphRuntimeFactory::RuntimeCreate(Module module, const std::vector& ctxs) { auto factory_module = module.as(); CHECK(factory_module != nullptr); if (factory_module->GetKind() == "graph") { @@ -138,7 +138,7 @@ Module GraphRuntimeFactory::RuntimeCreate(Module module, const std::vector &ctxs) { +Module RuntimeCreate(Module module, const std::vector& ctxs) { auto mod = module.as(); CHECK(mod != nullptr); if (mod->GetKind() == "graph") { @@ -196,42 +196,41 @@ Module RuntimeCreate(Module module, const std::vector &ctxs) { return Module(); } -TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create") -.set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK_GE(args.num_args, 4) << "The expected number of arguments for " - "graph_runtime_factory.create needs at least 3, " - "but it has " - << args.num_args; - auto exec = make_object(); - // The argument order is graph_runtime_kind, graph_json, module, module_name, params. - CHECK_EQ((args.size() - 4) % 2, 0); - std::unordered_map params; - for (size_t i = 4; i < static_cast(args.size()); i += 2) { - std::string name = args[i].operator String(); - params[name] = args[i + 1].operator tvm::runtime::NDArray(); - } - exec->Init(args[0], args[1], params, args[3]); - exec->Import(args[2]); - *rv = Module(exec); - }); - -TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.runtime_create") -.set_body([](TVMArgs args, TVMRetValue* rv) { - std::vector contexts; - TVMContext ctx; - // arg is: module, ctxs - CHECK_EQ((args.size() - 1) % 2, 0); - for (int i = 1; i < args.num_args; i += 2) { - int dev_type = args[i]; - ctx.device_type = static_cast(dev_type); - ctx.device_id = args[i + 1]; - contexts.push_back(ctx); +TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) << "The expected number of arguments for " + "graph_runtime_factory.create needs at least 3, " + "but it has " + << args.num_args; + auto exec = make_object(); + // The argument order is graph_runtime_kind, graph_json, module, module_name, params. + CHECK_EQ((args.size() - 4) % 2, 0); + std::unordered_map params; + for (size_t i = 4; i < static_cast(args.size()); i += 2) { + std::string name = args[i].operator String(); + params[name] = args[i + 1].operator tvm::runtime::NDArray(); } - *rv = RuntimeCreate(args[0], contexts); + exec->Init(args[0], args[1], params, args[3]); + exec->Import(args[2]); + *rv = Module(exec); }); +TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.runtime_create") + .set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector contexts; + TVMContext ctx; + // arg is: module, ctxs + CHECK_EQ((args.size() - 1) % 2, 0); + for (int i = 1; i < args.num_args; i += 2) { + int dev_type = args[i]; + ctx.device_type = static_cast(dev_type); + ctx.device_id = args[i + 1]; + contexts.push_back(ctx); + } + *rv = RuntimeCreate(args[0], contexts); + }); + TVM_REGISTER_GLOBAL("runtime.module.loadbinary_GraphRuntimeFactory") -.set_body_typed(GraphRuntimeFactoryModuleLoadBinary); + .set_body_typed(GraphRuntimeFactoryModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index 9c85ad878cb9..1f59eb55e282 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -22,55 +22,51 @@ * \brief Graph runtime factory creating graph runtime. */ -#ifndef TVM_RUNTIME_GRAPH_RUNTIME_FACTORY_H_ -#define TVM_RUNTIME_GRAPH_RUNTIME_FACTORY_H_ +#ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ +#define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ #include -#include #include #include +#include + #include -#include #include - +#include namespace tvm { namespace runtime { class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { - public: - /*! * \brief Initialize the GraphRuntimeFactory with graph and context. * \param graph_json The execution graph. * \param params The params of graph. * \param kind The runtime kind to be created. */ - - void Init(const std::string& kind, - const std::string& graph_json, + void Init(const std::string& kind, const std::string& graph_json, const std::unordered_map& params, const std::string& module_name = "default"); + /*! + * \brief Import other GraphRuntimeFactory module. + * \param other The GraphRuntimeFactory module we want to import. + */ void ImportModule(Module other); - /*! * \brief Get member function to front-end * \param name The name of the function. * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. */ - const char* type_key() const override { - return "GraphRuntimeFactory"; - } + const char* type_key() const override { return "GraphRuntimeFactory"; } /*! * \brief Save the module to binary stream. @@ -78,7 +74,6 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { */ void SaveToBinary(dmlc::Stream* stream) override; - /*! * \brief Create a specific runtime module * \param module The module we will be used for creating runtime @@ -95,26 +90,18 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { */ Module SelectModule(const std::string& name); - const std::string& GetJson() const { - return graph_json_; - } + const std::string& GetJson() const { return graph_json_; } - std::unordered_map GetParams() const { - return params_; - } + std::unordered_map GetParams() const { return params_; } Module GetLib() const { CHECK_GT(this->imports().size(), 0); return this->imports_[0]; } - const std::string& GetKind() const { - return kind_; - } + const std::string& GetKind() const { return kind_; } - const std::string& GetModuleName() const { - return module_name_; - } + const std::string& GetModuleName() const { return module_name_; } const std::vector& GetGraphRuntimeFactoryModuleList() const { return graph_runtime_factory_module_list_; @@ -140,7 +127,7 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { std::vector graph_runtime_factory_module_list_; }; -} // namespace runtime -} // namespace tvm +} // namespace runtime +} // namespace tvm -#endif // TVM_RUNTIME_GRAPH_RUNTIME_FACTORY_H_ \ No newline at end of file +#endif // TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ From ba5b0c50f54a2eae1cf826fd0b9a010432eb57c5 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Tue, 23 Jun 2020 20:12:34 +0800 Subject: [PATCH 12/29] remove export_graph_mod --- python/tvm/relay/build_module.py | 8 ++-- .../unittest/test_module_runtime_interface.py | 38 +++++++------------ 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 5d7996a38a9d..a0dfa0eb69de 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -182,7 +182,7 @@ def get_params(self): return ret -def build(mod, target=None, target_host=None, params=None, mod_name='default', export_graph_module=False): +def build(mod, target=None, target_host=None, params=None, mod_name='default'): """Helper function that builds a Relay function to run on TVM graph runtime. @@ -250,10 +250,8 @@ def build(mod, target=None, target_host=None, params=None, mod_name='default', e with tophub_context: bld_mod = BuildModule() graph_json, mod, params = bld_mod.build(mod, target, target_host, params) - if export_graph_module: - mod = _graph_runtime_factory.create("graph", graph_json, mod, params, mod_name) - return mod - return graph_json, mod, params + mod = _graph_runtime_factory.create("graph", graph_json, mod, params, mod_name) + return mod def optimize(mod, target=None, params=None): diff --git a/tests/python/unittest/test_module_runtime_interface.py b/tests/python/unittest/test_module_runtime_interface.py index 35478a0a8884..d9331821bd8a 100644 --- a/tests/python/unittest/test_module_runtime_interface.py +++ b/tests/python/unittest/test_module_runtime_interface.py @@ -42,8 +42,7 @@ def verify(data, num_layers=18): def test_legacy_compatibility(): mod, params = get_workload() with relay.build_config(opt_level=3): - graph, lib, graph_params = relay.build_module.build( - mod, "llvm", params=params, export_graph_module=True) + graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") ctx = tvm.cpu() module = graph_runtime.create(graph, lib, ctx) @@ -56,8 +55,7 @@ def test_legacy_compatibility(): def test_cpu(): mod, params = get_workload() with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build( - mod, "llvm", params=params, export_graph_module=True) + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") # raw api ctx = tvm.cpu() @@ -81,8 +79,7 @@ def test_cpu(): def test_gpu(): mod, params = get_workload() with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build( - mod, "cuda", params=params, export_graph_module=True) + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") ctx = tvm.gpu() gmod = complied_graph_lib['default'](ctx) @@ -100,10 +97,10 @@ def test_multi_models(): resnet50_mod, resnet50_params = get_workload(50) with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build( - resnet18_mod, "llvm", params=resnet18_params, mod_name='resnet18', export_graph_module=True) + resnet18_mod, "llvm", params=resnet18_params, mod_name='resnet18') with relay.build_config(opt_level=3): resnet50_gpu_lib = relay.build_module.build( - resnet50_mod, "cuda", params=resnet50_params, mod_name='resnet50', export_graph_module=True) + resnet50_mod, "cuda", params=resnet50_params, mod_name='resnet50') complied_graph_lib.import_module(resnet50_gpu_lib) data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") # resnet18 @@ -132,8 +129,7 @@ def test_multi_models(): def test_cpu_export(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build( - mod, "llvm", params=params, export_graph_module=True) + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) from tvm.contrib import util temp = util.tempdir() @@ -160,8 +156,7 @@ def test_cpu_export(format=".so"): def test_gpu_export(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build( - mod, "cuda", params=params, export_graph_module=True) + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) from tvm.contrib import util temp = util.tempdir() @@ -188,8 +183,7 @@ def test_gpu_export(format=".so"): def test_previous_cpu_export(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): - graph, lib, graph_params = relay.build_module.build( - mod, "llvm", params=params, export_graph_module=True) + graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) from tvm.contrib import util temp = util.tempdir() @@ -220,8 +214,7 @@ def test_previous_cpu_export(format=".so"): def test_previous_gpu_export(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): - graph, lib, graph_params = relay.build_module.build( - mod, "cuda", params=params, export_graph_module=True) + graph, lib, graph_params = relay.build_module.build(mod, "cuda", params=params) from tvm.contrib import util temp = util.tempdir() @@ -252,8 +245,7 @@ def test_previous_gpu_export(format=".so"): def test_rpc_export(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build( - mod, "llvm", params=params, export_graph_module=True) + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) from tvm.contrib import util temp = util.tempdir() @@ -291,8 +283,7 @@ def test_rpc_export(format=".so"): def test_previous_rpc_export(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): - graph, lib, graph_params = relay.build_module.build( - mod, "llvm", params=params, export_graph_module=True) + graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) from tvm.contrib import util temp = util.tempdir() @@ -328,8 +319,7 @@ def test_previous_rpc_export(format=".so"): def test_package_params(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build( - mod, "llvm", params=params, export_graph_module=True) + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) from tvm.contrib import util temp = util.tempdir() @@ -361,10 +351,10 @@ def test_multi_models_package_params(format=".so"): resnet50_mod, resnet50_params = get_workload(50) with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build( - resnet18_mod, "llvm", params=resnet18_params, mod_name='resnet18', export_graph_module=True) + resnet18_mod, "llvm", params=resnet18_params, mod_name='resnet18') with relay.build_config(opt_level=3): resnet50_gpu_lib = relay.build_module.build( - resnet50_mod, "cuda", params=resnet50_params, mod_name='resnet50', export_graph_module=True) + resnet50_mod, "cuda", params=resnet50_params, mod_name='resnet50') complied_graph_lib.import_module(resnet50_gpu_lib) from tvm.contrib import util From bae954ca9a586e132e4a5d01b7949b6d635f93ef Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Mon, 6 Jul 2020 17:06:12 +0800 Subject: [PATCH 13/29] address comments --- python/tvm/runtime/module.py | 12 ++++++++---- src/runtime/graph/graph_runtime.h | 3 +-- src/runtime/graph/graph_runtime_factory.cc | 8 ++++++-- .../python/unittest/test_module_runtime_interface.py | 6 +++--- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index d3afa18041f5..62eb02d59202 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -20,6 +20,7 @@ import ctypes import struct import os +import warnings from collections import namedtuple import tvm._ffi @@ -224,7 +225,7 @@ def evaluator(*args): raise NameError("time_evaluate is only supported when RPC is enabled") def _collect_modules(self, module_type_keys): - """Helper function to collect specifit modules, then return it.""" + """Helper function to collect specific modules, then return it.""" visited, stack, modules = set(), [], [] type_keys = module_type_keys if isinstance(module_type_keys, (list, tuple)) else [module_type_keys] # append root module @@ -240,7 +241,7 @@ def _collect_modules(self, module_type_keys): stack.append(m) return modules - def _dso_exportable(self): + def _dso_exportable_types(self): return ["llvm", "c"] def export_library(self, @@ -296,7 +297,10 @@ def export_library(self, for index, module in enumerate(graph_runtime_factory_modules): if not package_params: module.get_function("diable_package_params")() - path_params = os.path.join(os.path.dirname(file_name), "deploy_" + str(index) + ".params") + params_file_name = "deploy_" + module.get_function("get_module_name")() + ".params" + warnings.warn("Disabled package params, we will generate file " + params_file_name, + stacklevel=2) + path_params = os.path.join(os.path.dirname(file_name), params_file_name) from tvm import relay with open(path_params, "wb") as fo: graph_params = {} @@ -304,7 +308,7 @@ def export_library(self, graph_params[k] = v fo.write(relay.save_param_dict(graph_params)) - modules = self._collect_modules(self._dso_exportable()) + modules = self._collect_modules(self._dso_exportable_types()) temp = _util.tempdir() files = addons if addons else [] is_system_lib = False diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index f9543c100cc3..0438c5bda0d8 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -68,7 +68,7 @@ struct TVMOpParam { * This runtime can be acccesibly in various language via * TVM runtime PackedFunc API. */ -class TVM_DLL GraphRuntime : public GraphRuntimeFactory { +class TVM_DLL GraphRuntime : public ModuleNode { struct OpArgs { std::vector args; std::vector arg_values; @@ -98,7 +98,6 @@ class TVM_DLL GraphRuntime : public GraphRuntimeFactory { * processor. * \param ctxs The context of the host and devices where graph nodes will be * executed on. - * \param params The params of graph. */ void Init(const std::string& graph_json, tvm::runtime::Module module, diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index ad79e23e39a7..8c709bd59b95 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -84,7 +84,6 @@ PackedFunc GraphRuntimeFactory::GetFunction( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetJson(); }); } else if (name == "get_lib") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_GT(this->imports().size(), 0); *rv = this->GetLib(); }); } else if (name == "get_params") { @@ -98,6 +97,11 @@ PackedFunc GraphRuntimeFactory::GetFunction( } else if (name == "diable_package_params") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->package_params_ = false; }); + } else if (name == "get_module_name") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetModuleName(); + }); } else { return PackedFunc(); } @@ -198,7 +202,7 @@ Module RuntimeCreate(Module module, const std::vector& ctxs) { TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args, TVMRetValue* rv) { CHECK_GE(args.num_args, 4) << "The expected number of arguments for " - "graph_runtime_factory.create needs at least 3, " + "graph_runtime_factory.create needs at least 4, " "but it has " << args.num_args; auto exec = make_object(); diff --git a/tests/python/unittest/test_module_runtime_interface.py b/tests/python/unittest/test_module_runtime_interface.py index d9331821bd8a..cbd7e11c8fb6 100644 --- a/tests/python/unittest/test_module_runtime_interface.py +++ b/tests/python/unittest/test_module_runtime_interface.py @@ -338,7 +338,7 @@ def test_package_params(format=".so"): get_output = gmod["get_output"] load_params = gmod["load_params"] data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - loaded_params = bytearray(open(temp.relpath("deploy_0.params"), "rb").read()) + loaded_params = bytearray(open(temp.relpath("deploy_default.params"), "rb").read()) set_input("data", tvm.nd.array(data)) load_params(loaded_params) run() @@ -376,7 +376,7 @@ def test_multi_models_package_params(format=".so"): get_output = gmod["get_output"] load_params = gmod["load_params"] data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - loaded_params = bytearray(open(temp.relpath("deploy_0.params"), "rb").read()) + loaded_params = bytearray(open(temp.relpath("deploy_resnet18.params"), "rb").read()) set_input("data", tvm.nd.array(data)) load_params(loaded_params) run() @@ -391,7 +391,7 @@ def test_multi_models_package_params(format=".so"): get_output = gmod["get_output"] load_params = gmod["load_params"] data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - loaded_params = bytearray(open(temp.relpath("deploy_1.params"), "rb").read()) + loaded_params = bytearray(open(temp.relpath("deploy_resnet50.params"), "rb").read()) set_input("data", tvm.nd.array(data)) load_params(loaded_params) run() From 1e69f4bcb6f475999e779c66eae9e627c9693d8f Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Wed, 8 Jul 2020 21:01:21 +0800 Subject: [PATCH 14/29] refactor --- python/tvm/contrib/graph_runtime.py | 9 - python/tvm/relay/build_module.py | 2 +- python/tvm/rpc/client.py | 7 +- python/tvm/runtime/graph_runtime_factory.py | 50 +----- python/tvm/runtime/module.py | 51 ++---- src/runtime/graph/graph_runtime.h | 34 ---- src/runtime/graph/graph_runtime_factory.cc | 162 +++++++++--------- src/runtime/graph/graph_runtime_factory.h | 62 ++++--- .../unittest/test_module_runtime_interface.py | 113 ++++++------ 9 files changed, 196 insertions(+), 294 deletions(-) diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 3bf09b79306b..9b714a84b541 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -63,15 +63,6 @@ def create(graph_json_str, libmod, ctx): return GraphModule(fcreate(graph_json_str, libmod, *device_type_id)) -# TODO (FrozenGene): rename -def create4unified(libmod, ctx): - ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx) - if num_rpc_ctx == len(ctx): - fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_factory.runtime_create") - else: - fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_factory.runtime_create") - - return GraphModule(fcreate(libmod, *device_type_id)) def get_device_ctx(libmod, ctx): """Parse and validate all the device context(s). diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index a0dfa0eb69de..e5bab4e18e62 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -250,7 +250,7 @@ def build(mod, target=None, target_host=None, params=None, mod_name='default'): with tophub_context: bld_mod = BuildModule() graph_json, mod, params = bld_mod.build(mod, target, target_host, params) - mod = _graph_runtime_factory.create("graph", graph_json, mod, params, mod_name) + mod = _graph_runtime_factory.create(graph_json, mod, mod_name, params) return mod diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 971ce4d06019..2f96c9b62976 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -160,12 +160,7 @@ def load_module(self, path): m : Module The remote module containing remote function. """ - module = _ffi_api.LoadRemoteModule(self._sess, path) - type_key = self.get_function("runtime.ModuleGetTypeKey")(module) - if type_key == "GraphRuntimeFactory": - from tvm.runtime.graph_runtime_factory import GraphRuntimeFactoryModule - return GraphRuntimeFactoryModule(module) - return module + return _ffi_api.LoadRemoteModule(self._sess, path) def cpu(self, dev_id=0): """Construct CPU device.""" diff --git a/python/tvm/runtime/graph_runtime_factory.py b/python/tvm/runtime/graph_runtime_factory.py index 7925b2541c9b..ac5063fb0164 100644 --- a/python/tvm/runtime/graph_runtime_factory.py +++ b/python/tvm/runtime/graph_runtime_factory.py @@ -27,18 +27,21 @@ from . import ndarray -def create(graph_runtime_kind, graph_json_str, libmod, params, module_name='default'): +def create(graph_json_str, libmod, libmod_name, params): """Create a runtime executor module given a graph and module. Parameters ---------- - graph_runtime_kind: str - The kind of graph runtime. Like graphruntime, vm and so on. graph_json_str : str or graph class The graph to be deployed in json format output by nnvm graph. The graph can only contain one operator(tvm_op) that points to the name of PackedFunc in the libmod. libmod : tvm.Module The module of the corresponding function + libmod_name: str + The name of module + params : dict of str to NDArray + The parameters of module + Returns ------- graph_module : GraphModule @@ -54,7 +57,7 @@ def create(graph_runtime_kind, graph_json_str, libmod, params, module_name='defa for k, v in params.items(): args.append(k) args.append(ndarray.array(v)) - return GraphRuntimeFactoryModule(fcreate(graph_runtime_kind, graph_json_str, libmod, module_name, *args)) + return GraphRuntimeFactoryModule(fcreate(graph_json_str, libmod, libmod_name, *args)) class GraphRuntimeFactoryModule(Module): @@ -75,9 +78,6 @@ class GraphRuntimeFactoryModule(Module): def __init__(self, module): self.module = module - self._select_module = module["select_module"] - self._import_module = module["import_module"] - self.selected_module = None self.graph_json = None self.lib = None self.params = {} @@ -87,42 +87,6 @@ def __init__(self, module): def __del__(self): pass - def runtime_create(self, ctx): - """Create the runtime using ctx - - Parameters - ---------- - ctx : TVMContext or list of TVMContext - """ - ctx, num_rpc_ctx, device_type_id = get_device_ctx(self.selected_module, ctx) - if num_rpc_ctx == len(ctx): - fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_factory.runtime_create") - else: - fcreate = get_global_func("tvm.graph_runtime_factory.runtime_create") - return fcreate(self.selected_module, *device_type_id) - - def import_module(self, mod): - """Create the runtime using ctx - - Parameters - ---------- - mod : GraphRuntimeFactoryModule - The graph runtime factory module we want to import - """ - return self._import_module(mod) - - def __getitem__(self, key='default'): - """Get specific module - - Parameters - ---------- - key : str - The key of module. - """ - self.selected_module = self._select_module(key) - self.selected_module._entry = self.runtime_create - return self.selected_module - def __iter__(self): warnings.warn( "legacy graph runtime behaviour of producing json / lib / params will be removed in the next release ", diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 62eb02d59202..3cdb28f8c496 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -19,8 +19,6 @@ """Runtime Module namespace.""" import ctypes import struct -import os -import warnings from collections import namedtuple import tvm._ffi @@ -224,31 +222,29 @@ def evaluator(*args): except NameError: raise NameError("time_evaluate is only supported when RPC is enabled") - def _collect_modules(self, module_type_keys): - """Helper function to collect specific modules, then return it.""" - visited, stack, modules = set(), [], [] - type_keys = module_type_keys if isinstance(module_type_keys, (list, tuple)) else [module_type_keys] + def _collect_dso_modules(self): + """Helper function to collect dso modules, then return it.""" + visited, stack, dso_modules = set(), [], [] # append root module visited.add(self) stack.append(self) while stack: module = stack.pop() - if module.type_key in type_keys: - modules.append(module) + if module._dso_exportable(): + dso_modules.append(module) for m in module.imported_modules: if m not in visited: visited.add(m) stack.append(m) - return modules + return dso_modules - def _dso_exportable_types(self): - return ["llvm", "c"] + def _dso_exportable(self): + return self.type_key == "llvm" or self.type_key == "c" def export_library(self, file_name, fcompile=None, addons=None, - package_params=True, **kwargs): """Export the module and its imported device code one library. @@ -265,13 +261,6 @@ def export_library(self, If fcompile has attribute object_format, will compile host library to that format. Otherwise, will use default format "o". - addons : str, optional - Extra files needed to be passed to compiler. - - package_params: bool, optional. - Whether we will package params into library. - The default value is True. - kwargs : dict, optional Additional arguments passed to fcompile """ @@ -293,22 +282,7 @@ def export_library(self, self.save(file_name) return - graph_runtime_factory_modules = self._collect_modules("GraphRuntimeFactory") - for index, module in enumerate(graph_runtime_factory_modules): - if not package_params: - module.get_function("diable_package_params")() - params_file_name = "deploy_" + module.get_function("get_module_name")() + ".params" - warnings.warn("Disabled package params, we will generate file " + params_file_name, - stacklevel=2) - path_params = os.path.join(os.path.dirname(file_name), params_file_name) - from tvm import relay - with open(path_params, "wb") as fo: - graph_params = {} - for k, v in module.get_function("get_params")().items(): - graph_params[k] = v - fo.write(relay.save_param_dict(graph_params)) - - modules = self._collect_modules(self._dso_exportable_types()) + modules = self._collect_dso_modules() temp = _util.tempdir() files = addons if addons else [] is_system_lib = False @@ -428,12 +402,7 @@ def load_module(path, fmt=""): elif path.endswith(".obj"): fmt = "micro_dev" # Redirect to the load API - module = _ffi_api.ModuleLoadFromFile(path, fmt) - if module.type_key == 'GraphRuntimeFactory': - from tvm.runtime.graph_runtime_factory import GraphRuntimeFactoryModule - return GraphRuntimeFactoryModule(module) - return module - + return _ffi_api.ModuleLoadFromFile(path, fmt) def enabled(target): diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index 0438c5bda0d8..d0c982281b34 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -31,16 +31,12 @@ #include #include -#include -#include #include -#include #include #include #include #include -#include "./graph_runtime_factory.h" namespace tvm { namespace runtime { @@ -175,36 +171,6 @@ class TVM_DLL GraphRuntime : public ModuleNode { std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; } - /*! - * \brief Set the graph params. - * \param params The graph params value we want to set. - */ - void SetParams(const std::unordered_map& params) { - std::unordered_map value = params; - // upload big arrays first to avoid memory issue in rpc mode - std::vector keys; - for (const auto& p : value) { - keys.emplace_back(p.first); - } - std::sort(std::begin(keys), std::end(keys), - [&](const std::string& lhs, const std::string& rhs) -> bool { - auto lhs_shape = value[lhs].Shape(); - auto rhs_shape = value[rhs].Shape(); - auto lhs_prod = std::accumulate(std::begin(lhs_shape), std::end(lhs_shape), 1, - std::multiplies()); - auto rhs_prod = std::accumulate(std::begin(rhs_shape), std::end(rhs_shape), 1, - std::multiplies()); - return lhs_prod > rhs_prod; - }); - - for (const auto& key : keys) { - int in_idx = this->GetInputIndex(key); - if (in_idx >= 0) { - this->SetInput(in_idx, const_cast(value[key].operator->())); - } - } - } - protected: // Memory pool entry. struct PoolEntry { diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index 8c709bd59b95..c00c7c0e5901 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -25,9 +25,9 @@ #include "./graph_runtime_factory.h" #include +#include #include -#include #include #include @@ -36,21 +36,12 @@ namespace tvm { namespace runtime { -void GraphRuntimeFactory::Init(const std::string& kind, const std::string& graph_json, +void GraphRuntimeFactory::Init(const std::string& graph_json, const std::unordered_map& params, const std::string& module_name) { - kind_ = kind; graph_json_ = graph_json; params_ = params; module_name_ = module_name; - graph_runtime_factory_module_list_.push_back(module_name_); -} - -void GraphRuntimeFactory::ImportModule(Module other) { - this->Import(other); - auto module = other.as(); - CHECK(module) << "should only import graph runtime factory module"; - graph_runtime_factory_module_list_.push_back(module->GetModuleName()); } PackedFunc GraphRuntimeFactory::GetFunction( @@ -69,11 +60,6 @@ PackedFunc GraphRuntimeFactory::GetFunction( } *rv = this->RuntimeCreate(args[0], contexts); }); - } else if (name == "import_module") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 1); - this->ImportModule(args[0]); - }); } else if (name == "select_module") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.size(), 1); @@ -102,14 +88,32 @@ PackedFunc GraphRuntimeFactory::GetFunction( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModuleName(); }); + } else if (name == module_name_) { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + auto module = this->SelectModule(module_name_); + std::vector contexts; + for (int i = 0; i < args.num_args; ++i) { + contexts.emplace_back(args[i].operator TVMContext()); + } + *rv = this->RuntimeCreate(module, contexts); + }); + } else if (name == "debug_create") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.size(), 2); + std::string module_name = args[0].operator String(); + auto module = this->SelectModule(module_name); + std::vector contexts; + for (int i = 1; i < args.num_args; ++i) { + contexts.emplace_back(args[i].operator TVMContext()); + } + *rv = this->DebugRuntimeCreate(module, contexts); + }); } else { - return PackedFunc(); + return PackedFunc(); } } void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) { - stream->Write(graph_runtime_factory_module_list_); - stream->Write(kind_); stream->Write(graph_json_); stream->Write(package_params_); if (package_params_) { @@ -132,39 +136,70 @@ void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) { Module GraphRuntimeFactory::RuntimeCreate(Module module, const std::vector& ctxs) { auto factory_module = module.as(); CHECK(factory_module != nullptr); - if (factory_module->GetKind() == "graph") { - auto exec = make_object(); - exec->Init(factory_module->GetJson(), factory_module->GetLib(), ctxs); - exec->SetParams(factory_module->GetParams()); - return Module(exec); + auto exec = make_object(); + exec->Init(factory_module->GetJson(), factory_module->GetLib(), ctxs); + // set params + auto params = factory_module->GetParams(); + auto keys = factory_module->GetSorterParamKeys(params); + for (const auto& key : keys) { + int in_idx = exec->GetInputIndex(key); + if (in_idx >= 0) { + exec->SetInput(in_idx, const_cast(params[key].operator->())); + } } + return Module(exec); +} - return Module(); +Module GraphRuntimeFactory::DebugRuntimeCreate(Module module, const std::vector& ctxs) { + auto factory_module = module.as(); + CHECK(factory_module != nullptr); + const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_debug.create"); + CHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_debug.create in registry. " + "Do you enable debug graph runtime build?"; + // Debug runtime will call GetAllContexs, so we unpack the ctxs. + std::vector unpacked_ctxs; + for(const auto& ctx: ctxs) { + unpacked_ctxs.emplace_back(ctx.device_type); + unpacked_ctxs.emplace_back(ctx.device_id); + } + size_t args_size = unpacked_ctxs.size() + 2; + std::vector values(args_size); + std::vector codes(args_size); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + setter(0, factory_module->GetJson()); + setter(1, factory_module->GetLib()); + for (int i = 0; i < unpacked_ctxs.size(); ++i) { + setter(i + 2, unpacked_ctxs[i]); + } + TVMRetValue rv; + pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv); + Module mod = rv.operator Module(); + // debug graph runtime is one child class of graph runtime. + GraphRuntime* exec = const_cast(mod.as()); + auto params = factory_module->GetParams(); + auto keys = factory_module->GetSorterParamKeys(params); + for (const auto& key : keys) { + int in_idx = exec->GetInputIndex(key); + if (in_idx >= 0) { + exec->SetInput(in_idx, const_cast(params[key].operator->())); + } + } + return mod; } Module GraphRuntimeFactory::SelectModule(const std::string& name) { - auto iter = std::find(graph_runtime_factory_module_list_.begin(), - graph_runtime_factory_module_list_.end(), name); - CHECK(iter != graph_runtime_factory_module_list_.end()); - if (iter == graph_runtime_factory_module_list_.begin()) { - auto exec = make_object(); - exec->Init(this->GetKind(), this->GetJson(), this->GetParams()); - exec->Import(this->GetLib()); - return Module(exec); - } else { - return this->imports()[std::distance(graph_runtime_factory_module_list_.begin(), iter)]; - } + CHECK(name == module_name_) << "Currently we only support single model for now."; + auto exec = make_object(); + exec->Init(this->GetJson(), this->GetParams()); + exec->Import(this->GetLib()); + return Module(exec); } Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); - std::vector graph_runtime_factory_module_list; - std::string kind; std::string graph_json; bool package_params; std::unordered_map params; - CHECK(stream->Read(&graph_runtime_factory_module_list)); - CHECK(stream->Read(&kind)); CHECK(stream->Read(&graph_json)); CHECK(stream->Read(&package_params)); if (package_params) { @@ -180,59 +215,28 @@ Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { } } auto exec = make_object(); - exec->Init(kind, graph_json, params); - exec->SetGraphRuntimeFactoryModuleList(graph_runtime_factory_module_list); + exec->Init(graph_json, params); return Module(exec); } -Module RuntimeCreate(Module module, const std::vector& ctxs) { - auto mod = module.as(); - CHECK(mod != nullptr); - if (mod->GetKind() == "graph") { - auto exec = make_object(); - exec->Init(mod->GetJson(), mod->GetLib(), ctxs); - exec->SetParams(mod->GetParams()); - return Module(exec); - } else { - LOG(ERROR) << "Doesn't support graph kind of " << mod->GetKind(); - } - - return Module(); -} - TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK_GE(args.num_args, 4) << "The expected number of arguments for " - "graph_runtime_factory.create needs at least 4, " + CHECK_GE(args.num_args, 3) << "The expected number of arguments for " + "graph_runtime_factory.create needs at least 3, " "but it has " << args.num_args; auto exec = make_object(); - // The argument order is graph_runtime_kind, graph_json, module, module_name, params. - CHECK_EQ((args.size() - 4) % 2, 0); + // The argument order is graph_json, module, module_name, params. + CHECK_EQ((args.size() - 3) % 2, 0); std::unordered_map params; - for (size_t i = 4; i < static_cast(args.size()); i += 2) { + for (size_t i = 3; i < static_cast(args.size()); i += 2) { std::string name = args[i].operator String(); params[name] = args[i + 1].operator tvm::runtime::NDArray(); } - exec->Init(args[0], args[1], params, args[3]); - exec->Import(args[2]); + exec->Init(args[0], params, args[2]); + exec->Import(args[1]); *rv = Module(exec); }); -TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.runtime_create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - std::vector contexts; - TVMContext ctx; - // arg is: module, ctxs - CHECK_EQ((args.size() - 1) % 2, 0); - for (int i = 1; i < args.num_args; i += 2) { - int dev_type = args[i]; - ctx.device_type = static_cast(dev_type); - ctx.device_id = args[i + 1]; - contexts.push_back(ctx); - } - *rv = RuntimeCreate(args[0], contexts); - }); - TVM_REGISTER_GLOBAL("runtime.module.loadbinary_GraphRuntimeFactory") .set_body_typed(GraphRuntimeFactoryModuleLoadBinary); diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index 1f59eb55e282..474a24078365 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -30,6 +30,8 @@ #include #include +#include +#include #include #include #include @@ -43,18 +45,12 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { * \brief Initialize the GraphRuntimeFactory with graph and context. * \param graph_json The execution graph. * \param params The params of graph. - * \param kind The runtime kind to be created. + * \param module_name The module name of graph. */ - void Init(const std::string& kind, const std::string& graph_json, + void Init(const std::string& graph_json, const std::unordered_map& params, const std::string& module_name = "default"); - /*! - * \brief Import other GraphRuntimeFactory module. - * \param other The GraphRuntimeFactory module we want to import. - */ - void ImportModule(Module other); - /*! * \brief Get member function to front-end * \param name The name of the function. @@ -83,6 +79,15 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { */ Module RuntimeCreate(Module module, const std::vector& ctxs); + /*! + * \brief Create a specific debug runtime module + * \param module The module we will be used for creating runtime + * \param ctxs The context of the host and devices where graph nodes will be + * executed on. + * \return created debug runtime module + */ + Module DebugRuntimeCreate(Module module, const std::vector& ctxs); + /*! * \brief Select the specific module * \param name The name of the module @@ -94,37 +99,48 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { std::unordered_map GetParams() const { return params_; } + /*! + * \brief Get sorted keys of params. + * \param params The graph params value we want to sort. + * \return The sorted keys of params + */ + std::vector GetSorterParamKeys(const std::unordered_map& params) const { + std::unordered_map value = params; + // upload big arrays first to avoid memory issue in rpc mode + std::vector keys; + for (const auto& p : value) { + keys.emplace_back(p.first); + } + std::sort(std::begin(keys), std::end(keys), + [&](const std::string& lhs, const std::string& rhs) -> bool { + auto lhs_shape = value[lhs].Shape(); + auto rhs_shape = value[rhs].Shape(); + auto lhs_prod = std::accumulate(std::begin(lhs_shape), std::end(lhs_shape), 1, + std::multiplies()); + auto rhs_prod = std::accumulate(std::begin(rhs_shape), std::end(rhs_shape), 1, + std::multiplies()); + return lhs_prod > rhs_prod; + }); + + return keys; + } + Module GetLib() const { CHECK_GT(this->imports().size(), 0); return this->imports_[0]; } - const std::string& GetKind() const { return kind_; } - const std::string& GetModuleName() const { return module_name_; } - const std::vector& GetGraphRuntimeFactoryModuleList() const { - return graph_runtime_factory_module_list_; - } - - void SetGraphRuntimeFactoryModuleList( - const std::vector& graph_runtime_factory_module_list) { - graph_runtime_factory_module_list_ = graph_runtime_factory_module_list; - } - protected: /*! \brief The execution graph. */ std::string graph_json_; /*! \brief The params. */ std::unordered_map params_; - /*! \brief runtime kind */ - std::string kind_; /*! \brief module name */ std::string module_name_; /*! \brief whether to package params */ bool package_params_ = true; - /*! \brief graph runtime factory module lists */ - std::vector graph_runtime_factory_module_list_; }; } // namespace runtime diff --git a/tests/python/unittest/test_module_runtime_interface.py b/tests/python/unittest/test_module_runtime_interface.py index cbd7e11c8fb6..50b4f084f6ca 100644 --- a/tests/python/unittest/test_module_runtime_interface.py +++ b/tests/python/unittest/test_module_runtime_interface.py @@ -20,6 +20,7 @@ import tvm from tvm.contrib import graph_runtime from tvm.runtime import graph_runtime_factory +from tvm.contrib.debugger import debug_runtime def get_workload(num_layers=18): mod, params = relay.testing.resnet.get_workload(num_layers=num_layers) @@ -66,10 +67,10 @@ def test_cpu(): set_input("data", tvm.nd.array(data)) run() out = get_output(0).asnumpy() - - # graph runtime tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - gmod = graph_runtime.create4unified(complied_graph_lib['default'], ctx) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx)) gmod.set_input("data", data) gmod.run() out = gmod.get_output(0).asnumpy() @@ -92,40 +93,6 @@ def test_gpu(): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) -def test_multi_models(): - resnet18_mod, resnet18_params = get_workload() - resnet50_mod, resnet50_params = get_workload(50) - with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build( - resnet18_mod, "llvm", params=resnet18_params, mod_name='resnet18') - with relay.build_config(opt_level=3): - resnet50_gpu_lib = relay.build_module.build( - resnet50_mod, "cuda", params=resnet50_params, mod_name='resnet50') - complied_graph_lib.import_module(resnet50_gpu_lib) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - # resnet18 - cpu_ctx = tvm.cpu() - gmod = complied_graph_lib['resnet18'](cpu_ctx) - set_input = gmod["set_input"] - get_input = gmod["get_input"] - run = gmod["run"] - get_output = gmod["get_output"] - set_input("data", tvm.nd.array(data)) - run() - out = get_output(0).asnumpy() - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - - # resnet50 - gpu_ctx = tvm.gpu() - gmod = complied_graph_lib['resnet50'](gpu_ctx) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - set_input("data", tvm.nd.array(data)) - run() - out = get_output(0).asnumpy() - tvm.testing.assert_allclose(out, verify(data, 50), atol=1e-5) - def test_cpu_export(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): @@ -245,7 +212,7 @@ def test_previous_gpu_export(format=".so"): def test_rpc_export(format=".so"): mod, params = get_workload() with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) from tvm.contrib import util temp = util.tempdir() @@ -263,7 +230,7 @@ def test_rpc_export(format=".so"): remote.upload(path_lib) loaded_lib = remote.load_module(path_lib) data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - ctx = remote.cpu() + ctx = remote.gpu() gmod = loaded_lib['default'](ctx) set_input = gmod["set_input"] run = gmod["run"] @@ -273,7 +240,7 @@ def test_rpc_export(format=".so"): out = get_output(0).asnumpy() tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - gmod = graph_runtime.create4unified(loaded_lib['default'], ctx) + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) gmod.set_input("data", data) gmod.run() out = gmod.get_output(0).asnumpy() @@ -398,22 +365,52 @@ def test_multi_models_package_params(format=".so"): out = get_output(0).asnumpy() tvm.testing.assert_allclose(out, verify(data, num_layers=50), atol=1e-5) + +def test_debug_graph_runtime(): + mod, params = get_workload() + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + # raw api + ctx = tvm.cpu() + # gmod = complied_graph_lib['debug_create']('default', ctx) + # set_input = gmod["set_input"] + # run = gmod["run"] + # get_output = gmod["get_output"] + # set_input("data", tvm.nd.array(data)) + # run() + # out = get_output(0).asnumpy() + # tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + debug_g_mod = debug_runtime.GraphModuleDebug(complied_graph_lib['debug_create']('default', ctx), [ctx], complied_graph_lib['get_json'](), None) + # debug_g_mod = debug_runtime.create(complied_graph_lib['get_json'](), complied_graph_lib['get_lib'](), ctx) + debug_g_mod.set_input("data", data) + debug_g_mod.run() + out = debug_g_mod.get_output(0).asnumpy() + # # graph runtime wrapper + # tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + #gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx)) + # gmod.set_input("data", data) + # gmod.run() + # out = gmod.get_output(0).asnumpy() + # + # tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + if __name__ == "__main__": - test_legacy_compatibility() - test_cpu() - test_gpu() - test_multi_models() - test_cpu_export(".so") - test_cpu_export(".tar") - test_gpu_export(".so") - test_gpu_export(".tar") - test_rpc_export(".so") - test_rpc_export(".tar") - test_previous_cpu_export(".so") - test_previous_cpu_export(".tar") - test_previous_gpu_export(".so") - test_previous_gpu_export(".tar") - test_previous_rpc_export(".so") - test_previous_rpc_export(".tar") - test_package_params(".so") - test_multi_models_package_params(".so") \ No newline at end of file + # test_legacy_compatibility() + # test_cpu() + # test_gpu() + # test_cpu_export(".so") + # test_cpu_export(".tar") + # test_gpu_export(".so") + # test_gpu_export(".tar") + # test_rpc_export(".so") + # test_rpc_export(".tar") + test_debug_graph_runtime() + # test_previous_cpu_export(".so") + # test_previous_cpu_export(".tar") + # test_previous_gpu_export(".so") + # test_previous_gpu_export(".tar") + # test_previous_rpc_export(".so") + # test_previous_rpc_export(".tar") + # test_package_params(".so") + # test_multi_models_package_params(".so") \ No newline at end of file From ef18312f7ca3803e719226867c571eb3be7f6e8c Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Thu, 9 Jul 2020 13:26:09 +0800 Subject: [PATCH 15/29] refactor --- python/tvm/relay/build_module.py | 3 + python/tvm/runtime/graph_runtime_factory.py | 10 +- src/runtime/graph/graph_runtime_factory.cc | 124 ++--- src/runtime/graph/graph_runtime_factory.h | 26 +- .../unittest/test_module_runtime_interface.py | 416 -------------- .../test_runtime_module_based_interface.py | 510 ++++++++++++++++++ 6 files changed, 573 insertions(+), 516 deletions(-) delete mode 100644 tests/python/unittest/test_module_runtime_interface.py create mode 100644 tests/python/unittest/test_runtime_module_based_interface.py diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index e5bab4e18e62..f0c1a2cbae2b 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -209,6 +209,9 @@ def build(mod, target=None, target_host=None, params=None, mod_name='default'): Input parameters to the graph that do not change during inference time. Used for constant folding. + mod_name: Optional[str] + The module name we will build + Returns ------- graph_json : str diff --git a/python/tvm/runtime/graph_runtime_factory.py b/python/tvm/runtime/graph_runtime_factory.py index ac5063fb0164..4aa3c7187eb1 100644 --- a/python/tvm/runtime/graph_runtime_factory.py +++ b/python/tvm/runtime/graph_runtime_factory.py @@ -15,14 +15,9 @@ # specific language governing permissions and limitations # under the License. """Graph runtime factory.""" -import numpy as np import warnings from tvm._ffi.base import string_types from tvm._ffi.registry import get_global_func -from tvm._ffi.runtime_ctypes import TVMContext -from tvm.contrib.graph_runtime import get_device_ctx -from .packed_func import _set_class_module -from tvm.rpc import base as rpc_base from .module import Module from . import ndarray @@ -89,7 +84,8 @@ def __del__(self): def __iter__(self): warnings.warn( - "legacy graph runtime behaviour of producing json / lib / params will be removed in the next release ", + "legacy graph runtime behaviour of producing json / lib / params will be " + "removed in the next release ", DeprecationWarning, 2) self.graph_json = self.module["get_json"]() self.lib = self.module["get_lib"]() @@ -105,4 +101,4 @@ def __next__(self): objs = [self.graph_json, self.lib, self.params] obj = objs[self.iter_cnt] self.iter_cnt += 1 - return obj \ No newline at end of file + return obj diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index c00c7c0e5901..16dc472b7bb8 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -31,8 +31,6 @@ #include #include -#include "./graph_runtime.h" - namespace tvm { namespace runtime { @@ -46,32 +44,12 @@ void GraphRuntimeFactory::Init(const std::string& graph_json, PackedFunc GraphRuntimeFactory::GetFunction( const std::string& name, const tvm::runtime::ObjectPtr& sptr_to_self) { - if (name == "runtime_create") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - std::vector contexts; - TVMContext ctx; - // arg is: module, ctxs - CHECK_EQ((args.size() - 1) % 2, 0); - for (int i = 1; i < args.num_args; i += 2) { - int dev_type = args[i]; - ctx.device_type = static_cast(dev_type); - ctx.device_id = args[i + 1]; - contexts.push_back(ctx); - } - *rv = this->RuntimeCreate(args[0], contexts); - }); - } else if (name == "select_module") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.size(), 1); - *rv = this->SelectModule(args[0]); - }); - } else if (name == "get_json") { + if (name == "get_json") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetJson(); }); } else if (name == "get_lib") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLib(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); }); } else if (name == "get_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Map ret; @@ -80,14 +58,6 @@ PackedFunc GraphRuntimeFactory::GetFunction( } *rv = ret; }); - } else if (name == "diable_package_params") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->package_params_ = false; }); - } else if (name == "get_module_name") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetModuleName(); - }); } else if (name == module_name_) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { auto module = this->SelectModule(module_name_); @@ -108,29 +78,34 @@ PackedFunc GraphRuntimeFactory::GetFunction( } *rv = this->DebugRuntimeCreate(module, contexts); }); + } else if (name == "remove_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + auto exec = make_object(); + exec->Init(this->GetJson(), {}, this->GetModuleName()); + exec->Import(this->GetLib()); + *rv = Module(exec); + }); } else { - return PackedFunc(); + return PackedFunc(); } } void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) { stream->Write(graph_json_); - stream->Write(package_params_); - if (package_params_) { - std::vector names; - std::vector arrays; - for (const auto& v : params_) { - names.emplace_back(v.first); - arrays.emplace_back(const_cast(v.second.operator->())); - } - uint64_t sz = arrays.size(); - CHECK(sz == names.size()); - stream->Write(sz); - stream->Write(names); - for (size_t i = 0; i < sz; ++i) { - tvm::runtime::SaveDLTensor(stream, arrays[i]); - } + std::vector names; + std::vector arrays; + for (const auto& v : params_) { + names.emplace_back(v.first); + arrays.emplace_back(const_cast(v.second.operator->())); + } + uint64_t sz = arrays.size(); + CHECK(sz == names.size()); + stream->Write(sz); + stream->Write(names); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(stream, arrays[i]); } + stream->Write(module_name_); } Module GraphRuntimeFactory::RuntimeCreate(Module module, const std::vector& ctxs) { @@ -139,14 +114,7 @@ Module GraphRuntimeFactory::RuntimeCreate(Module module, const std::vector(); exec->Init(factory_module->GetJson(), factory_module->GetLib(), ctxs); // set params - auto params = factory_module->GetParams(); - auto keys = factory_module->GetSorterParamKeys(params); - for (const auto& key : keys) { - int in_idx = exec->GetInputIndex(key); - if (in_idx >= 0) { - exec->SetInput(in_idx, const_cast(params[key].operator->())); - } - } + SetParams(exec.get(), factory_module->GetParams()); return Module(exec); } @@ -156,9 +124,9 @@ Module GraphRuntimeFactory::DebugRuntimeCreate(Module module, const std::vector< const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_debug.create"); CHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_debug.create in registry. " "Do you enable debug graph runtime build?"; - // Debug runtime will call GetAllContexs, so we unpack the ctxs. + // Debug runtime create packed function will call GetAllContexs, so we unpack the ctxs. std::vector unpacked_ctxs; - for(const auto& ctx: ctxs) { + for (const auto& ctx : ctxs) { unpacked_ctxs.emplace_back(ctx.device_type); unpacked_ctxs.emplace_back(ctx.device_id); } @@ -168,22 +136,14 @@ Module GraphRuntimeFactory::DebugRuntimeCreate(Module module, const std::vector< runtime::TVMArgsSetter setter(values.data(), codes.data()); setter(0, factory_module->GetJson()); setter(1, factory_module->GetLib()); - for (int i = 0; i < unpacked_ctxs.size(); ++i) { + for (size_t i = 0; i < unpacked_ctxs.size(); ++i) { setter(i + 2, unpacked_ctxs[i]); } TVMRetValue rv; pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv); Module mod = rv.operator Module(); // debug graph runtime is one child class of graph runtime. - GraphRuntime* exec = const_cast(mod.as()); - auto params = factory_module->GetParams(); - auto keys = factory_module->GetSorterParamKeys(params); - for (const auto& key : keys) { - int in_idx = exec->GetInputIndex(key); - if (in_idx >= 0) { - exec->SetInput(in_idx, const_cast(params[key].operator->())); - } - } + SetParams(const_cast(mod.as()), factory_module->GetParams()); return mod; } @@ -198,24 +158,22 @@ Module GraphRuntimeFactory::SelectModule(const std::string& name) { Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); std::string graph_json; - bool package_params; std::unordered_map params; + std::string module_name; CHECK(stream->Read(&graph_json)); - CHECK(stream->Read(&package_params)); - if (package_params) { - uint64_t sz; - CHECK(stream->Read(&sz)); - std::vector names; - CHECK(stream->Read(&names)); - CHECK(sz == names.size()); - for (size_t i = 0; i < sz; ++i) { - tvm::runtime::NDArray temp; - temp.Load(stream); - params[names[i]] = temp; - } + uint64_t sz; + CHECK(stream->Read(&sz)); + std::vector names; + CHECK(stream->Read(&names)); + CHECK(sz == names.size()); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::NDArray temp; + temp.Load(stream); + params[names[i]] = temp; } + CHECK(stream->Read(&module_name)); auto exec = make_object(); - exec->Init(graph_json, params); + exec->Init(graph_json, params, module_name); return Module(exec); } diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index 474a24078365..42c617ef759b 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -25,6 +25,8 @@ #ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ #define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ +#include "./graph_runtime.h" + #include #include #include @@ -34,6 +36,7 @@ #include #include #include +#include #include namespace tvm { @@ -57,7 +60,7 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; /*! * \return The type key of the executor. @@ -100,11 +103,12 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { std::unordered_map GetParams() const { return params_; } /*! - * \brief Get sorted keys of params. - * \param params The graph params value we want to sort. - * \return The sorted keys of params + * \brief Set params. + * \param graph_runtime The graph runtime we want to set the params into. + * \param params The graph params value we want to set. */ - std::vector GetSorterParamKeys(const std::unordered_map& params) const { + void SetParams(GraphRuntime* graph_runtime, + const std::unordered_map& params) const { std::unordered_map value = params; // upload big arrays first to avoid memory issue in rpc mode std::vector keys; @@ -121,12 +125,16 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { std::multiplies()); return lhs_prod > rhs_prod; }); - - return keys; + for (const auto& key : keys) { + int in_idx = graph_runtime->GetInputIndex(key); + if (in_idx >= 0) { + graph_runtime->SetInput(in_idx, const_cast(value[key].operator->())); + } + } } Module GetLib() const { - CHECK_GT(this->imports().size(), 0); + CHECK_EQ(this->imports().size(), 0); return this->imports_[0]; } @@ -139,8 +147,6 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { std::unordered_map params_; /*! \brief module name */ std::string module_name_; - /*! \brief whether to package params */ - bool package_params_ = true; }; } // namespace runtime diff --git a/tests/python/unittest/test_module_runtime_interface.py b/tests/python/unittest/test_module_runtime_interface.py deleted file mode 100644 index 50b4f084f6ca..000000000000 --- a/tests/python/unittest/test_module_runtime_interface.py +++ /dev/null @@ -1,416 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import numpy as np -from tvm import relay -from tvm.relay import testing -import tvm -from tvm.contrib import graph_runtime -from tvm.runtime import graph_runtime_factory -from tvm.contrib.debugger import debug_runtime - -def get_workload(num_layers=18): - mod, params = relay.testing.resnet.get_workload(num_layers=num_layers) - return mod, params - -def verify(data, num_layers=18): - mod, params = get_workload(num_layers) - with relay.build_config(opt_level=3): - graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) - - ctx = tvm.cpu() - module = graph_runtime.create(graph, lib, ctx) - module.set_input("data", data) - module.set_input(**graph_params) - module.run() - out = module.get_output(0).asnumpy() - - return out - -def test_legacy_compatibility(): - mod, params = get_workload() - with relay.build_config(opt_level=3): - graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - ctx = tvm.cpu() - module = graph_runtime.create(graph, lib, ctx) - module.set_input("data", data) - module.set_input(**graph_params) - module.run() - out = module.get_output(0).asnumpy() - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -def test_cpu(): - mod, params = get_workload() - with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - # raw api - ctx = tvm.cpu() - gmod = complied_graph_lib['default'](ctx) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - set_input("data", tvm.nd.array(data)) - run() - out = get_output(0).asnumpy() - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - - # graph runtime wrapper - gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx)) - gmod.set_input("data", data) - gmod.run() - out = gmod.get_output(0).asnumpy() - - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -def test_gpu(): - mod, params = get_workload() - with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - ctx = tvm.gpu() - gmod = complied_graph_lib['default'](ctx) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - set_input("data", tvm.nd.array(data)) - run() - out = get_output(0).asnumpy() - - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -def test_cpu_export(format=".so"): - mod, params = get_workload() - with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) - - from tvm.contrib import util - temp = util.tempdir() - if format == ".so": - file_name = "deploy_lib.so" - else: - assert format == ".tar" - file_name = "deploy_lib.tar" - path_lib = temp.relpath(file_name) - complied_graph_lib.export_library(path_lib) - loaded_lib = tvm.runtime.load_module(path_lib) - ctx = tvm.cpu(0) - gmod = loaded_lib['default'](ctx) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - set_input("data", tvm.nd.array(data)) - run() - out = get_output(0).asnumpy() - - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -def test_gpu_export(format=".so"): - mod, params = get_workload() - with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) - - from tvm.contrib import util - temp = util.tempdir() - if format == ".so": - file_name = "deploy_lib.so" - else: - assert format == ".tar" - file_name = "deploy_lib.tar" - path_lib = temp.relpath(file_name) - complied_graph_lib.export_library(path_lib) - loaded_lib = tvm.runtime.load_module(path_lib) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - ctx = tvm.gpu() - gmod = loaded_lib['default'](ctx) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - set_input("data", tvm.nd.array(data)) - run() - out = get_output(0).asnumpy() - - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -def test_previous_cpu_export(format=".so"): - mod, params = get_workload() - with relay.build_config(opt_level=3): - graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) - - from tvm.contrib import util - temp = util.tempdir() - if format == ".so": - file_name = "deploy_lib.so" - else: - assert format == ".tar" - file_name = "deploy_lib.tar" - path_lib = temp.relpath(file_name) - lib.export_library(path_lib) - with open(temp.relpath("deploy_graph.json"), "w") as fo: - fo.write(graph) - with open(temp.relpath("deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(graph_params)) - loaded_json = open(temp.relpath("deploy_graph.json")).read() - loaded_lib = tvm.runtime.load_module(path_lib) - loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - ctx = tvm.cpu() - module = graph_runtime.create(loaded_json, loaded_lib, ctx) - module.load_params(loaded_params) - module.set_input("data", data) - module.run() - out = module.get_output(0).asnumpy() - - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -def test_previous_gpu_export(format=".so"): - mod, params = get_workload() - with relay.build_config(opt_level=3): - graph, lib, graph_params = relay.build_module.build(mod, "cuda", params=params) - - from tvm.contrib import util - temp = util.tempdir() - if format == ".so": - file_name = "deploy_lib.so" - else: - assert format == ".tar" - file_name = "deploy_lib.tar" - path_lib = temp.relpath(file_name) - lib.export_library(path_lib) - with open(temp.relpath("deploy_graph.json"), "w") as fo: - fo.write(graph) - with open(temp.relpath("deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(graph_params)) - loaded_json = open(temp.relpath("deploy_graph.json")).read() - loaded_lib = tvm.runtime.load_module(path_lib) - loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - ctx = tvm.gpu() - module = graph_runtime.create(loaded_json, loaded_lib, ctx) - module.load_params(loaded_params) - module.set_input("data", data) - module.run() - out = module.get_output(0).asnumpy() - - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -def test_rpc_export(format=".so"): - mod, params = get_workload() - with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) - - from tvm.contrib import util - temp = util.tempdir() - if format == ".so": - file_name = "deploy_lib.so" - else: - assert format == ".tar" - file_name = "deploy_lib.tar" - path_lib = temp.relpath(file_name) - complied_graph_lib.export_library(path_lib) - - from tvm import rpc - server = rpc.Server("localhost", use_popen=True) - remote = rpc.connect(server.host, server.port) - remote.upload(path_lib) - loaded_lib = remote.load_module(path_lib) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - ctx = remote.gpu() - gmod = loaded_lib['default'](ctx) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - set_input("data", tvm.nd.array(data, ctx=ctx)) - run() - out = get_output(0).asnumpy() - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - - gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) - gmod.set_input("data", data) - gmod.run() - out = gmod.get_output(0).asnumpy() - - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -def test_previous_rpc_export(format=".so"): - mod, params = get_workload() - with relay.build_config(opt_level=3): - graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) - - from tvm.contrib import util - temp = util.tempdir() - if format == ".so": - file_name = "deploy_lib.so" - else: - assert format == ".tar" - file_name = "deploy_lib.tar" - path_lib = temp.relpath(file_name) - lib.export_library(path_lib) - with open(temp.relpath("deploy_graph.json"), "w") as fo: - fo.write(graph) - with open(temp.relpath("deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(graph_params)) - - from tvm import rpc - server = rpc.Server("localhost", use_popen=True) - remote = rpc.connect(server.host, server.port) - remote.upload(path_lib) - loaded_json = open(temp.relpath("deploy_graph.json")).read() - loaded_lib = remote.load_module(path_lib) - loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - ctx = remote.cpu() - module = graph_runtime.create(loaded_json, loaded_lib, ctx) - module.load_params(loaded_params) - module.set_input("data", data) - module.run() - out = module.get_output(0).asnumpy() - - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -def test_package_params(format=".so"): - mod, params = get_workload() - with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) - - from tvm.contrib import util - temp = util.tempdir() - if format == ".so": - file_name = "deploy_lib.so" - else: - assert format == ".tar" - file_name = "deploy_lib.tar" - path_lib = temp.relpath(file_name) - complied_graph_lib.export_library(path_lib, package_params=False) - loaded_lib = tvm.runtime.load_module(path_lib) - ctx = tvm.cpu(0) - gmod = loaded_lib['default'](ctx) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - load_params = gmod["load_params"] - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - loaded_params = bytearray(open(temp.relpath("deploy_default.params"), "rb").read()) - set_input("data", tvm.nd.array(data)) - load_params(loaded_params) - run() - out = get_output(0).asnumpy() - - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -def test_multi_models_package_params(format=".so"): - resnet18_mod, resnet18_params = get_workload() - resnet50_mod, resnet50_params = get_workload(50) - with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build( - resnet18_mod, "llvm", params=resnet18_params, mod_name='resnet18') - with relay.build_config(opt_level=3): - resnet50_gpu_lib = relay.build_module.build( - resnet50_mod, "cuda", params=resnet50_params, mod_name='resnet50') - complied_graph_lib.import_module(resnet50_gpu_lib) - - from tvm.contrib import util - temp = util.tempdir() - if format == ".so": - file_name = "deploy_lib.so" - else: - assert format == ".tar" - file_name = "deploy_lib.tar" - path_lib = temp.relpath(file_name) - complied_graph_lib.export_library(path_lib, package_params=False) - loaded_lib = tvm.runtime.load_module(path_lib) - - # resnet18 - ctx = tvm.cpu(0) - gmod = loaded_lib['resnet18'](ctx) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - load_params = gmod["load_params"] - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - loaded_params = bytearray(open(temp.relpath("deploy_resnet18.params"), "rb").read()) - set_input("data", tvm.nd.array(data)) - load_params(loaded_params) - run() - out = get_output(0).asnumpy() - tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - - # resnet50 - ctx = tvm.gpu() - gmod = loaded_lib['resnet50'](ctx) - set_input = gmod["set_input"] - run = gmod["run"] - get_output = gmod["get_output"] - load_params = gmod["load_params"] - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - loaded_params = bytearray(open(temp.relpath("deploy_resnet50.params"), "rb").read()) - set_input("data", tvm.nd.array(data)) - load_params(loaded_params) - run() - out = get_output(0).asnumpy() - tvm.testing.assert_allclose(out, verify(data, num_layers=50), atol=1e-5) - - -def test_debug_graph_runtime(): - mod, params = get_workload() - with relay.build_config(opt_level=3): - complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) - data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") - # raw api - ctx = tvm.cpu() - # gmod = complied_graph_lib['debug_create']('default', ctx) - # set_input = gmod["set_input"] - # run = gmod["run"] - # get_output = gmod["get_output"] - # set_input("data", tvm.nd.array(data)) - # run() - # out = get_output(0).asnumpy() - # tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - debug_g_mod = debug_runtime.GraphModuleDebug(complied_graph_lib['debug_create']('default', ctx), [ctx], complied_graph_lib['get_json'](), None) - # debug_g_mod = debug_runtime.create(complied_graph_lib['get_json'](), complied_graph_lib['get_lib'](), ctx) - debug_g_mod.set_input("data", data) - debug_g_mod.run() - out = debug_g_mod.get_output(0).asnumpy() - # # graph runtime wrapper - # tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - #gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx)) - # gmod.set_input("data", data) - # gmod.run() - # out = gmod.get_output(0).asnumpy() - # - # tvm.testing.assert_allclose(out, verify(data), atol=1e-5) - -if __name__ == "__main__": - # test_legacy_compatibility() - # test_cpu() - # test_gpu() - # test_cpu_export(".so") - # test_cpu_export(".tar") - # test_gpu_export(".so") - # test_gpu_export(".tar") - # test_rpc_export(".so") - # test_rpc_export(".tar") - test_debug_graph_runtime() - # test_previous_cpu_export(".so") - # test_previous_cpu_export(".tar") - # test_previous_gpu_export(".so") - # test_previous_gpu_export(".tar") - # test_previous_rpc_export(".so") - # test_previous_rpc_export(".tar") - # test_package_params(".so") - # test_multi_models_package_params(".so") \ No newline at end of file diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py new file mode 100644 index 000000000000..aa0e472815ac --- /dev/null +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -0,0 +1,510 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +from tvm import relay +from tvm.relay import testing +import tvm +from tvm.contrib import graph_runtime +from tvm.runtime import graph_runtime_factory +from tvm.contrib.debugger import debug_runtime + +def verify(data): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) + + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.set_input("data", data) + module.set_input(**graph_params) + module.run() + out = module.get_output(0).asnumpy() + + return out + +def test_legacy_compatibility(): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.set_input("data", data) + module.set_input(**graph_params) + module.run() + out = module.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_cpu(): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + # raw api + ctx = tvm.cpu() + gmod = complied_graph_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_gpu(): + if not tvm.runtime.enabled("cuda"): + print("Skip because cuda is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.gpu() + + # raw api + gmod = complied_graph_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(complied_graph_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +def test_mod_export(): + def verify_cpu_export(obj_format): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + loaded_lib = tvm.runtime.load_module(path_lib) + ctx = tvm.cpu(0) + gmod = loaded_lib['default'](ctx) + + # raw api + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_gpu_export(obj_format): + if not tvm.runtime.enabled("cuda"): + print("Skip because cuda is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + loaded_lib = tvm.runtime.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.gpu() + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_rpc_cpu_export(obj_format): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + + from tvm import rpc + server = rpc.Server("localhost", use_popen=True) + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = remote.cpu() + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data, ctx=ctx)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_rpc_gpu_export(obj_format): + if not tvm.runtime.enabled("cuda"): + print("Skip because cuda is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib.export_library(path_lib) + + from tvm import rpc + server = rpc.Server("localhost", use_popen=True) + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = remote.gpu() + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data, ctx=ctx)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + gmod.set_input("data", data) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + for obj_format in [".so", ".tar"]: + verify_cpu_export(obj_format) + verify_gpu_export(obj_format) + verify_rpc_cpu_export(obj_format) + verify_rpc_gpu_export(obj_format) + +def test_remove_package_params(): + def verify_cpu_remove_package_params(obj_format): + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib_no_params = complied_graph_lib["remove_params"]() + complied_graph_lib_no_params.export_library(path_lib) + with open(temp.relpath("deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(complied_graph_lib["get_params"]())) + loaded_lib = tvm.runtime.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.cpu(0) + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + set_input("data", tvm.nd.array(data)) + load_params(loaded_params) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + gmod.set_input("data", data) + gmod.load_params(loaded_params) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_gpu_remove_package_params(obj_format): + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib_no_params = complied_graph_lib["remove_params"]() + complied_graph_lib_no_params.export_library(path_lib) + with open(temp.relpath("deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(complied_graph_lib["get_params"]())) + loaded_lib = tvm.runtime.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.gpu(0) + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + set_input("data", tvm.nd.array(data)) + load_params(loaded_params) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + gmod.set_input("data", data) + gmod.load_params(loaded_params) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_rpc_cpu_remove_package_params(obj_format): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib_no_params = complied_graph_lib["remove_params"]() + complied_graph_lib_no_params.export_library(path_lib) + path_params = temp.relpath("deploy_param.params") + with open(path_params, "wb") as fo: + fo.write(relay.save_param_dict(complied_graph_lib["get_params"]())) + + from tvm import rpc + server = rpc.Server("localhost", use_popen=True) + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = remote.cpu() + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + loaded_params = bytearray(open(path_params, "rb").read()) + set_input("data", tvm.nd.array(data, ctx=ctx)) + load_params(loaded_params) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + loaded_params = bytearray(open(path_params, "rb").read()) + gmod.set_input("data", data) + gmod.load_params(loaded_params) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + def verify_rpc_gpu_remove_package_params(obj_format): + if not tvm.runtime.enabled("cuda"): + print("Skip because cuda is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + complied_graph_lib_no_params = complied_graph_lib["remove_params"]() + complied_graph_lib_no_params.export_library(path_lib) + path_params = temp.relpath("deploy_param.params") + with open(path_params, "wb") as fo: + fo.write(relay.save_param_dict(complied_graph_lib["get_params"]())) + + from tvm import rpc + server = rpc.Server("localhost", use_popen=True) + remote = rpc.connect(server.host, server.port) + remote.upload(path_lib) + loaded_lib = remote.load_module(path_lib) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = remote.gpu() + + # raw api + gmod = loaded_lib['default'](ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + load_params = gmod["load_params"] + loaded_params = bytearray(open(path_params, "rb").read()) + set_input("data", tvm.nd.array(data, ctx=ctx)) + load_params(loaded_params) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # graph runtime wrapper + gmod = graph_runtime.GraphModule(loaded_lib['default'](ctx)) + loaded_params = bytearray(open(path_params, "rb").read()) + gmod.set_input("data", data) + gmod.load_params(loaded_params) + gmod.run() + out = gmod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + for obj_format in [".so", ".tar"]: + verify_cpu_remove_package_params(obj_format) + verify_gpu_remove_package_params(obj_format) + verify_rpc_cpu_remove_package_params(obj_format) + verify_rpc_gpu_remove_package_params(obj_format) + +def test_debug_graph_runtime(): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + + # raw api + ctx = tvm.cpu() + gmod = complied_graph_lib['debug_create']('default', ctx) + set_input = gmod["set_input"] + run = gmod["run"] + get_output = gmod["get_output"] + set_input("data", tvm.nd.array(data)) + run() + out = get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + # debug graph runtime wrapper + debug_g_mod = debug_runtime.GraphModuleDebug(complied_graph_lib['debug_create']('default', ctx), [ctx], complied_graph_lib['get_json'](), None) + debug_g_mod.set_input("data", data) + debug_g_mod.run() + out = debug_g_mod.get_output(0).asnumpy() + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + +if __name__ == "__main__": + test_legacy_compatibility() + test_cpu() + test_gpu() + test_mod_export() + test_remove_package_params() + test_debug_graph_runtime() \ No newline at end of file From 6f7ceee8b506a4d1047f81e7641e0155065aa48d Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Thu, 9 Jul 2020 14:00:15 +0800 Subject: [PATCH 16/29] clang-format --- src/runtime/graph/graph_runtime_factory.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index 42c617ef759b..1074cc41bd03 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -25,20 +25,20 @@ #ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ #define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ -#include "./graph_runtime.h" - #include #include #include #include #include +#include #include #include #include -#include #include +#include "./graph_runtime.h" + namespace tvm { namespace runtime { From 893c5311f8ec1b623788aa052b6ba445625cdca4 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Thu, 9 Jul 2020 14:09:25 +0800 Subject: [PATCH 17/29] comment --- python/tvm/runtime/graph_runtime_factory.py | 10 ++-------- .../unittest/test_runtime_module_based_interface.py | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/python/tvm/runtime/graph_runtime_factory.py b/python/tvm/runtime/graph_runtime_factory.py index 4aa3c7187eb1..a03014a164db 100644 --- a/python/tvm/runtime/graph_runtime_factory.py +++ b/python/tvm/runtime/graph_runtime_factory.py @@ -39,8 +39,8 @@ def create(graph_json_str, libmod, libmod_name, params): Returns ------- - graph_module : GraphModule - Runtime graph module that can be used to execute the graph. + graph_module : GraphRuntimeFactoryModule + Runtime graph runtime factory module. """ if not isinstance(graph_json_str, string_types): try: @@ -57,16 +57,10 @@ def create(graph_json_str, libmod, libmod_name, params): class GraphRuntimeFactoryModule(Module): """Graph runtime factory module. - This is a module of graph runtime factory Parameters ---------- - module : Module - The interal tvm module that holds the actual graph functions. - - Attributes - ---------- module : Module The interal tvm module that holds the actual graph functions. """ diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index aa0e472815ac..1924d116fe0a 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -507,4 +507,4 @@ def test_debug_graph_runtime(): test_gpu() test_mod_export() test_remove_package_params() - test_debug_graph_runtime() \ No newline at end of file + test_debug_graph_runtime() From 6c515ef00fbb409086fcc8124fa06eea09f24164 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Thu, 9 Jul 2020 14:23:30 +0800 Subject: [PATCH 18/29] fix GetLib() CHECK issue --- src/runtime/graph/graph_runtime_factory.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index 1074cc41bd03..2a57f7a03c19 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -134,7 +134,7 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { } Module GetLib() const { - CHECK_EQ(this->imports().size(), 0); + CHECK_EQ(this->imports().size(), 1); return this->imports_[0]; } From 926acec5824862114e83431bd387825dbd545da0 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Thu, 9 Jul 2020 15:05:12 +0800 Subject: [PATCH 19/29] refactor --- .../tvm/{runtime => relay/backend}/graph_runtime_factory.py | 4 ++-- python/tvm/relay/build_module.py | 2 +- tests/python/unittest/test_runtime_module_based_interface.py | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) rename python/tvm/{runtime => relay/backend}/graph_runtime_factory.py (97%) diff --git a/python/tvm/runtime/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py similarity index 97% rename from python/tvm/runtime/graph_runtime_factory.py rename to python/tvm/relay/backend/graph_runtime_factory.py index a03014a164db..e87304fa8e3f 100644 --- a/python/tvm/runtime/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -18,8 +18,8 @@ import warnings from tvm._ffi.base import string_types from tvm._ffi.registry import get_global_func -from .module import Module -from . import ndarray +from tvm.runtime.module import Module +from tvm.runtime import ndarray def create(graph_json_str, libmod, libmod_name, params): diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index f0c1a2cbae2b..d696f452a43b 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -26,11 +26,11 @@ from tvm.tir import expr as tvm_expr from .. import nd as _nd, target as _target, autotvm from ..contrib import graph_runtime as _graph_rt -from ..runtime import graph_runtime_factory as _graph_runtime_factory from . import _build_module from . import ty as _ty from . import expr as _expr from . import function as _function +from .backend import graph_runtime_factory as _graph_runtime_factory from .backend import interpreter as _interpreter from .backend.vm import VMExecutor diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 1924d116fe0a..e55ca3749c9a 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -19,7 +19,6 @@ from tvm.relay import testing import tvm from tvm.contrib import graph_runtime -from tvm.runtime import graph_runtime_factory from tvm.contrib.debugger import debug_runtime def verify(data): From edb60a3b4ef05f00f7886ed1aa3c44f72daae270 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Thu, 9 Jul 2020 15:17:44 +0800 Subject: [PATCH 20/29] add missing device api check --- .../python/unittest/test_runtime_module_based_interface.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index e55ca3749c9a..2563d203dadc 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -278,6 +278,9 @@ def verify_rpc_gpu_export(obj_format): def test_remove_package_params(): def verify_cpu_remove_package_params(obj_format): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return mod, params = relay.testing.resnet.get_workload(num_layers=18) with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "llvm", params=params) @@ -321,6 +324,9 @@ def verify_cpu_remove_package_params(obj_format): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) def verify_gpu_remove_package_params(obj_format): + if not tvm.runtime.enabled("cuda"): + print("Skip because cuda is not enabled") + return mod, params = relay.testing.resnet.get_workload(num_layers=18) with relay.build_config(opt_level=3): complied_graph_lib = relay.build_module.build(mod, "cuda", params=params) From 5e08ea9e03e98678fd7908ad0e1ee9afde8994c9 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Thu, 9 Jul 2020 18:30:03 +0800 Subject: [PATCH 21/29] Solve tvm::Map odr --- src/node/container.cc | 3 --- src/runtime/container.cc | 6 ++++++ src/runtime/graph/graph_runtime_factory.cc | 20 +++++++------------- src/runtime/graph/graph_runtime_factory.h | 18 +++++++++--------- 4 files changed, 22 insertions(+), 25 deletions(-) diff --git a/src/node/container.cc b/src/node/container.cc index 60b5f40b98f1..eeca51574223 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -357,7 +357,4 @@ TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) *ret = std::move(rkvs); }); -#if (USE_FALLBACK_STL_MAP == 0) -TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[]; -#endif } // namespace tvm diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 62220a885208..bcdcb56b8af9 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -21,6 +21,7 @@ * \file src/runtime/container.cc * \brief Implementations of common containers. */ +#include #include #include #include @@ -28,6 +29,11 @@ #include namespace tvm { + +#if (USE_FALLBACK_STL_MAP == 0) +TVM_DLL constexpr uint64_t tvm::DenseMapNode::kNextProbeLocation[]; +#endif + namespace runtime { using namespace vm; diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index 16dc472b7bb8..2507b1fd3694 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -35,7 +35,7 @@ namespace tvm { namespace runtime { void GraphRuntimeFactory::Init(const std::string& graph_json, - const std::unordered_map& params, + const tvm::Map& params, const std::string& module_name) { graph_json_ = graph_json; params_ = params; @@ -51,13 +51,8 @@ PackedFunc GraphRuntimeFactory::GetFunction( return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); }); } else if (name == "get_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - Map ret; - for (const auto& kv : this->GetParams()) { - ret.Set(kv.first, kv.second); - } - *rv = ret; - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); } else if (name == module_name_) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { auto module = this->SelectModule(module_name_); @@ -158,7 +153,7 @@ Module GraphRuntimeFactory::SelectModule(const std::string& name) { Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); std::string graph_json; - std::unordered_map params; + tvm::Map params; std::string module_name; CHECK(stream->Read(&graph_json)); uint64_t sz; @@ -169,7 +164,7 @@ Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { for (size_t i = 0; i < sz; ++i) { tvm::runtime::NDArray temp; temp.Load(stream); - params[names[i]] = temp; + params.Set(names[i], temp); } CHECK(stream->Read(&module_name)); auto exec = make_object(); @@ -185,10 +180,9 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args auto exec = make_object(); // The argument order is graph_json, module, module_name, params. CHECK_EQ((args.size() - 3) % 2, 0); - std::unordered_map params; + tvm::Map params; for (size_t i = 3; i < static_cast(args.size()); i += 2) { - std::string name = args[i].operator String(); - params[name] = args[i + 1].operator tvm::runtime::NDArray(); + params.Set(args[i].operator String(), args[i + 1].operator tvm::runtime::NDArray()); } exec->Init(args[0], params, args[2]); exec->Import(args[1]); diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index 2a57f7a03c19..da64121e8b9b 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -25,6 +25,7 @@ #ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ #define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ +#include #include #include #include @@ -50,8 +51,7 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { * \param params The params of graph. * \param module_name The module name of graph. */ - void Init(const std::string& graph_json, - const std::unordered_map& params, + void Init(const std::string& graph_json, const tvm::Map& params, const std::string& module_name = "default"); /*! @@ -98,18 +98,14 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { */ Module SelectModule(const std::string& name); - const std::string& GetJson() const { return graph_json_; } - - std::unordered_map GetParams() const { return params_; } - /*! * \brief Set params. * \param graph_runtime The graph runtime we want to set the params into. * \param params The graph params value we want to set. */ void SetParams(GraphRuntime* graph_runtime, - const std::unordered_map& params) const { - std::unordered_map value = params; + const tvm::Map& params) const { + tvm::Map value = params; // upload big arrays first to avoid memory issue in rpc mode std::vector keys; for (const auto& p : value) { @@ -133,6 +129,10 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { } } + const std::string& GetJson() const { return graph_json_; } + + tvm::Map GetParams() const { return params_; } + Module GetLib() const { CHECK_EQ(this->imports().size(), 1); return this->imports_[0]; @@ -144,7 +144,7 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { /*! \brief The execution graph. */ std::string graph_json_; /*! \brief The params. */ - std::unordered_map params_; + tvm::Map params_; /*! \brief module name */ std::string module_name_; }; From c8f505c922a10d3e71881966bc95f99de07c8df1 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Thu, 9 Jul 2020 18:41:52 +0800 Subject: [PATCH 22/29] skip debug graph runtime test if not enable --- .../python/unittest/test_runtime_module_based_interface.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 2563d203dadc..7602073f5aee 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -490,7 +490,11 @@ def test_debug_graph_runtime(): # raw api ctx = tvm.cpu() - gmod = complied_graph_lib['debug_create']('default', ctx) + try: + gmod = complied_graph_lib['debug_create']('default', ctx) + except: + print("Skip because debug graph_runtime not enabled") + return set_input = gmod["set_input"] run = gmod["run"] get_output = gmod["get_output"] From da6b1d9516d99d76571745ccaa6fcf0c8e49c9f5 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Thu, 9 Jul 2020 22:36:56 +0800 Subject: [PATCH 23/29] Trigger notification From e796d4bcfeaf191ec71110438acd241a9dc90c27 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Mon, 13 Jul 2020 14:52:32 +0800 Subject: [PATCH 24/29] address comments --- .../relay/backend/graph_runtime_factory.py | 50 ++++++++----- python/tvm/relay/build_module.py | 2 +- src/node/container.cc | 3 + src/runtime/container.cc | 6 -- src/runtime/graph/graph_runtime_factory.cc | 73 +++++++------------ src/runtime/graph/graph_runtime_factory.h | 48 +++--------- .../test_runtime_module_based_interface.py | 8 +- 7 files changed, 79 insertions(+), 111 deletions(-) diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index e87304fa8e3f..129f7a14e688 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -18,7 +18,6 @@ import warnings from tvm._ffi.base import string_types from tvm._ffi.registry import get_global_func -from tvm.runtime.module import Module from tvm.runtime import ndarray @@ -52,39 +51,56 @@ def create(graph_json_str, libmod, libmod_name, params): for k, v in params.items(): args.append(k) args.append(ndarray.array(v)) - return GraphRuntimeFactoryModule(fcreate(graph_json_str, libmod, libmod_name, *args)) + return fcreate(graph_json_str, libmod, libmod_name, *args) -class GraphRuntimeFactoryModule(Module): +class GraphRuntimeFactoryModule(object): """Graph runtime factory module. This is a module of graph runtime factory Parameters ---------- - module : Module - The interal tvm module that holds the actual graph functions. + graph_json_str : str or graph class + The graph to be deployed in json format output by nnvm graph. + The graph can only contain one operator(tvm_op) that + points to the name of PackedFunc in the libmod. + libmod : tvm.Module + The module of the corresponding function + libmod_name: str + The name of module + params : dict of str to NDArray + The parameters of module """ - def __init__(self, module): - self.module = module - self.graph_json = None - self.lib = None - self.params = {} + def __init__(self, graph_json_str, libmod, libmod_name, params): + self.graph_json = graph_json_str + self.lib = libmod + self.libmod_name = libmod_name + self.params = params self.iter_cnt = 0 - super(GraphRuntimeFactoryModule, self).__init__(self.module.handle) + self.module = create(graph_json_str, libmod, libmod_name, params) + + def export_library(self, + file_name, + fcompile=None, + addons=None, + **kwargs): + return self.module.export_library(file_name, fcompile, addons, **kwargs) + + # Sometimes we want to get params explicitly. + # For example, we want to save its params value to + # an independentfile. + def get_params(self): + return self.params - def __del__(self): - pass + def __getitem__(self, item): + return self.module.__getitem__(item) def __iter__(self): warnings.warn( "legacy graph runtime behaviour of producing json / lib / params will be " "removed in the next release ", DeprecationWarning, 2) - self.graph_json = self.module["get_json"]() - self.lib = self.module["get_lib"]() - for k, v in self.module["get_params"]().items(): - self.params[k] = v return self diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d696f452a43b..896f33403491 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -253,7 +253,7 @@ def build(mod, target=None, target_host=None, params=None, mod_name='default'): with tophub_context: bld_mod = BuildModule() graph_json, mod, params = bld_mod.build(mod, target, target_host, params) - mod = _graph_runtime_factory.create(graph_json, mod, mod_name, params) + mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, mod, mod_name, params) return mod diff --git a/src/node/container.cc b/src/node/container.cc index eeca51574223..60b5f40b98f1 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -357,4 +357,7 @@ TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) *ret = std::move(rkvs); }); +#if (USE_FALLBACK_STL_MAP == 0) +TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[]; +#endif } // namespace tvm diff --git a/src/runtime/container.cc b/src/runtime/container.cc index bcdcb56b8af9..62220a885208 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -21,7 +21,6 @@ * \file src/runtime/container.cc * \brief Implementations of common containers. */ -#include #include #include #include @@ -29,11 +28,6 @@ #include namespace tvm { - -#if (USE_FALLBACK_STL_MAP == 0) -TVM_DLL constexpr uint64_t tvm::DenseMapNode::kNextProbeLocation[]; -#endif - namespace runtime { using namespace vm; diff --git a/src/runtime/graph/graph_runtime_factory.cc b/src/runtime/graph/graph_runtime_factory.cc index 2507b1fd3694..aa35afaf70f8 100644 --- a/src/runtime/graph/graph_runtime_factory.cc +++ b/src/runtime/graph/graph_runtime_factory.cc @@ -34,9 +34,10 @@ namespace tvm { namespace runtime { -void GraphRuntimeFactory::Init(const std::string& graph_json, - const tvm::Map& params, - const std::string& module_name) { +GraphRuntimeFactory::GraphRuntimeFactory( + const std::string& graph_json, + const std::unordered_map& params, + const std::string& module_name) { graph_json_ = graph_json; params_ = params; module_name_ = module_name; @@ -44,40 +45,31 @@ void GraphRuntimeFactory::Init(const std::string& graph_json, PackedFunc GraphRuntimeFactory::GetFunction( const std::string& name, const tvm::runtime::ObjectPtr& sptr_to_self) { - if (name == "get_json") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetJson(); }); - } else if (name == "get_lib") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); }); - } else if (name == "get_params") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); - } else if (name == module_name_) { + if (name == module_name_) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - auto module = this->SelectModule(module_name_); std::vector contexts; for (int i = 0; i < args.num_args; ++i) { contexts.emplace_back(args[i].operator TVMContext()); } - *rv = this->RuntimeCreate(module, contexts); + *rv = this->RuntimeCreate(contexts); }); } else if (name == "debug_create") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_GE(args.size(), 2); std::string module_name = args[0].operator String(); - auto module = this->SelectModule(module_name); + CHECK(module_name == module_name_) << "Currently we only support single model for now."; std::vector contexts; for (int i = 1; i < args.num_args; ++i) { contexts.emplace_back(args[i].operator TVMContext()); } - *rv = this->DebugRuntimeCreate(module, contexts); + *rv = this->DebugRuntimeCreate(contexts); }); } else if (name == "remove_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - auto exec = make_object(); - exec->Init(this->GetJson(), {}, this->GetModuleName()); - exec->Import(this->GetLib()); + std::unordered_map empty_params{}; + auto exec = + make_object(this->graph_json_, empty_params, this->module_name_); + exec->Import(this->imports_[0]); *rv = Module(exec); }); } else { @@ -103,19 +95,15 @@ void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) { stream->Write(module_name_); } -Module GraphRuntimeFactory::RuntimeCreate(Module module, const std::vector& ctxs) { - auto factory_module = module.as(); - CHECK(factory_module != nullptr); +Module GraphRuntimeFactory::RuntimeCreate(const std::vector& ctxs) { auto exec = make_object(); - exec->Init(factory_module->GetJson(), factory_module->GetLib(), ctxs); + exec->Init(this->graph_json_, this->imports_[0], ctxs); // set params - SetParams(exec.get(), factory_module->GetParams()); + SetParams(exec.get(), this->params_); return Module(exec); } -Module GraphRuntimeFactory::DebugRuntimeCreate(Module module, const std::vector& ctxs) { - auto factory_module = module.as(); - CHECK(factory_module != nullptr); +Module GraphRuntimeFactory::DebugRuntimeCreate(const std::vector& ctxs) { const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_debug.create"); CHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_debug.create in registry. " "Do you enable debug graph runtime build?"; @@ -129,8 +117,8 @@ Module GraphRuntimeFactory::DebugRuntimeCreate(Module module, const std::vector< std::vector values(args_size); std::vector codes(args_size); runtime::TVMArgsSetter setter(values.data(), codes.data()); - setter(0, factory_module->GetJson()); - setter(1, factory_module->GetLib()); + setter(0, this->graph_json_); + setter(1, this->imports_[0]); for (size_t i = 0; i < unpacked_ctxs.size(); ++i) { setter(i + 2, unpacked_ctxs[i]); } @@ -138,22 +126,14 @@ Module GraphRuntimeFactory::DebugRuntimeCreate(Module module, const std::vector< pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv); Module mod = rv.operator Module(); // debug graph runtime is one child class of graph runtime. - SetParams(const_cast(mod.as()), factory_module->GetParams()); + SetParams(const_cast(mod.as()), this->params_); return mod; } -Module GraphRuntimeFactory::SelectModule(const std::string& name) { - CHECK(name == module_name_) << "Currently we only support single model for now."; - auto exec = make_object(); - exec->Init(this->GetJson(), this->GetParams()); - exec->Import(this->GetLib()); - return Module(exec); -} - Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { dmlc::Stream* stream = static_cast(strm); std::string graph_json; - tvm::Map params; + std::unordered_map params; std::string module_name; CHECK(stream->Read(&graph_json)); uint64_t sz; @@ -164,11 +144,10 @@ Module GraphRuntimeFactoryModuleLoadBinary(void* strm) { for (size_t i = 0; i < sz; ++i) { tvm::runtime::NDArray temp; temp.Load(stream); - params.Set(names[i], temp); + params[names[i]] = temp; } CHECK(stream->Read(&module_name)); - auto exec = make_object(); - exec->Init(graph_json, params, module_name); + auto exec = make_object(graph_json, params, module_name); return Module(exec); } @@ -177,14 +156,14 @@ TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args "graph_runtime_factory.create needs at least 3, " "but it has " << args.num_args; - auto exec = make_object(); // The argument order is graph_json, module, module_name, params. CHECK_EQ((args.size() - 3) % 2, 0); - tvm::Map params; + std::unordered_map params; for (size_t i = 3; i < static_cast(args.size()); i += 2) { - params.Set(args[i].operator String(), args[i + 1].operator tvm::runtime::NDArray()); + std::string name = args[i].operator String(); + params[name] = args[i + 1].operator tvm::runtime::NDArray(); } - exec->Init(args[0], params, args[2]); + auto exec = make_object(args[0], params, args[2]); exec->Import(args[1]); *rv = Module(exec); }); diff --git a/src/runtime/graph/graph_runtime_factory.h b/src/runtime/graph/graph_runtime_factory.h index da64121e8b9b..98fb27c43ea2 100644 --- a/src/runtime/graph/graph_runtime_factory.h +++ b/src/runtime/graph/graph_runtime_factory.h @@ -25,7 +25,6 @@ #ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ #define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_FACTORY_H_ -#include #include #include #include @@ -46,13 +45,14 @@ namespace runtime { class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { public: /*! - * \brief Initialize the GraphRuntimeFactory with graph and context. + * \brief Construct the GraphRuntimeFactory. * \param graph_json The execution graph. * \param params The params of graph. * \param module_name The module name of graph. */ - void Init(const std::string& graph_json, const tvm::Map& params, - const std::string& module_name = "default"); + GraphRuntimeFactory(const std::string& graph_json, + const std::unordered_map& params, + const std::string& module_name = "default"); /*! * \brief Get member function to front-end @@ -75,28 +75,19 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { /*! * \brief Create a specific runtime module - * \param module The module we will be used for creating runtime * \param ctxs The context of the host and devices where graph nodes will be * executed on. * \return created runtime module */ - Module RuntimeCreate(Module module, const std::vector& ctxs); + Module RuntimeCreate(const std::vector& ctxs); /*! * \brief Create a specific debug runtime module - * \param module The module we will be used for creating runtime * \param ctxs The context of the host and devices where graph nodes will be * executed on. * \return created debug runtime module */ - Module DebugRuntimeCreate(Module module, const std::vector& ctxs); - - /*! - * \brief Select the specific module - * \param name The name of the module - * \return selected module - */ - Module SelectModule(const std::string& name); + Module DebugRuntimeCreate(const std::vector& ctxs); /*! * \brief Set params. @@ -104,8 +95,8 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { * \param params The graph params value we want to set. */ void SetParams(GraphRuntime* graph_runtime, - const tvm::Map& params) const { - tvm::Map value = params; + const std::unordered_map& params) const { + std::unordered_map value = params; // upload big arrays first to avoid memory issue in rpc mode std::vector keys; for (const auto& p : value) { @@ -113,13 +104,9 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { } std::sort(std::begin(keys), std::end(keys), [&](const std::string& lhs, const std::string& rhs) -> bool { - auto lhs_shape = value[lhs].Shape(); - auto rhs_shape = value[rhs].Shape(); - auto lhs_prod = std::accumulate(std::begin(lhs_shape), std::end(lhs_shape), 1, - std::multiplies()); - auto rhs_prod = std::accumulate(std::begin(rhs_shape), std::end(rhs_shape), 1, - std::multiplies()); - return lhs_prod > rhs_prod; + auto lhs_size = GetDataSize(value[lhs].ToDLPack()->dl_tensor); + auto rhs_size = GetDataSize(value[rhs].ToDLPack()->dl_tensor); + return lhs_size > rhs_size; }); for (const auto& key : keys) { int in_idx = graph_runtime->GetInputIndex(key); @@ -129,22 +116,11 @@ class TVM_DLL GraphRuntimeFactory : public runtime::ModuleNode { } } - const std::string& GetJson() const { return graph_json_; } - - tvm::Map GetParams() const { return params_; } - - Module GetLib() const { - CHECK_EQ(this->imports().size(), 1); - return this->imports_[0]; - } - - const std::string& GetModuleName() const { return module_name_; } - protected: /*! \brief The execution graph. */ std::string graph_json_; /*! \brief The params. */ - tvm::Map params_; + std::unordered_map params_; /*! \brief module name */ std::string module_name_; }; diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 7602073f5aee..7f38b0a32741 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -296,7 +296,7 @@ def verify_cpu_remove_package_params(obj_format): complied_graph_lib_no_params = complied_graph_lib["remove_params"]() complied_graph_lib_no_params.export_library(path_lib) with open(temp.relpath("deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(complied_graph_lib["get_params"]())) + fo.write(relay.save_param_dict(complied_graph_lib.get_params())) loaded_lib = tvm.runtime.load_module(path_lib) data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") ctx = tvm.cpu(0) @@ -342,7 +342,7 @@ def verify_gpu_remove_package_params(obj_format): complied_graph_lib_no_params = complied_graph_lib["remove_params"]() complied_graph_lib_no_params.export_library(path_lib) with open(temp.relpath("deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(complied_graph_lib["get_params"]())) + fo.write(relay.save_param_dict(complied_graph_lib.get_params())) loaded_lib = tvm.runtime.load_module(path_lib) data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") ctx = tvm.gpu(0) @@ -389,7 +389,7 @@ def verify_rpc_cpu_remove_package_params(obj_format): complied_graph_lib_no_params.export_library(path_lib) path_params = temp.relpath("deploy_param.params") with open(path_params, "wb") as fo: - fo.write(relay.save_param_dict(complied_graph_lib["get_params"]())) + fo.write(relay.save_param_dict(complied_graph_lib.get_params())) from tvm import rpc server = rpc.Server("localhost", use_popen=True) @@ -441,7 +441,7 @@ def verify_rpc_gpu_remove_package_params(obj_format): complied_graph_lib_no_params.export_library(path_lib) path_params = temp.relpath("deploy_param.params") with open(path_params, "wb") as fo: - fo.write(relay.save_param_dict(complied_graph_lib["get_params"]())) + fo.write(relay.save_param_dict(complied_graph_lib.get_params())) from tvm import rpc server = rpc.Server("localhost", use_popen=True) From 03793b3bbce57cc54564ff175bc79caab6d69de6 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Mon, 13 Jul 2020 15:03:55 +0800 Subject: [PATCH 25/29] update doc comments --- python/tvm/relay/backend/graph_runtime_factory.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index 129f7a14e688..3813a5795713 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -22,7 +22,7 @@ def create(graph_json_str, libmod, libmod_name, params): - """Create a runtime executor module given a graph and module. + """Create a runtime executor module. Parameters ---------- graph_json_str : str or graph class @@ -38,7 +38,7 @@ def create(graph_json_str, libmod, libmod_name, params): Returns ------- - graph_module : GraphRuntimeFactoryModule + graph_module : Module Runtime graph runtime factory module. """ if not isinstance(graph_json_str, string_types): @@ -80,11 +80,7 @@ def __init__(self, graph_json_str, libmod, libmod_name, params): self.iter_cnt = 0 self.module = create(graph_json_str, libmod, libmod_name, params) - def export_library(self, - file_name, - fcompile=None, - addons=None, - **kwargs): + def export_library(self, file_name, fcompile=None, addons=None, **kwargs): return self.module.export_library(file_name, fcompile, addons, **kwargs) # Sometimes we want to get params explicitly. From d7f44a9b5bcc38482737725c4e99e5cebda1191e Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Mon, 13 Jul 2020 16:20:49 +0800 Subject: [PATCH 26/29] add get_json for debug graph runtime --- python/tvm/relay/backend/graph_runtime_factory.py | 3 +++ tests/python/unittest/test_runtime_module_based_interface.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index 3813a5795713..80a64869f217 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -89,6 +89,9 @@ def export_library(self, file_name, fcompile=None, addons=None, **kwargs): def get_params(self): return self.params + def get_json(self): + return self.graph_json + def __getitem__(self, item): return self.module.__getitem__(item) diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 7f38b0a32741..5ab4e829f2ed 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -504,7 +504,8 @@ def test_debug_graph_runtime(): tvm.testing.assert_allclose(out, verify(data), atol=1e-5) # debug graph runtime wrapper - debug_g_mod = debug_runtime.GraphModuleDebug(complied_graph_lib['debug_create']('default', ctx), [ctx], complied_graph_lib['get_json'](), None) + debug_g_mod = debug_runtime.GraphModuleDebug(complied_graph_lib['debug_create']('default', ctx), [ctx], + complied_graph_lib.get_json(), None) debug_g_mod.set_input("data", data) debug_g_mod.run() out = debug_g_mod.get_output(0).asnumpy() From 5c23a07801c203e5de682c40fce31d965f415863 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Mon, 13 Jul 2020 19:23:24 +0800 Subject: [PATCH 27/29] comment fix --- python/tvm/relay/backend/graph_runtime_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index 80a64869f217..ef635e5f8a73 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -85,7 +85,7 @@ def export_library(self, file_name, fcompile=None, addons=None, **kwargs): # Sometimes we want to get params explicitly. # For example, we want to save its params value to - # an independentfile. + # an independent file. def get_params(self): return self.params From 54bbdc8fc430976e730bd379e7c714968eb1adbd Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Mon, 13 Jul 2020 22:37:17 +0800 Subject: [PATCH 28/29] Trigger CI From a83333ae0748bf86f47b550cb2be7bbf33dc0761 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Tue, 14 Jul 2020 11:51:27 +0800 Subject: [PATCH 29/29] update --- python/tvm/contrib/debugger/debug_runtime.py | 13 ++--- python/tvm/contrib/graph_runtime.py | 12 ++--- .../relay/backend/graph_runtime_factory.py | 51 ++++--------------- src/runtime/module.cc | 15 +----- 4 files changed, 22 insertions(+), 69 deletions(-) diff --git a/python/tvm/contrib/debugger/debug_runtime.py b/python/tvm/contrib/debugger/debug_runtime.py index 848d7f57d1de..1f96a86c851a 100644 --- a/python/tvm/contrib/debugger/debug_runtime.py +++ b/python/tvm/contrib/debugger/debug_runtime.py @@ -35,10 +35,10 @@ def create(graph_json_str, libmod, ctx, dump_root=None): Parameters ---------- - graph_json_str : str or graph class + graph_json_str : str The graph to be deployed in json format output by graph compiler. - The graph can only contain one operator(tvm_op) that - points to the name of PackedFunc in the libmod. + The graph can contain operator(tvm_op) that points to the name + of PackedFunc in the libmod. libmod : tvm.Module The module of the corresponding function. @@ -54,11 +54,8 @@ def create(graph_json_str, libmod, ctx, dump_root=None): graph_module : GraphModuleDebug Debug Runtime graph module that can be used to execute the graph. """ - if not isinstance(graph_json_str, string_types): - try: - graph_json_str = graph_json_str._tvm_graph_json() - except AttributeError: - raise ValueError("Type %s is not supported" % type(graph_json_str)) + assert isinstance(graph_json_str, string_types) + try: ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) if num_rpc_ctx == len(ctx): diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 9b714a84b541..ec102f5b4796 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -29,10 +29,10 @@ def create(graph_json_str, libmod, ctx): Parameters ---------- - graph_json_str : str or graph class + graph_json_str : str The graph to be deployed in json format output by json graph. - The graph can only contain one operator(tvm_op) that - points to the name of PackedFunc in the libmod. + The graph can contain operator(tvm_op) that points to the name + of PackedFunc in the libmod. libmod : tvm.runtime.Module The module of the corresponding function @@ -48,11 +48,7 @@ def create(graph_json_str, libmod, ctx): graph_module : GraphModule Runtime graph module that can be used to execute the graph. """ - if not isinstance(graph_json_str, string_types): - try: - graph_json_str = graph_json_str._tvm_graph_json() - except AttributeError: - raise ValueError("Type %s is not supported" % type(graph_json_str)) + assert isinstance(graph_json_str, string_types) ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx) diff --git a/python/tvm/relay/backend/graph_runtime_factory.py b/python/tvm/relay/backend/graph_runtime_factory.py index ef635e5f8a73..f7ed122128f7 100644 --- a/python/tvm/relay/backend/graph_runtime_factory.py +++ b/python/tvm/relay/backend/graph_runtime_factory.py @@ -20,50 +20,16 @@ from tvm._ffi.registry import get_global_func from tvm.runtime import ndarray - -def create(graph_json_str, libmod, libmod_name, params): - """Create a runtime executor module. - Parameters - ---------- - graph_json_str : str or graph class - The graph to be deployed in json format output by nnvm graph. - The graph can only contain one operator(tvm_op) that - points to the name of PackedFunc in the libmod. - libmod : tvm.Module - The module of the corresponding function - libmod_name: str - The name of module - params : dict of str to NDArray - The parameters of module - - Returns - ------- - graph_module : Module - Runtime graph runtime factory module. - """ - if not isinstance(graph_json_str, string_types): - try: - graph_json_str = graph_json_str._tvm_graph_json() - except AttributeError: - raise ValueError("Type %s is not supported" % type(graph_json_str)) - fcreate = get_global_func("tvm.graph_runtime_factory.create") - args = [] - for k, v in params.items(): - args.append(k) - args.append(ndarray.array(v)) - return fcreate(graph_json_str, libmod, libmod_name, *args) - - class GraphRuntimeFactoryModule(object): """Graph runtime factory module. This is a module of graph runtime factory Parameters ---------- - graph_json_str : str or graph class - The graph to be deployed in json format output by nnvm graph. - The graph can only contain one operator(tvm_op) that - points to the name of PackedFunc in the libmod. + graph_json_str : str + The graph to be deployed in json format output by graph compiler. + The graph can contain operator(tvm_op) that points to the name of + PackedFunc in the libmod. libmod : tvm.Module The module of the corresponding function libmod_name: str @@ -73,12 +39,18 @@ class GraphRuntimeFactoryModule(object): """ def __init__(self, graph_json_str, libmod, libmod_name, params): + assert isinstance(graph_json_str, string_types) + fcreate = get_global_func("tvm.graph_runtime_factory.create") + args = [] + for k, v in params.items(): + args.append(k) + args.append(ndarray.array(v)) + self.module = fcreate(graph_json_str, libmod, libmod_name, *args) self.graph_json = graph_json_str self.lib = libmod self.libmod_name = libmod_name self.params = params self.iter_cnt = 0 - self.module = create(graph_json_str, libmod, libmod_name, params) def export_library(self, file_name, fcompile=None, addons=None, **kwargs): return self.module.export_library(file_name, fcompile, addons, **kwargs) @@ -102,7 +74,6 @@ def __iter__(self): DeprecationWarning, 2) return self - def __next__(self): if self.iter_cnt > 2: raise StopIteration diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 769e9d9719f6..8052467e2dde 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -66,19 +66,8 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) PackedFunc pf = self->GetFunction(name, GetObjectPtr(this)); if (pf != nullptr) return pf; if (query_imports) { - std::unordered_set visited{self}; - std::vector stack{self}; - while (!stack.empty()) { - ModuleNode* n = stack.back(); - stack.pop_back(); - for (Module& m : n->imports_) { - ModuleNode* next = m.operator->(); - if (visited.count(next)) continue; - pf = m->GetFunction(name, m.data_); - if (pf != nullptr) return pf; - visited.insert(next); - stack.push_back(next); - } + for (Module& m : self->imports_) { + pf = m.operator->()->GetFunction(name, query_imports); } } return pf;