Skip to content

Commit

Permalink
Create an empty LLVM module w/o using dummy func
Browse files Browse the repository at this point in the history
  • Loading branch information
kumasento committed Feb 28, 2020
1 parent d6656e2 commit 7a79f6b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 32 deletions.
63 changes: 31 additions & 32 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ namespace relay {
namespace backend {

using tir::LoweredFunc;
using tir::MakeAPI;
using tir::EvaluateNode;

using TargetsMap = Map<tvm::Integer, tvm::Target>;
using namespace tvm::relay::transform;
Expand Down Expand Up @@ -442,39 +440,26 @@ class RelayBuildModule : public runtime::ModuleNode {

auto lowered_funcs = graph_codegen_->GetLoweredFunc();

// When there is no lowered_funcs due to reasons such as optimization,
// we first try to generate a dummy one if the target host is "llvm".
// When there is no lowered_funcs due to reasons such as optimization.
if (lowered_funcs.size() == 0) {
// Decide first the target host
Target target_host_val = target_host_;
if (!target_host_.defined()) {
for (const auto &it : targets_) {
if (it.second->device_type == kDLCPU) {
target_host_val = it.second;
break;
}
}
}
Target target_host = GetTargetHost();

// If no target_host has been set, we choose a default one, which is
// llvm if "codegen.build_llvm" is accessible.
const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm");
if (!target_host_val.defined())
target_host_val = (pf != nullptr) ? target::llvm() : target::stackvm();

if (target_host_val.defined() && target_host_val->target_name == "llvm")
lowered_funcs.Set(
target_host_val->str(),
Array<LoweredFunc>({
MakeAPI(EvaluateNode::make(0), "__dummy__", Array<ObjectRef>(), 0, false) }));
}

if (lowered_funcs.size() == 0) {
// If there is still no lowered_funcs, a fallback solution is to create a module
// with empty code content.
// The code content is initialized with ";" to prevent complaining
// from CSourceModuleNode::SaveToFile.
ret_.mod = tvm::codegen::CSourceModuleCreate(";", "");
// llvm if "codegen.LLVMModuleCreate" is accessible.
const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate");
if (!target_host.defined())
target_host = (pf != nullptr) ? target::llvm() : target::stackvm();

if (target_host.defined() && target_host->target_name == "llvm") {
// If we can decide the target is LLVM, we then create an empty LLVM module.
ret_.mod = (*pf)(target_host->str());
} else {
// If there is still no lowered_funcs, a fallback solution is to create a module
// with empty code content.
// The code content is initialized with ";" to prevent complaining
// from CSourceModuleNode::SaveToFile.
ret_.mod = tvm::codegen::CSourceModuleCreate(";", "");
}
} else {
ret_.mod = tvm::build(
lowered_funcs,
Expand All @@ -488,6 +473,20 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.mod.Import(it);
}

private:
Target GetTargetHost() {
Target target_host = target_host_;
if (!target_host_.defined()) {
for (const auto &it : targets_) {
if (it.second->device_type == kDLCPU) {
target_host = it.second;
break;
}
}
}
return target_host;
}

protected:
std::unique_ptr<GraphCodegen> graph_codegen_;
/*! \brief target device */
Expand Down
34 changes: 34 additions & 0 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,40 @@ TVM_REGISTER_GLOBAL("codegen.build_llvm")
*rv = runtime::Module(n);
});

TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
.set_body([](TVMArgs args, TVMRetValue *rv) {
auto n = make_object<LLVMModuleNode>();

// parse target triple from the first argument
auto target = args[0].operator std::string();
std::string triple, mcpu, mattr;
llvm::TargetOptions opt;
ParseLLVMTargetOptions(target, &triple, &mcpu, &mattr, &opt);

// create a default data layout
auto tm = GetLLVMTargetMachine(target);
llvm::DataLayout layout(tm->createDataLayout());

// initialize an IR code snippet from a simple template
std::string ir_str;
ir_str += "target triple = \"" + triple + "\"\n";
ir_str += "target datalayout = \"" + layout.getStringRepresentation() + "\"";

// use parseIR to create a LLVM Module.
auto ctx = std::make_shared<llvm::LLVMContext>();
llvm::SMDiagnostic err;
auto mem_buf = llvm::MemoryBuffer::getMemBuffer(ir_str);
auto module = llvm::parseIR(mem_buf->getMemBufferRef(), err, *ctx);
if (module == nullptr) {
std::string msg = std::string(err.getMessage());
LOG(FATAL) << "Failed to create a LLVM module from the generated IR code:"
<< std::endl << ir_str << std::endl << "Error message: " << msg;
}
n->Init(std::move(module), ctx);

*rv = runtime::Module(n);
});

TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = static_cast<int64_t>(LookupLLVMIntrinsic(args[0]));
Expand Down

0 comments on commit 7a79f6b

Please sign in to comment.