Skip to content

Commit

Permalink
Handle empty LLVMModule in GetFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
kumasento committed Mar 25, 2020
1 parent 3aabbd9 commit c93294b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
10 changes: 8 additions & 2 deletions cmake/modules/contrib/DNNL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@
# specific language governing permissions and limitations
# under the License.

if(USE_DNNL_CODEGEN STREQUAL "ON")
if(NOT USE_DNNL_CODEGEN STREQUAL "OFF")
file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/codegen.cc)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})

find_library(EXTERN_LIBRARY_DNNL dnnl)
if(USE_DNNL_CODEGEN STREQUAL "ON")
find_library(EXTERN_LIBRARY_DNNL dnnl)
else()
set(DNNL_INSTALL_PATH "${USE_DNNL_CODEGEN}")
find_library(EXTERN_LIBRARY_DNNL dnnl ${DNNL_INSTALL_PATH} ${DNNL_PATH}/lib)
endif()
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})

file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
message(STATUS "Build with DNNL codegen: " ${EXTERN_LIBRARY_DNNL})
Expand Down
13 changes: 11 additions & 2 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
PackedFunc GetFunction(
const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final {
std::cout << name << std::endl;
if (name == "__tvm_is_system_module") {
bool flag =
(mptr_->getFunction("__tvm_module_startup") != nullptr);
Expand All @@ -71,6 +72,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
});
}
if (ee_ == nullptr) LazyInitJIT();

// This LLVMModule is empty and no function can be retrieved.
if (entry_func_.empty()) return nullptr;

std::lock_guard<std::mutex> lock(mutex_);
const std::string& fname = (name == runtime::symbol::tvm_module_main ?
entry_func_ : name);
Expand Down Expand Up @@ -318,6 +323,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
<< "Failed to initialize jit engine for " << mptr_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false);
// setup context address.
// we will skip context setup if this LLVMModule is empty.
if (GetGlobalAddr(runtime::symbol::tvm_module_main) == 0)
return;

entry_func_ =
reinterpret_cast<const char*>(GetGlobalAddr(runtime::symbol::tvm_module_main));
if (void** ctx_addr = reinterpret_cast<void**>(
Expand All @@ -329,15 +338,15 @@ class LLVMModuleNode final : public runtime::ModuleNode {
});
}
// Get global address from execution engine.
uint64_t GetGlobalAddr(const std::string& name) {
uint64_t GetGlobalAddr(const std::string& name) const {
// first verifies if GV exists.
if (mptr_->getGlobalVariable(name) != nullptr) {
return ee_->getGlobalValueAddress(name);
} else {
return 0;
}
}
uint64_t GetFunctionAddr(const std::string& name) {
uint64_t GetFunctionAddr(const std::string& name) const {
// first verifies if GV exists.
if (mptr_->getFunction(name) != nullptr) {
return ee_->getFunctionAddress(name);
Expand Down
19 changes: 10 additions & 9 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def check_vm_result():
def check_graph_runtime_result():
with relay.build_config(opt_level=3):
json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib)
# lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)

for name, data in map_inputs.items():
Expand Down Expand Up @@ -365,6 +365,7 @@ def test_extern_ccompiler():


def test_extern_dnnl():
print("text_extern_dnnl")
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return
Expand Down Expand Up @@ -679,12 +680,12 @@ def expected():


if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
test_extern_ccompiler_default_ops()
test_extern_ccompiler()
# test_multi_node_compiler()
# test_extern_ccompiler_single_op()
# test_extern_ccompiler_default_ops()
# test_extern_ccompiler()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
test_function_lifting()
test_function_lifting_inline()
test_constant_propagation()
# test_extern_dnnl_mobilenet()
# test_function_lifting()
# test_function_lifting_inline()
# test_constant_propagation()

0 comments on commit c93294b

Please sign in to comment.