Skip to content

Commit

Permalink
[CodeGen] Generate blob use LLVM directly
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene committed Jan 9, 2020
1 parent baae28b commit 8fed25b
Show file tree
Hide file tree
Showing 11 changed files with 478 additions and 19 deletions.
3 changes: 3 additions & 0 deletions cmake/util/FindLLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,8 @@ macro(find_llvm use_llvm)
message(STATUS "Found LLVM_INCLUDE_DIRS=" ${LLVM_INCLUDE_DIRS})
message(STATUS "Found LLVM_DEFINITIONS=" ${LLVM_DEFINITIONS})
message(STATUS "Found TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION})
if (${TVM_LLVM_VERSION} LESS 40)
message(FATAL_ERROR "TVM requires LLVM 4.0 or higher.")
endif()
endif()
endmacro(find_llvm)
15 changes: 15 additions & 0 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
* \return cstr The C string representation of the file.
*/
std::string PackImportsToC(const runtime::Module& m, bool system_lib);

/*!
* \brief Pack imported device library to a LLVM module.
* Compile the LLVM module and link with the host library
* will allow the DSO loader to automatically discover and import
* the dependency from the shared library.
*
* \param m The host module with the imports.
* \param system_lib Whether expose as system library.
* \param target LLVM target
* \return runtime::Module The generated LLVM module.
*/
runtime::Module PackImportsToLLVM(const runtime::Module& m,
bool system_lib,
const std::string& target);
} // namespace codegen
} // namespace tvm

Expand Down
23 changes: 23 additions & 0 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,32 @@ def create_shared(output,
else:
raise ValueError("Unsupported platform")

def get_target_triple():
""" Get the target triple using compiler.
Returns
-------
out: str (Linux / Mac) or None (Win32)
"""
if sys.platform == "darwin" or sys.platform.startswith("linux"):
cmd = ["g++", "-dumpmachine"]
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "dumpmachine error:\n"
msg += py_str(out)
raise RuntimeError(msg)
return py_str(out)
elif sys.platform == "win32":
return None
else:
raise ValueError("Unsupported platform")


# assign so as default output format
create_shared.output_format = "so" if sys.platform != "win32" else "dll"
create_shared.get_target_triple = get_target_triple()


def build_create_shared_func(options=None, compile_cmd="g++"):
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/contrib/ndk.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,28 @@ def create_shared(output,
msg += py_str(out)
raise RuntimeError(msg)

def get_target_triple():
""" Get the target triple using compiler.
Returns
-------
out: str
"""
if "TVM_NDK_CC" not in os.environ:
raise RuntimeError("Require environment variable TVM_NDK_CC"
" to be the NDK standalone compiler")
compiler = os.environ["TVM_NDK_CC"]
cmd = [compiler, "-dumpmachine"]
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "dumpmachine error:\n"
msg += py_str(out)
raise RuntimeError(msg)
return py_str(out)


# assign output format
create_shared.output_format = "so"
create_shared.get_target_triple = get_target_triple() if "TVM_NDK_CC" in os.environ else None
24 changes: 18 additions & 6 deletions python/tvm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def export_library(self,
files = []
is_system_lib = False
has_c_module = False
llvm_target_triple = None
for index, module in enumerate(modules):
if fcompile is not None and hasattr(fcompile, "object_format"):
object_format = fcompile.object_format
Expand All @@ -138,18 +139,29 @@ def export_library(self,
files.append(path_obj)
is_system_lib = (module.type_key == "llvm" and
module.get_function("__tvm_is_system_module")())

if self.imported_modules:
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc)
llvm_target_triple = (module.type_key == "llvm" and
module.get_function("get_target_triple")())
if not fcompile:
if file_name.endswith(".tar"):
fcompile = _tar.tar
else:
fcompile = _cc.create_shared

if llvm_target_triple is None and hasattr(fcompile, "get_target_triple"):
llvm_target_triple = fcompile.get_target_triple

if self.imported_modules:
if enabled("llvm") and llvm_target_triple:
path_obj = temp.relpath("devc.o")
m = _PackImportsToLLVM(self, is_system_lib, llvm_target_triple)
m.save(path_obj)
files.append(path_obj)
else:
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc)

if has_c_module:
options = []
if "options" in kwargs:
Expand Down
3 changes: 3 additions & 0 deletions src/api/api_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,8 @@ TVM_REGISTER_GLOBAL("codegen._Build")

TVM_REGISTER_GLOBAL("module._PackImportsToC")
.set_body_typed(PackImportsToC);

TVM_REGISTER_GLOBAL("module._PackImportsToLLVM")
.set_body_typed(PackImportsToLLVM);
} // namespace codegen
} // namespace tvm
46 changes: 40 additions & 6 deletions src/codegen/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
#include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/build_module.h>
#include <dmlc/memory_io.h>
#include <sstream>
#include <vector>
#include <cstdint>
#include <unordered_set>
#include <cstring>
#include <iomanip>

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -158,13 +160,21 @@ class ModuleSerializer {
std::vector<uint64_t> import_tree_child_indices_;
};

