diff --git a/apps/android_deploy/app/src/main/jni/tvm_runtime.h b/apps/android_deploy/app/src/main/jni/tvm_runtime.h index 573612b93bc25..0d038fb1060cd 100644 --- a/apps/android_deploy/app/src/main/jni/tvm_runtime.h +++ b/apps/android_deploy/app/src/main/jni/tvm_runtime.h @@ -27,12 +27,12 @@ #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/workspace_pool.cc" -#include "../src/runtime/module_util.cc" -#include "../src/runtime/system_lib_module.cc" +#include "../src/runtime/library_module.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" #include "../src/runtime/registry.cc" #include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_module.cc" +#include "../src/runtime/dso_library.cc" #include "../src/runtime/thread_pool.cc" #include "../src/runtime/object.cc" #include "../src/runtime/threading_backend.cc" diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 8da257176ca3c..5d2bca2e216dd 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -39,12 +39,12 @@ #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/workspace_pool.cc" -#include "../src/runtime/module_util.cc" -#include "../src/runtime/system_lib_module.cc" +#include "../src/runtime/library_module.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" #include "../src/runtime/registry.cc" #include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_module.cc" +#include "../src/runtime/dso_library.cc" #include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_event_impl.cc" #include "../src/runtime/rpc/rpc_server_env.cc" diff --git a/apps/howto_deploy/run_example.sh b/apps/howto_deploy/run_example.sh index 6bbdc1f2cb4ad..ab95f157d7dc4 100755 --- a/apps/howto_deploy/run_example.sh +++ b/apps/howto_deploy/run_example.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,4 +29,4 @@ echo "Run the cpp deployment with all in normal library..." lib/cpp_deploy_normal echo "Run the python deployment with all in normal library..." -python python_deploy.py +python3 python_deploy.py diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc index 67c9a9d647165..d166eaf756a56 100644 --- a/apps/howto_deploy/tvm_runtime_pack.cc +++ b/apps/howto_deploy/tvm_runtime_pack.cc @@ -40,7 +40,7 @@ #include "../../src/runtime/c_runtime_api.cc" #include "../../src/runtime/cpu_device_api.cc" #include "../../src/runtime/workspace_pool.cc" -#include "../../src/runtime/module_util.cc" +#include "../../src/runtime/library_module.cc" #include "../../src/runtime/module.cc" #include "../../src/runtime/registry.cc" #include "../../src/runtime/file_util.cc" @@ -55,8 +55,8 @@ // Likely we only need to enable one of the following // If you use Module::Load, use dso_module // For system packed library, use system_lib_module -#include "../../src/runtime/dso_module.cc" -#include "../../src/runtime/system_lib_module.cc" +#include "../../src/runtime/dso_library.cc" +#include "../../src/runtime/system_library.cc" // Graph runtime #include "../../src/runtime/graph/graph_runtime.cc" diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.mm b/apps/ios_rpc/tvmrpc/TVMRuntime.mm index a98862abd94b4..d593eef922b86 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.mm +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.mm @@ -27,12 +27,12 @@ #include "../../../src/runtime/workspace_pool.cc" #include "../../../src/runtime/thread_pool.cc" #include "../../../src/runtime/threading_backend.cc" -#include "../../../src/runtime/module_util.cc" -#include "../../../src/runtime/system_lib_module.cc" +#include "../../../src/runtime/library_module.cc" +#include "../../../src/runtime/system_library.cc" #include "../../../src/runtime/module.cc" #include "../../../src/runtime/registry.cc" #include "../../../src/runtime/file_util.cc" -#include "../../../src/runtime/dso_module.cc" +#include "../../../src/runtime/dso_library.cc" #include "../../../src/runtime/ndarray.cc" #include "../../../src/runtime/object.cc" diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc index c8be428c2fcb7..416067dcdca1a 100644 --- a/golang/src/tvm_runtime_pack.cc +++ b/golang/src/tvm_runtime_pack.cc @@ -24,7 +24,7 @@ #include "src/runtime/c_runtime_api.cc" #include "src/runtime/cpu_device_api.cc" #include "src/runtime/workspace_pool.cc" -#include "src/runtime/module_util.cc" +#include "src/runtime/library_module.cc" #include "src/runtime/module.cc" #include "src/runtime/registry.cc" #include "src/runtime/file_util.cc" @@ -39,8 +39,8 @@ // Likely we only need to enable one of the following // If you use Module::Load, use dso_module // For system packed library, use system_lib_module -#include "src/runtime/dso_module.cc" -#include "src/runtime/system_lib_module.cc" +#include "src/runtime/dso_library.cc" +#include "src/runtime/system_library.cc" // Graph runtime #include "src/runtime/graph/graph_runtime.cc" diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index 5d4c9439215a2..ef35c71699d4e 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -201,6 +201,7 @@ class ModuleNode : public Object { protected: friend class Module; + friend class ModuleInternal; /*! \brief The modules this module depend on */ std::vector imports_; diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc index b2ad4c2a4990a..c3401c9f1f34e 100644 --- a/src/codegen/llvm/llvm_module.cc +++ b/src/codegen/llvm/llvm_module.cc @@ -28,7 +28,7 @@ #include "llvm_common.h" #include "codegen_llvm.h" #include "../../runtime/file_util.h" -#include "../../runtime/module_util.h" +#include "../../runtime/library_module.h" namespace tvm { namespace codegen { @@ -286,7 +286,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { *ctx_addr = this; } runtime::InitContextFunctions([this](const char *name) { - return GetGlobalAddr(name); + return reinterpret_cast(GetGlobalAddr(name)); }); } // Get global address from execution engine. diff --git a/src/runtime/dso_module.cc b/src/runtime/dso_library.cc similarity index 59% rename from src/runtime/dso_module.cc rename to src/runtime/dso_library.cc index 4e189573ffa56..4df5c0552a3d6 100644 --- a/src/runtime/dso_module.cc +++ b/src/runtime/dso_library.cc @@ -18,14 +18,14 @@ */ /*! - * \file dso_module.cc - * \brief Module to load from dynamic shared library. + * \file dso_libary.cc + * \brief Create library module to load from dynamic shared library. */ #include #include #include #include -#include "module_util.h" +#include "library_module.h" #if defined(_WIN32) #include @@ -36,51 +36,19 @@ namespace tvm { namespace runtime { -// Module to load from dynamic shared libary. +// Dynamic shared libary. // This is the default module TVM used for host-side AOT -class DSOModuleNode final : public ModuleNode { +class DSOLibrary final : public Library { public: - ~DSOModuleNode() { + ~DSOLibrary() { if (lib_handle_) Unload(); } - - const char* type_key() const final { - return "dso"; - } - - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { - BackendPackedCFunc faddr; - if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = reinterpret_cast( - GetSymbol(runtime::symbol::tvm_module_main)); - CHECK(entry_name!= nullptr) - << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; - faddr = reinterpret_cast(GetSymbol(entry_name)); - } else { - faddr = reinterpret_cast(GetSymbol(name.c_str())); - } - if (faddr == nullptr) return PackedFunc(); - return WrapPackedFunc(faddr, sptr_to_self); - } - void Init(const std::string& name) { Load(name); - if (auto *ctx_addr = - reinterpret_cast(GetSymbol(runtime::symbol::tvm_module_ctx))) { - *ctx_addr = this; - } - InitContextFunctions([this](const char* fname) { - return GetSymbol(fname); - }); - // Load the imported modules - const char* dev_mblob = - reinterpret_cast( - GetSymbol(runtime::symbol::tvm_dev_mblob)); - if (dev_mblob != nullptr) { - ImportModuleBlob(dev_mblob, &imports_); - } + } + + void* GetSymbol(const char* name) final { + return GetSymbol_(name); } private: @@ -88,6 +56,12 @@ class DSOModuleNode final : public ModuleNode { #if defined(_WIN32) // library handle HMODULE lib_handle_{nullptr}; + + void* GetSymbol_(const char* name) { + return reinterpret_cast( + GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) + } + // Load the library void Load(const std::string& name) { // use wstring version that is needed by LLVM. @@ -96,12 +70,10 @@ class DSOModuleNode final : public ModuleNode { CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; } - void* GetSymbol(const char* name) { - return reinterpret_cast( - GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) - } + void Unload() { FreeLibrary(lib_handle_); + lib_handle_ = nullptr; } #else // Library handle @@ -113,20 +85,23 @@ class DSOModuleNode final : public ModuleNode { << "Failed to load dynamic shared library " << name << " " << dlerror(); } - void* GetSymbol(const char* name) { + + void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } + void Unload() { dlclose(lib_handle_); + lib_handle_ = nullptr; } #endif }; TVM_REGISTER_GLOBAL("module.loadfile_so") .set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); + auto n = make_object(); n->Init(args[0]); - *rv = runtime::Module(n); + *rv = CreateModuleFromLibrary(n); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc new file mode 100644 index 0000000000000..e770a0a0a3611 --- /dev/null +++ b/src/runtime/library_module.cc @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file module_util.cc + * \brief Utilities for module. + */ +#ifndef _LIBCPP_SGX_CONFIG +#include +#endif +#include +#include +#include +#include "library_module.h" + +namespace tvm { +namespace runtime { + +// Library module that exposes symbols from a library. +class LibraryModuleNode final : public ModuleNode { + public: + explicit LibraryModuleNode(ObjectPtr lib) + : lib_(lib) { + } + + const char* type_key() const final { + return "library"; + } + + PackedFunc GetFunction( + const std::string& name, + const ObjectPtr& sptr_to_self) final { + BackendPackedCFunc faddr; + if (name == runtime::symbol::tvm_module_main) { + const char* entry_name = reinterpret_cast( + lib_->GetSymbol(runtime::symbol::tvm_module_main)); + CHECK(entry_name!= nullptr) + << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; + faddr = reinterpret_cast(lib_->GetSymbol(entry_name)); + } else { + faddr = reinterpret_cast(lib_->GetSymbol(name.c_str())); + } + if (faddr == nullptr) return PackedFunc(); + return WrapPackedFunc(faddr, sptr_to_self); + } + + private: + ObjectPtr lib_; +}; + +/*! + * \brief Helper classes to get into internal of a module. + */ +class ModuleInternal { + public: + // Get mutable reference of imports. + static std::vector* GetImportsAddr(ModuleNode* node) { + return &(node->imports_); + } +}; + +PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, + const ObjectPtr& sptr_to_self) { + return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { + int ret = (*faddr)( + const_cast(args.values), + const_cast(args.type_codes), + args.num_args); + CHECK_EQ(ret, 0) << TVMGetLastError(); + }); +} + +void InitContextFunctions(std::function fgetsymbol) { + #define TVM_INIT_CONTEXT_FUNC(FuncName) \ + if (auto *fp = reinterpret_cast \ + (fgetsymbol("__" #FuncName))) { \ + *fp = FuncName; \ + } + // Initialize the functions + TVM_INIT_CONTEXT_FUNC(TVMFuncCall); + TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError); + TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv); + TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace); + TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace); + TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); + TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); + + #undef TVM_INIT_CONTEXT_FUNC +} + +/*! + * \brief Load and append module blob to module list + * \param mblob The module blob. + * \param module_list The module list to append to + */ +void ImportModuleBlob(const char* mblob, std::vector* mlist) { +#ifndef _LIBCPP_SGX_CONFIG + CHECK(mblob != nullptr); + uint64_t nbytes = 0; + for (size_t i = 0; i < sizeof(nbytes); ++i) { + uint64_t c = mblob[i]; + nbytes |= (c & 0xffUL) << (i * 8); + } + dmlc::MemoryFixedSizeStream fs( + const_cast(mblob + sizeof(nbytes)), static_cast(nbytes)); + dmlc::Stream* stream = &fs; + uint64_t size; + CHECK(stream->Read(&size)); + for (uint64_t i = 0; i < size; ++i) { + std::string tkey; + CHECK(stream->Read(&tkey)); + std::string fkey = "module.loadbinary_" + tkey; + const PackedFunc* f = Registry::Get(fkey); + CHECK(f != nullptr) + << "Loader of " << tkey << "(" + << fkey << ") is not presented."; + Module m = (*f)(static_cast(stream)); + mlist->push_back(m); + } +#else + LOG(FATAL) << "SGX does not support ImportModuleBlob"; +#endif +} + +Module CreateModuleFromLibrary(ObjectPtr lib) { + InitContextFunctions([lib](const char* fname) { + return lib->GetSymbol(fname); + }); + auto n = make_object(lib); + // Load the imported modules + const char* dev_mblob = + reinterpret_cast( + lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); + if (dev_mblob != nullptr) { + ImportModuleBlob( + dev_mblob, ModuleInternal::GetImportsAddr(n.operator->())); + } + + Module root_mod = Module(n); + // allow lookup of symbol from root(so all symbols are visible). + if (auto *ctx_addr = + reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { + *ctx_addr = root_mod.operator->(); + } + return root_mod; +} +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/module_util.h b/src/runtime/library_module.h similarity index 55% rename from src/runtime/module_util.h rename to src/runtime/library_module.h index 5f56c150588a1..e5f5ad94893bb 100644 --- a/src/runtime/module_util.h +++ b/src/runtime/library_module.h @@ -18,17 +18,16 @@ */ /*! - * \file module_util.h - * \brief Helper utilities for module building + * \file library_module.h + * \brief Module that builds from a libary of symbols. */ -#ifndef TVM_RUNTIME_MODULE_UTIL_H_ -#define TVM_RUNTIME_MODULE_UTIL_H_ +#ifndef TVM_RUNTIME_LIBRARY_MODULE_H_ +#define TVM_RUNTIME_LIBRARY_MODULE_H_ #include #include #include -#include -#include +#include extern "C" { // Function signature for generated packed function in shared library @@ -39,42 +38,48 @@ typedef int (*BackendPackedCFunc)(void* args, namespace tvm { namespace runtime { +/*! + * \brief Library is the common interface + * for storing data in the form of shared libaries. + * + * \sa dso_library.cc + * \sa system_library.cc + */ +class Library : public Object { + public: + /*! + * \brief Get the symbol address for a given name. + * \param name The name of the symbol. + * \return The symbol. + */ + virtual void *GetSymbol(const char* name) = 0; + // NOTE: we do not explicitly create an type index and type_key here for libary. + // This is because we do not need dynamic type downcasting. +}; + /*! * \brief Wrap a BackendPackedCFunc to packed function. * \param faddr The function address * \param mptr The module pointer node. */ PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, const ObjectPtr& mptr); -/*! - * \brief Load and append module blob to module list - * \param mblob The module blob. - * \param module_list The module list to append to - */ -void ImportModuleBlob(const char* mblob, std::vector* module_list); /*! * \brief Utility to initialize conext function symbols during startup - * \param flookup A symbol lookup function. - * \tparam FLookup a function of signature string->void* + * \param fgetsymbol A symbol lookup function. */ -template -void InitContextFunctions(FLookup flookup) { - #define TVM_INIT_CONTEXT_FUNC(FuncName) \ - if (auto *fp = reinterpret_cast \ - (flookup("__" #FuncName))) { \ - *fp = FuncName; \ - } - // Initialize the functions - TVM_INIT_CONTEXT_FUNC(TVMFuncCall); - TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError); - TVM_INIT_CONTEXT_FUNC(TVMBackendGetFuncFromEnv); - TVM_INIT_CONTEXT_FUNC(TVMBackendAllocWorkspace); - TVM_INIT_CONTEXT_FUNC(TVMBackendFreeWorkspace); - TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); - TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); +void InitContextFunctions(std::function fgetsymbol); - #undef TVM_INIT_CONTEXT_FUNC -} +/*! + * \brief Create a module from a library. + * + * \param lib The library. + * \return The corresponding loaded module. + * + * \note This function can create multiple linked modules + * by parsing the binary blob section of the library. + */ +Module CreateModuleFromLibrary(ObjectPtr lib); } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_MODULE_UTIL_H_ +#endif // TVM_RUNTIME_LIBRARY_MODULE_H_ diff --git a/src/runtime/module_util.cc b/src/runtime/module_util.cc deleted file mode 100644 index 445bfd343653f..0000000000000 --- a/src/runtime/module_util.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file module_util.cc - * \brief Utilities for module. - */ -#ifndef _LIBCPP_SGX_CONFIG -#include -#endif -#include -#include -#include -#include -#include "module_util.h" - -namespace tvm { -namespace runtime { - -void ImportModuleBlob(const char* mblob, std::vector* mlist) { -#ifndef _LIBCPP_SGX_CONFIG - CHECK(mblob != nullptr); - uint64_t nbytes = 0; - for (size_t i = 0; i < sizeof(nbytes); ++i) { - uint64_t c = mblob[i]; - nbytes |= (c & 0xffUL) << (i * 8); - } - dmlc::MemoryFixedSizeStream fs( - const_cast(mblob + sizeof(nbytes)), static_cast(nbytes)); - dmlc::Stream* stream = &fs; - uint64_t size; - CHECK(stream->Read(&size)); - for (uint64_t i = 0; i < size; ++i) { - std::string tkey; - CHECK(stream->Read(&tkey)); - std::string fkey = "module.loadbinary_" + tkey; - const PackedFunc* f = Registry::Get(fkey); - CHECK(f != nullptr) - << "Loader of " << tkey << "(" - << fkey << ") is not presented."; - Module m = (*f)(static_cast(stream)); - mlist->push_back(m); - } -#else - LOG(FATAL) << "SGX does not support ImportModuleBlob"; -#endif -} - -PackedFunc WrapPackedFunc(BackendPackedCFunc faddr, - const ObjectPtr& sptr_to_self) { - return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - int ret = (*faddr)( - const_cast(args.values), - const_cast(args.type_codes), - args.num_args); - CHECK_EQ(ret, 0) << TVMGetLastError(); - }); -} - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/stackvm/stackvm_module.cc b/src/runtime/stackvm/stackvm_module.cc index 4f86d0764ebf9..b73c7ceaa85af 100644 --- a/src/runtime/stackvm/stackvm_module.cc +++ b/src/runtime/stackvm/stackvm_module.cc @@ -28,7 +28,6 @@ #include #include "stackvm_module.h" #include "../file_util.h" -#include "../module_util.h" namespace tvm { namespace runtime { diff --git a/src/runtime/system_lib_module.cc b/src/runtime/system_lib_module.cc deleted file mode 100644 index 8a75a36ca49f4..0000000000000 --- a/src/runtime/system_lib_module.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file system_lib_module.cc - * \brief SystemLib module. - */ -#include -#include -#include -#include -#include "module_util.h" - -namespace tvm { -namespace runtime { - -class SystemLibModuleNode : public ModuleNode { - public: - SystemLibModuleNode() = default; - - const char* type_key() const final { - return "system_lib"; - } - - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { - std::lock_guard lock(mutex_); - - if (module_blob_ != nullptr) { - // If we previously recorded submodules, load them now. - ImportModuleBlob(reinterpret_cast(module_blob_), &imports_); - module_blob_ = nullptr; - } - - auto it = tbl_.find(name); - if (it != tbl_.end()) { - return WrapPackedFunc( - reinterpret_cast(it->second), sptr_to_self); - } else { - return PackedFunc(); - } - } - - void RegisterSymbol(const std::string& name, void* ptr) { - std::lock_guard lock(mutex_); - if (name == symbol::tvm_module_ctx) { - void** ctx_addr = reinterpret_cast(ptr); - *ctx_addr = this; - } else if (name == symbol::tvm_dev_mblob) { - // Record pointer to content of submodules to be loaded. - // We defer loading submodules to the first call to GetFunction(). - // The reason is that RegisterSymbol() gets called when initializing the - // syslib (i.e. library loading time), and the registeries aren't ready - // yet. Therefore, we might not have the functionality to load submodules - // now. - CHECK(module_blob_ == nullptr) << "Resetting mobule blob?"; - module_blob_ = ptr; - } else { - auto it = tbl_.find(name); - if (it != tbl_.end() && ptr != it->second) { - LOG(WARNING) << "SystemLib symbol " << name - << " get overriden to a different address " - << ptr << "->" << it->second; - } - tbl_[name] = ptr; - } - } - - static const ObjectPtr& Global() { - static auto inst = make_object(); - return inst; - } - - private: - // Internal mutex - std::mutex mutex_; - // Internal symbol table - std::unordered_map tbl_; - // Module blob to be imported - void* module_blob_{nullptr}; -}; - -TVM_REGISTER_GLOBAL("module._GetSystemLib") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = runtime::Module(SystemLibModuleNode::Global()); - }); -} // namespace runtime -} // namespace tvm - -int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) { - tvm::runtime::SystemLibModuleNode::Global()->RegisterSymbol(name, ptr); - return 0; -} diff --git a/src/runtime/system_library.cc b/src/runtime/system_library.cc new file mode 100644 index 0000000000000..b9d751d09dc42 --- /dev/null +++ b/src/runtime/system_library.cc @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file system_library.cc + * \brief Create library module that directly get symbol from the system lib. + */ +#include +#include +#include +#include +#include "library_module.h" + +namespace tvm { +namespace runtime { + +class SystemLibrary : public Library { + public: + SystemLibrary() = default; + + void* GetSymbol(const char* name) final { + std::lock_guard lock(mutex_); + auto it = tbl_.find(name); + if (it != tbl_.end()) { + return it->second; + } else { + return nullptr; + } + } + + void RegisterSymbol(const std::string& name, void* ptr) { + std::lock_guard lock(mutex_); + auto it = tbl_.find(name); + if (it != tbl_.end() && ptr != it->second) { + LOG(WARNING) + << "SystemLib symbol " << name + << " get overriden to a different address " + << ptr << "->" << it->second; + } + tbl_[name] = ptr; + } + + static const ObjectPtr& Global() { + static auto inst = make_object(); + return inst; + } + + private: + // Internal mutex + std::mutex mutex_; + // Internal symbol table + std::unordered_map tbl_; +}; + +TVM_REGISTER_GLOBAL("module._GetSystemLib") +.set_body([](TVMArgs args, TVMRetValue* rv) { + static auto mod = CreateModuleFromLibrary( + SystemLibrary::Global()); + *rv = mod; + }); +} // namespace runtime +} // namespace tvm + +int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) { + tvm::runtime::SystemLibrary::Global()->RegisterSymbol(name, ptr); + return 0; +} diff --git a/web/web_runtime.cc b/web/web_runtime.cc index 63284bd8c2cc0..d5b40889472cb 100644 --- a/web/web_runtime.cc +++ b/web/web_runtime.cc @@ -26,14 +26,14 @@ #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" #include "../src/runtime/workspace_pool.cc" -#include "../src/runtime/module_util.cc" -#include "../src/runtime/system_lib_module.cc" +#include "../src/runtime/library_module.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" #include "../src/runtime/ndarray.cc" #include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" #include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_module.cc" +#include "../src/runtime/dso_library.cc" #include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_event_impl.cc" #include "../src/runtime/rpc/rpc_server_env.cc"