Skip to content

Commit

Permalink
[CodeGen] Generate blob use LLVM directly (#4657)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene authored and zhiics committed Jan 10, 2020
1 parent f16f245 commit 54fb55a
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 21 deletions.
3 changes: 3 additions & 0 deletions cmake/util/FindLLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,8 @@ macro(find_llvm use_llvm)
message(STATUS "Found LLVM_INCLUDE_DIRS=" ${LLVM_INCLUDE_DIRS})
message(STATUS "Found LLVM_DEFINITIONS=" ${LLVM_DEFINITIONS})
message(STATUS "Found TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION})
if (${TVM_LLVM_VERSION} LESS 40)
message(FATAL_ERROR "TVM requires LLVM 4.0 or higher.")
endif()
endif()
endmacro(find_llvm)
15 changes: 15 additions & 0 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
* \return cstr The C string representation of the file.
*/
std::string PackImportsToC(const runtime::Module& m, bool system_lib);

/*!
* \brief Pack imported device library to a LLVM module.
* Compile the LLVM module and link with the host library
* will allow the DSO loader to automatically discover and import
* the dependency from the shared library.
*
* \param m The host module with the imports.
* \param system_lib Whether expose as system library.
* \param target_triple LLVM target triple
* \return runtime::Module The generated LLVM module.
*/
runtime::Module PackImportsToLLVM(const runtime::Module& m,
bool system_lib,
const std::string& target_triple);
} // namespace codegen
} // namespace tvm

Expand Down
40 changes: 38 additions & 2 deletions python/tvm/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,41 @@ def create_shared(output,
else:
raise ValueError("Unsupported platform")

def get_target_by_dump_machine(compiler):
""" Functor of get_target_triple that can get the target triple using compiler.
Parameters
----------
compiler : Optional[str]
The compiler.
Returns
-------
out: Callable
A function that can get target triple according to dumpmachine option of compiler.
"""
def get_target_triple():
""" Get target triple according to dumpmachine option of compiler."""
if compiler:
cmd = [compiler, "-dumpmachine"]
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
msg = "dumpmachine error:\n"
msg += py_str(out)
return None
return py_str(out)
else:
return None

return get_target_triple


# assign so as default output format
create_shared.output_format = "so" if sys.platform != "win32" else "dll"

create_shared.get_target_triple = get_target_by_dump_machine(
"g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None)

def build_create_shared_func(options=None, compile_cmd="g++"):
"""Build create_shared function with particular default options and compile_cmd.
Expand All @@ -75,10 +106,11 @@ def build_create_shared_func(options=None, compile_cmd="g++"):
def create_shared_wrapper(output, objects, options=options, compile_cmd=compile_cmd):
create_shared(output, objects, options, compile_cmd)
create_shared_wrapper.output_format = create_shared.output_format
create_shared_wrapper.get_target_triple = get_target_by_dump_machine(compile_cmd)
return create_shared_wrapper


def cross_compiler(compile_func, base_options=None, output_format="so"):
def cross_compiler(compile_func, base_options=None, output_format="so", get_target_triple=None):
"""Create a cross compiler function.
Parameters
Expand All @@ -92,6 +124,9 @@ def cross_compiler(compile_func, base_options=None, output_format="so"):
output_format : Optional[str]
Library output format.
get_target_triple: Optional[Callable]
Function that can target triple according to dumpmachine option of compiler.
Returns
-------
fcompile : Callable[[str, str, Optional[str]], None]
Expand All @@ -105,6 +140,7 @@ def _fcompile(outputs, objects, options=None):
all_options += options
compile_func(outputs, objects, options=all_options)
_fcompile.output_format = output_format
_fcompile.get_target_triple = get_target_triple
return _fcompile


Expand Down
4 changes: 4 additions & 0 deletions python/tvm/contrib/ndk.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import subprocess
import os
from .._ffi.base import py_str
from .cc import get_target_by_dump_machine

def create_shared(output,
objects,
Expand Down Expand Up @@ -64,5 +65,8 @@ def create_shared(output,
msg += py_str(out)
raise RuntimeError(msg)


# assign output format
create_shared.output_format = "so"
create_shared.get_target_triple = get_target_by_dump_machine(
os.environ["TVM_NDK_CC"]) if "TVM_NDK_CC" in os.environ else None
24 changes: 18 additions & 6 deletions python/tvm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def export_library(self,
files = []
is_system_lib = False
has_c_module = False
llvm_target_triple = None
for index, module in enumerate(modules):
if fcompile is not None and hasattr(fcompile, "object_format"):
object_format = fcompile.object_format
Expand All @@ -138,18 +139,29 @@ def export_library(self,
files.append(path_obj)
is_system_lib = (module.type_key == "llvm" and
module.get_function("__tvm_is_system_module")())

if self.imported_modules:
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc)
llvm_target_triple = (module.type_key == "llvm" and
module.get_function("_get_target_triple")())
if not fcompile:
if file_name.endswith(".tar"):
fcompile = _tar.tar
else:
fcompile = _cc.create_shared

if llvm_target_triple is None and hasattr(fcompile, "get_target_triple"):
llvm_target_triple = fcompile.get_target_triple()

if self.imported_modules:
if enabled("llvm") and llvm_target_triple:
path_obj = temp.relpath("devc.o")
m = _PackImportsToLLVM(self, is_system_lib, llvm_target_triple)
m.save(path_obj)
files.append(path_obj)
else:
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc)

if has_c_module:
options = []
if "options" in kwargs:
Expand Down
3 changes: 3 additions & 0 deletions src/api/api_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,8 @@ TVM_REGISTER_GLOBAL("codegen._Build")

TVM_REGISTER_GLOBAL("module._PackImportsToC")
.set_body_typed(PackImportsToC);

TVM_REGISTER_GLOBAL("module._PackImportsToLLVM")
.set_body_typed(PackImportsToLLVM);
} // namespace codegen
} // namespace tvm
45 changes: 39 additions & 6 deletions src/codegen/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/build_module.h>
#include <dmlc/memory_io.h>
#include <sstream>
Expand Down Expand Up @@ -158,13 +159,21 @@ class ModuleSerializer {
std::vector<uint64_t> import_tree_child_indices_;
};

std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;
namespace {
std::string SerializeModule(const runtime::Module& mod) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;

ModuleSerializer module_serializer(mod);
module_serializer.SerializeModule(stream);

return bin;
}
} // namespace

ModuleSerializer module_serializer(mod);
module_serializer.SerializeModule(stream);
std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin = SerializeModule(mod);

// translate to C program
std::ostringstream os;
Expand Down Expand Up @@ -211,5 +220,29 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
<< "#endif\n";
return os.str();
}

runtime::Module PackImportsToLLVM(const runtime::Module& mod,
bool system_lib,
const std::string& target_triple) {
std::string bin = SerializeModule(mod);

uint64_t nbytes = bin.length();
std::string header;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
header.push_back(((nbytes >> (i * 8)) & 0xffUL));
}
std::string blob = header + bin;
TVMByteArray blob_byte_array;
blob_byte_array.size = blob.length();
blob_byte_array.data = blob.data();

// Call codegen_blob to generate LLVM module
std::string codegen_f_name = "codegen.codegen_blob";
// the codegen function.
const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name);
CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented.";
return (*codegen_f)(blob_byte_array, system_lib, target_triple);
}

} // namespace codegen
} // namespace tvm
Loading

0 comments on commit 54fb55a

Please sign in to comment.