std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;
namespace {
std::string SerializeModule(const runtime::Module& mod) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;

ModuleSerializer module_serializer(mod);
module_serializer.SerializeModule(stream);

return bin;
}
} // namespace

ModuleSerializer module_serializer(mod);
module_serializer.SerializeModule(stream);
std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin = SerializeModule(mod);

// translate to C program
std::ostringstream os;
Expand Down Expand Up @@ -211,5 +221,29 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
<< "#endif\n";
return os.str();
}

runtime::Module PackImportsToLLVM(const runtime::Module& mod,
bool system_lib,
const std::string& target) {
std::string bin = SerializeModule(mod);

uint64_t nbytes = bin.length();
std::string header;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
header.push_back(((nbytes >> (i * 8)) & 0xffUL));
}
std::string blob = header + bin;
TVMByteArray blob_byte_array;
blob_byte_array.size = blob.length();
blob_byte_array.data = blob.data();

// Call codegen_blob to generate LLVM module
std::string codegen_f_name = "codegen.codegen_blob";
// the codegen function.
const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name);
CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented.";
return (*codegen_f)(blob_byte_array, system_lib, target);
}

} // namespace codegen
} // namespace tvm
163 changes: 163 additions & 0 deletions src/codegen/llvm/codegen_blob.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file codegen_blob.cc
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/module.h>
#include <cstring>
#include "codegen_blob.h"

