From 628b54ba9e5261eabad6495b32554406d9bfb697 Mon Sep 17 00:00:00 2001 From: Ruizhe Zhao Date: Mon, 16 Mar 2020 19:52:45 +0000 Subject: [PATCH] Return empty CSourceModule when no lowered_funcs exists in Relay mod (#4847) * Use dummy func when no lowered_funcs exists in Relay mod * Dummy func -> CSourceModule with empty code str * Added comments describing the empty CSouceModule * Always import external modules w/o assertions * Use CSourceModule as a fallback for LLVMModule * Changed cond for target == llvm * Create an empty LLVM module w/o using dummy func * Avoid using IR str concat to create LLVM module * Improved comments for codegen.LLVMModuleCreate * Satisfy the linter for LLVMModuleCreate --- src/relay/backend/build_module.cc | 47 +++++++++++++++++++++++-------- src/target/llvm/llvm_module.cc | 22 +++++++++++++++ 2 files changed, 58 insertions(+), 11 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 41833c4d4aff..d42cc27f77d0 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -28,8 +28,10 @@ #include #include #include +#include #include +#include "../../target/source/codegen_source_base.h" #include "utils.h" namespace tvm { @@ -451,28 +453,51 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.params = graph_codegen_->GetParams(); auto lowered_funcs = graph_codegen_->GetLoweredFunc(); + + // When there is no lowered_funcs due to reasons such as optimization. if (lowered_funcs.size() == 0) { - LOG(WARNING) << "no lowered funcs exist in the compiled module"; + Target target_host = GetTargetHost(); + + // If no target_host has been set, we choose a default one, which is + // llvm if "codegen.LLVMModuleCreate" is accessible. + const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); + if (!target_host.defined()) + target_host = (pf != nullptr) ? target::llvm() : target::stackvm(); + + if (target_host.defined() && target_host->target_name == "llvm") { + // If we can decide the target is LLVM, we then create an empty LLVM module. + ret_.mod = (*pf)(target_host->str(), "empty_module"); + } else { + // If we cannot decide the target is LLVM, we create an empty CSourceModule. + // The code content is initialized with ";" to prevent complaining + // from CSourceModuleNode::SaveToFile. + ret_.mod = tvm::codegen::CSourceModuleCreate(";", ""); + } } else { ret_.mod = tvm::build( lowered_funcs, target_host_, BuildConfig::Current()); } + Array ext_mods = graph_codegen_->GetExternalModules(); - if (!ext_mods.empty()) { - CHECK(lowered_funcs.size() > 0 || ext_mods.size() == 1) - << "Expect to have a TVM DSOModule when multiple external runtime modules exist"; - if (lowered_funcs.size() == 0) { - // Execute the whole module using external runtime. - ret_.mod = ext_mods[0]; - } else { - // Import all external runtime modules. - for (const auto& it : ext_mods) { - ret_.mod.Import(it); + // Import all external runtime modules. + for (const auto& it : ext_mods) + ret_.mod.Import(it); + } + + private: + Target GetTargetHost() { + Target target_host = target_host_; + if (!target_host_.defined()) { + for (const auto &it : targets_) { + if (it.second->device_type == kDLCPU) { + target_host = it.second; + break; } } } + return target_host; } protected: diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 7f4680972299..c04b25727972 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -356,6 +356,28 @@ TVM_REGISTER_GLOBAL("codegen.build_llvm") *rv = runtime::Module(n); }); +TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") +.set_body([](TVMArgs args, TVMRetValue *rv) { + auto n = make_object(); + auto target = args[0].operator std::string(); + auto module_name = args[1].operator std::string(); + + // Generate a LLVM module from an input target string + InitializeLLVM(); + auto tm = GetLLVMTargetMachine(target); + auto ctx = std::make_shared(); + std::unique_ptr module(new llvm::Module(module_name, *ctx)); + + // Use a default data layout and target triple + auto triple = tm->getTargetTriple(); + module->setTargetTriple(triple.str()); + module->setDataLayout(tm->createDataLayout()); + + n->Init(std::move(module), ctx); + + *rv = runtime::Module(n); +}); + TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = static_cast(LookupLLVMIntrinsic(args[0]));