From 350d8b9dd471226ae3b74b7a6116f63cbf3f0513 Mon Sep 17 00:00:00 2001 From: FrozenGene Date: Tue, 17 Dec 2019 17:15:15 +0800 Subject: [PATCH] Support standardize runtime module --- python/tvm/_ffi/function.py | 3 + python/tvm/module.py | 64 ++++-- src/codegen/codegen.cc | 116 +++++++++- src/runtime/library_module.cc | 66 +++++- src/runtime/module.cc | 2 +- .../unittest/test_runtime_module_export.py | 208 ++++++++++++++++++ 6 files changed, 412 insertions(+), 47 deletions(-) create mode 100644 tests/python/unittest/test_runtime_module_export.py diff --git a/python/tvm/_ffi/function.py b/python/tvm/_ffi/function.py index 60e7aeb9aec5..ed2f7e1f62d6 100644 --- a/python/tvm/_ffi/function.py +++ b/python/tvm/_ffi/function.py @@ -82,6 +82,9 @@ def __init__(self, handle): def __del__(self): check_call(_LIB.TVMModFree(self.handle)) + def __hash__(self): + return ctypes.cast(self.handle, ctypes.c_void_p).value + @property def entry_func(self): """Get the entry function diff --git a/python/tvm/module.py b/python/tvm/module.py index 976fb2d81cc7..e9e229469831 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -118,31 +118,28 @@ def export_library(self, self.save(file_name) return - if not (self.type_key == "llvm" or self.type_key == "c"): - raise ValueError("Module[%s]: Only llvm and c support export shared" % self.type_key) + modules = self._collect_dso_modules() temp = _util.tempdir() - if fcompile is not None and hasattr(fcompile, "object_format"): - object_format = fcompile.object_format - else: - if self.type_key == "llvm": - object_format = "o" + files = [] + is_system_lib = False + has_c_module = False + for index, module in enumerate(modules): + if fcompile is not None and hasattr(fcompile, "object_format"): + object_format = fcompile.object_format else: - assert self.type_key == "c" - object_format = "cc" - path_obj = temp.relpath("lib." + object_format) - self.save(path_obj) - files = [path_obj] - is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")() - has_imported_c_file = False + if module.type_key == "llvm": + object_format = "o" + else: + assert module.type_key == "c" + object_format = "cc" + has_c_module = True + path_obj = temp.relpath("lib" + str(index) + "." + object_format) + module.save(path_obj) + files.append(path_obj) + is_system_lib = (module.type_key == "llvm" and + module.get_function("__tvm_is_system_module")()) + if self.imported_modules: - for i, m in enumerate(self.imported_modules): - if m.type_key == "c": - has_imported_c_file = True - c_file_name = "tmp_" + str(i) + ".cc" - path_cc = temp.relpath(c_file_name) - with open(path_cc, "w") as f: - f.write(m.get_source()) - files.append(path_cc) path_cc = temp.relpath("devc.cc") with open(path_cc, "w") as f: f.write(_PackImportsToC(self, is_system_lib)) @@ -152,13 +149,15 @@ def export_library(self, fcompile = _tar.tar else: fcompile = _cc.create_shared - if self.type_key == "c" or has_imported_c_file: + + if has_c_module: options = [] if "options" in kwargs: opts = kwargs["options"] options = opts if isinstance(opts, (list, tuple)) else [opts] opts = options + ["-I" + path for path in find_include_path()] kwargs.update({'options': opts}) + fcompile(file_name, files, **kwargs) def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0): @@ -219,6 +218,25 @@ def evaluator(*args): except NameError: raise NameError("time_evaluate is only supported when RPC is enabled") + def _collect_dso_modules(self): + """Helper function to collect dso modules, then return it.""" + visited, stack, dso_modules = set(), [], [] + # append root module + visited.add(self) + stack.append(self) + while stack: + module = stack.pop() + if module._dso_exportable(): + dso_modules.append(module) + for m in module.imported_modules: + if m not in visited: + visited.add(m) + stack.append(m) + return dso_modules + + def _dso_exportable(self): + return self.type_key == "llvm" or self.type_key == "c" + def system_lib(): """Get system-wide library module singleton. diff --git a/src/codegen/codegen.cc b/src/codegen/codegen.cc index 6ce76f60e0e3..60b12dc6e553 100644 --- a/src/codegen/codegen.cc +++ b/src/codegen/codegen.cc @@ -28,7 +28,10 @@ #include #include #include -#include +#include +#include +#include +#include namespace tvm { namespace codegen { @@ -58,20 +61,111 @@ runtime::Module Build(const Array& funcs, return m; } +/*! \brief Helper class to serialize module */ +class ModuleSerializer { + public: + explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { + Init(); + } + + void SerializeModule(dmlc::Stream* stream) { + // Only have one DSO module and it is in the root, then + // we will not produce import_tree_. + bool has_import_tree = true; + if (DSOExportable(mod_.operator->()) && mod_->imports().empty()) { + has_import_tree = false; + } + uint64_t sz = 0; + if (has_import_tree) { + // we will append one key for _import_tree + // The layout is the same as before: binary_size, key, logic, key, logic... + sz = mod_vec_.size() + 1; + } else { + // Keep the old behaviour + sz = mod_->imports().size(); + } + stream->Write(sz); + + for (auto m : mod_vec_) { + std::string mod_type_key = m->type_key(); + if (!DSOExportable(m)) { + stream->Write(mod_type_key); + m->SaveToBinary(stream); + } else if (has_import_tree) { + mod_type_key = "_lib"; + stream->Write(mod_type_key); + } + } + + // Write _import_tree key if we have + if (has_import_tree) { + std::string import_key = "_import_tree"; + stream->Write(import_key); + stream->Write(import_tree_row_ptr_); + stream->Write(import_tree_child_indices_); + } + } + + private: + void Init() { + CreateModuleIndex(); + CreateImportTree(); + } + + // invariance: root module is always at location 0. + // The module order is collected via DFS + void CreateModuleIndex() { + std::unordered_set visited {mod_.operator->()}; + std::vector stack {mod_.operator->()}; + uint64_t module_index = 0; + + while (!stack.empty()) { + runtime::ModuleNode* n = stack.back(); + stack.pop_back(); + mod2index_[n] = module_index++; + mod_vec_.emplace_back(n); + for (runtime::Module m : n->imports()) { + runtime::ModuleNode* next = m.operator->(); + if (visited.count(next) == 0) { + visited.insert(next); + stack.push_back(next); + } + } + } + } + + void CreateImportTree() { + for (auto m : mod_vec_) { + for (runtime::Module im : m->imports()) { + uint64_t mod_index = mod2index_[im.operator->()]; + import_tree_child_indices_.push_back(mod_index); + } + import_tree_row_ptr_.push_back(import_tree_child_indices_.size()); + } + } + + bool DSOExportable(const runtime::ModuleNode* mod) { + return !std::strcmp(mod->type_key(), "llvm") || + !std::strcmp(mod->type_key(), "c"); + } + + runtime::Module mod_; + // construct module to index + std::unordered_map mod2index_; + // index -> module + std::vector mod_vec_; + std::vector import_tree_row_ptr_ {0}; + std::vector import_tree_child_indices_; +}; + std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { std::string bin; dmlc::MemoryStringStream ms(&bin); dmlc::Stream* stream = &ms; - uint64_t sz = static_cast(mod->imports().size()); - stream->Write(sz); - for (runtime::Module im : mod->imports()) { - CHECK_EQ(im->imports().size(), 0U) - << "Only support simply one-level hierarchy"; - std::string tkey = im->type_key(); - stream->Write(tkey); - if (tkey == "c") continue; - im->SaveToBinary(stream); - } + + ModuleSerializer module_serializer(mod); + module_serializer.SerializeModule(stream); + // translate to C program std::ostringstream os; os << "#ifdef _WIN32\n" diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index d3283bc19767..9aaf5b9ad390 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include "library_module.h" namespace tvm { @@ -108,9 +109,11 @@ void InitContextFunctions(std::function fgetsymbol) { /*! * \brief Load and append module blob to module list * \param mblob The module blob. - * \param module_list The module list to append to + * \param lib The library. + * + * \return Root Module. */ -void ImportModuleBlob(const char* mblob, std::vector* mlist) { +runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { #ifndef _LIBCPP_SGX_CONFIG CHECK(mblob != nullptr); uint64_t nbytes = 0; @@ -123,20 +126,56 @@ void ImportModuleBlob(const char* mblob, std::vector* mlist) { dmlc::Stream* stream = &fs; uint64_t size; CHECK(stream->Read(&size)); + std::vector modules; + std::vector import_tree_row_ptr; + std::vector import_tree_child_indices; for (uint64_t i = 0; i < size; ++i) { std::string tkey; CHECK(stream->Read(&tkey)); - if (tkey == "c") continue; - std::string fkey = "module.loadbinary_" + tkey; - const PackedFunc* f = Registry::Get(fkey); - CHECK(f != nullptr) + // Currently, _lib is for DSOModule, but we + // don't have loadbinary function for it currently + if (tkey == "_lib") { + auto dso_module = Module(make_object(lib)); + modules.emplace_back(dso_module); + } else if (tkey == "_import_tree") { + CHECK(stream->Read(&import_tree_row_ptr)); + CHECK(stream->Read(&import_tree_child_indices)); + } else { + std::string fkey = "module.loadbinary_" + tkey; + const PackedFunc* f = Registry::Get(fkey); + CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented."; - Module m = (*f)(static_cast(stream)); - mlist->push_back(m); + Module m = (*f)(static_cast(stream)); + modules.emplace_back(m); + } } + // if we are using old dll, we don't have import tree + // so that we can't reconstruct module relationship using import tree + if (import_tree_row_ptr.empty()) { + auto n = make_object(lib); + auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->()); + for (const auto& m : modules) { + module_import_addr->emplace_back(m); + } + return Module(n); + } else { + for (size_t i = 0; i < modules.size(); ++i) { + for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) { + auto module_import_addr = ModuleInternal::GetImportsAddr(modules[i].operator->()); + auto child_index = import_tree_child_indices[j]; + CHECK(child_index < modules.size()); + module_import_addr->emplace_back(modules[child_index]); + } + } + } + CHECK(!modules.empty()); + // invariance: root module is always at location 0. + // The module order is collected via DFS + return modules[0]; #else LOG(FATAL) << "SGX does not support ImportModuleBlob"; + return Module(); #endif } @@ -149,17 +188,20 @@ Module CreateModuleFromLibrary(ObjectPtr lib) { const char* dev_mblob = reinterpret_cast( lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); + Module root_mod; if (dev_mblob != nullptr) { - ImportModuleBlob( - dev_mblob, ModuleInternal::GetImportsAddr(n.operator->())); + root_mod = ProcessModuleBlob(dev_mblob, lib); + } else { + // Only have one single DSO Module + root_mod = Module(n); } - Module root_mod = Module(n); - // allow lookup of symbol from root(so all symbols are visible). + // allow lookup of symbol from root (so all symbols are visible). if (auto *ctx_addr = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { *ctx_addr = root_mod.operator->(); } + return root_mod; } } // namespace runtime diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 161675c7ca0c..2f3e337b4ce8 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -115,7 +115,7 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { if (it != import_cache_.end()) return it->second.get(); PackedFunc pf; for (Module& m : this->imports_) { - pf = m.GetFunction(name, false); + pf = m.GetFunction(name, true); if (pf != nullptr) break; } if (pf == nullptr) { diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py new file mode 100644 index 000000000000..b676cf2d5244 --- /dev/null +++ b/tests/python/unittest/test_runtime_module_export.py @@ -0,0 +1,208 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tvm import relay +from tvm.relay import testing +import tvm + +from tvm.contrib import util +header_file_dir_path = util.tempdir() + + +def gen_engine_header(): + code = r''' + #ifndef _ENGINE_H_ + #define _ENGINE_H_ + #include + #include + #include + #include + class Engine { + }; + + #endif + ''' + header_file = header_file_dir_path.relpath("gcc_engine.h") + with open(header_file, 'w') as f: + f.write(code) + + +def generate_engine_module(): + code = r''' + #include + #include + #include "gcc_engine.h" + + extern "C" void gcc_1_(float* gcc_input4, float* gcc_input5, + float* gcc_input6, float* gcc_input7, float* out) { + Engine engine; + } + ''' + gen_engine_header() + csource_module = tvm.module.csource_module_create(code, "cc") + return csource_module + + +def test_mod_export(): + def verify_gpu_mod_export(obj_format): + for device in ["llvm", "cuda"]: + if not tvm.module.enabled(device): + print("skip because %s is not enabled..." % device) + return + + resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) + resnet50_mod, resnet50_params = relay.testing.resnet.get_workload(num_layers=50) + with relay.build_config(opt_level=3): + _, resnet18_gpu_lib, _ = relay.build_module.build(resnet18_mod, "cuda", params=resnet18_params) + _, resnet50_cpu_lib, _ = relay.build_module.build(resnet50_mod, "llvm", params=resnet50_params) + + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + resnet18_gpu_lib.imported_modules[0].import_module(resnet50_cpu_lib) + resnet18_gpu_lib.export_library(path_lib) + loaded_lib = tvm.module.load(path_lib) + assert loaded_lib.type_key == "library" + assert loaded_lib.imported_modules[0].type_key == "cuda" + assert loaded_lib.imported_modules[0].imported_modules[0].type_key == "library" + + def verify_multi_dso_mod_export(obj_format): + for device in ["llvm"]: + if not tvm.module.enabled(device): + print("skip because %s is not enabled..." % device) + return + + resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + _, resnet18_cpu_lib, _ = relay.build_module.build(resnet18_mod, "llvm", params=resnet18_params) + + A = tvm.placeholder((1024,), name='A') + B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + s = tvm.create_schedule(B.op) + f = tvm.build(s, [A, B], "llvm", name="myadd") + from tvm.contrib import util + temp = util.tempdir() + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + resnet18_cpu_lib.import_module(f) + resnet18_cpu_lib.export_library(path_lib) + loaded_lib = tvm.module.load(path_lib) + assert loaded_lib.type_key == "library" + assert loaded_lib.imported_modules[0].type_key == "library" + + def verify_json_import_dso(obj_format): + for device in ["llvm"]: + if not tvm.module.enabled(device): + print("skip because %s is not enabled..." % device) + return + + # Get subgraph Json. + subgraph_json = ("json_rt_0\n" + + "input 0 10 10\n" + + "input 1 10 10\n" + + "input 2 10 10\n" + + "input 3 10 10\n" + + "add 4 inputs: 0 1 shape: 10 10\n" + + "sub 5 inputs: 4 2 shape: 10 10\n" + + "mul 6 inputs: 5 3 shape: 10 10\n" + + "json_rt_1\n" + + "input 0 10 10\n" + + "input 1 10 10\n" + + "input 2 10 10\n" + + "input 3 10 10\n" + + "add 4 inputs: 0 1 shape: 10 10\n" + + "sub 5 inputs: 4 2 shape: 10 10\n" + + "mul 6 inputs: 5 3 shape: 10 10") + + from tvm.contrib import util + temp = util.tempdir() + subgraph_path = temp.relpath('subgraph.examplejson') + with open(subgraph_path, 'w') as f: + f.write(subgraph_json) + + # Get Json and module. + A = tvm.placeholder((1024,), name='A') + B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + s = tvm.create_schedule(B.op) + f = tvm.build(s, [A, B], "llvm", name="myadd") + try: + ext_lib = tvm.module.load(subgraph_path, "examplejson") + except: + print("skip because Loader of examplejson is not presented") + return + ext_lib.import_module(f) + if obj_format == ".so": + file_name = "deploy_lib.so" + else: + assert obj_format == ".tar" + file_name = "deploy_lib.tar" + path_lib = temp.relpath(file_name) + ext_lib.export_library(path_lib) + lib = tvm.module.load(path_lib) + assert lib.type_key == "examplejson" + assert lib.imported_modules[0].type_key == "library" + + def verify_multi_c_mod_export(): + from shutil import which + if which("gcc") is None: + print("Skip test because gcc is not available.") + + for device in ["llvm"]: + if not tvm.module.enabled(device): + print("skip because %s is not enabled..." % device) + return + + resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + _, resnet18_cpu_lib, _ = relay.build_module.build(resnet18_mod, "llvm", params=resnet18_params) + + A = tvm.placeholder((1024,), name='A') + B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + s = tvm.create_schedule(B.op) + f = tvm.build(s, [A, B], "c", name="myadd") + engine_module = generate_engine_module() + from tvm.contrib import util + temp = util.tempdir() + file_name = "deploy_lib.so" + path_lib = temp.relpath(file_name) + resnet18_cpu_lib.import_module(f) + resnet18_cpu_lib.import_module(engine_module) + kwargs = {"options": ["-O2", "-std=c++11", "-I" + header_file_dir_path.relpath("")]} + resnet18_cpu_lib.export_library(path_lib, fcompile=False, **kwargs) + loaded_lib = tvm.module.load(path_lib) + assert loaded_lib.type_key == "library" + assert loaded_lib.imported_modules[0].type_key == "library" + assert loaded_lib.imported_modules[1].type_key == "library" + + for obj_format in [".so", ".tar"]: + verify_gpu_mod_export(obj_format) + verify_multi_dso_mod_export(obj_format) + verify_json_import_dso(obj_format) + + verify_multi_c_mod_export() + + +if __name__ == "__main__": + test_mod_export()