namespace tvm {
namespace codegen {

std::pair<std::unique_ptr<llvm::Module>,
std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(const std::string& data,
bool system_lib,
const std::string& target) {
InitializeLLVM();
auto tm = GetLLVMTargetMachine(std::string("-target ") + target);
auto target_triple = tm->getTargetTriple();
auto ctx = std::make_shared<llvm::LLVMContext>();
std::string module_name = "devc";
std::unique_ptr<llvm::Module> module(new llvm::Module(module_name, *ctx));
module->setTargetTriple(target_triple.str());
module->setDataLayout(tm->createDataLayout());
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
auto* tvm_dev_mblob = new llvm::GlobalVariable(*module, blob_value->getType(), true,
llvm::GlobalValue::ExternalLinkage, blob_value,
runtime::symbol::tvm_dev_mblob, nullptr,
llvm::GlobalVariable::NotThreadLocal, 0);

#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob->setAlignment(llvm::Align(1));
#else
tvm_dev_mblob->setAlignment(1);
#endif

if (target_triple.isOSWindows()) {
tvm_dev_mblob->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass);
}

if (system_lib) {
// LLVM type helper
auto void_ty = llvm::Type::getVoidTy(*ctx);
auto int32_ty = llvm::Type::getInt32Ty(*ctx);
auto int8_ty = llvm::Type::getInt8Ty(*ctx);
auto int8_ptr_ty = int8_ty->getPointerTo(0);

llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty);
auto* tvm_dev_mblob_reg =
new llvm::GlobalVariable(*module, int32_ty,
false, llvm::GlobalValue::InternalLinkage,
constant_zero,
std::string(runtime::symbol::tvm_dev_mblob) + "_reg_");
auto tvm_dev_mblob_reg_alignment = module->getDataLayout().getABITypeAlignment(int32_ty);
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob_reg->setAlignment(llvm::Align(tvm_dev_mblob_reg_alignment));
#else
tvm_dev_mblob_reg->setAlignment(tvm_dev_mblob_reg_alignment);
#endif

auto* tvm_dev_mblob_string_ty =
llvm::ArrayType::get(int8_ty, std::strlen(runtime::symbol::tvm_dev_mblob) + 1);
auto* tvm_dev_mblob_string_value =
llvm::ConstantDataArray::getString(*ctx, runtime::symbol::tvm_dev_mblob, true);
auto* tvm_dev_mblob_string =
new llvm::GlobalVariable(*module, tvm_dev_mblob_string_ty,
true, llvm::GlobalValue::PrivateLinkage,
tvm_dev_mblob_string_value,
std::string(runtime::symbol::tvm_dev_mblob) + ".str");
#if TVM_LLVM_VERSION >= 100
tvm_dev_mblob_string->setAlignment(llvm::Align(1));
#else
tvm_dev_mblob_string->setAlignment(1);
#endif

// Global init function
llvm::Function* init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false),
llvm::GlobalValue::InternalLinkage,
llvm::Twine("_GLOBAL__sub_I_", module_name),
module.get());

// Create variable initialization function.
llvm::Function* var_init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false),
llvm::GlobalValue::InternalLinkage,
llvm::Twine("__cxx_global_var_init"),
module.get());

// Create TVMBackendRegisterSystemLibSymbol function
llvm::Function* tvm_backend_fn =
llvm::Function::Create(llvm::FunctionType::get(int32_ty, {int8_ptr_ty, int8_ptr_ty}, false),
llvm::GlobalValue::ExternalLinkage,
llvm::Twine("TVMBackendRegisterSystemLibSymbol"),
module.get());

// Set necessary fn sections
auto get_static_init_section_specifier = [&target_triple]() -> std::string {
if (target_triple.isOSLinux()) {
return ".text.startup";
} else if (target_triple.isOSDarwin()) {
return "__TEXT,__StaticInit,regular,pure_instructions";
} else {
return "";
}
};

auto static_init_section_specifier = get_static_init_section_specifier();

if (!static_init_section_specifier.empty()) {
init_fn->setSection(static_init_section_specifier);
var_init_fn->setSection(static_init_section_specifier);
}


// The priority is 65535 for all platforms as clang do.
llvm::appendToGlobalCtors(*module, init_fn, 65535);

// Define init_fn body
llvm::IRBuilder<> ir_builder(*ctx);
llvm::BasicBlock* init_fn_bb = llvm::BasicBlock::Create(*ctx, "entry", init_fn);
ir_builder.SetInsertPoint(init_fn_bb);
ir_builder.CreateCall(var_init_fn);
ir_builder.CreateRetVoid();

// Define var_init_fn body
llvm::BasicBlock* var_init_fn_bb = llvm::BasicBlock::Create(*ctx, "entry", var_init_fn);
ir_builder.SetInsertPoint(var_init_fn_bb);
llvm::Constant* indices[] = {constant_zero, constant_zero};
llvm::SmallVector<llvm::Value*, 2> args;
args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_dev_mblob_string_ty,
tvm_dev_mblob_string,
indices));
args.push_back(llvm::ConstantExpr::getGetElementPtr(blob_value->getType(),
tvm_dev_mblob,
indices));
auto* tvm_backend_fn_ret_value = ir_builder.CreateCall(tvm_backend_fn, args);
ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_dev_mblob_reg);
ir_builder.CreateRetVoid();
}

return std::make_pair(std::move(module), ctx);
}

} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
Loading

0 comments on commit 8fed25b

Please sign in to comment.