Skip to content

Commit

Permalink
Support standardize runtime module
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene committed Dec 18, 2019
1 parent 9384353 commit d35d4c5
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 67 deletions.
80 changes: 51 additions & 29 deletions python/tvm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import absolute_import as _abs

import struct
import ctypes
from collections import namedtuple

from ._ffi.function import ModuleBase, _set_class_module
Expand All @@ -34,6 +35,9 @@ class Module(ModuleBase):
def __repr__(self):
return "Module(%s, %x)" % (self.type_key, self.handle.value)

def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value

@property
def type_key(self):
"""Get type key of the module."""
Expand Down Expand Up @@ -118,31 +122,39 @@ def export_library(self,
self.save(file_name)
return

if not (self.type_key == "llvm" or self.type_key == "c"):
raise ValueError("Module[%s]: Only llvm and c support export shared" % self.type_key)
modules = self._collect_dso_modules()
temp = _util.tempdir()
if fcompile is not None and hasattr(fcompile, "object_format"):
object_format = fcompile.object_format
else:
if self.type_key == "llvm":
object_format = "o"
files = []
is_system_lib = None
for module in modules:
if fcompile is not None and hasattr(fcompile, "object_format"):
object_format = fcompile.object_format
else:
if module.type_key == "llvm":
object_format = "o"
else:
assert module.type_key == "c"
object_format = "cc"
path_obj = temp.relpath("lib" + str(hash(module)) + "." + object_format)
module.save(path_obj)
files.append(path_obj)
if is_system_lib is None:
is_system_lib = (module.type_key == "llvm" and
module.get_function("__tvm_is_system_module")())
else:
assert self.type_key == "c"
object_format = "cc"
path_obj = temp.relpath("lib." + object_format)
self.save(path_obj)
files = [path_obj]
is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")()
has_imported_c_file = False
# Requires all dso modules should have the same system_lib setting
assert is_system_lib == (module.type_key == "llvm" and
module.get_function("__tvm_is_system_module")())
if module.type_key == "c":
options = []
if "options" in kwargs:
opts = kwargs["options"]
options = opts if isinstance(opts, (list, tuple)) else [opts]
opts = options + ["-I" + path for path in find_include_path()]
kwargs.update({'options': opts})

assert is_system_lib is not None
if self.imported_modules:
for i, m in enumerate(self.imported_modules):
if m.type_key == "c":
has_imported_c_file = True
c_file_name = "tmp_" + str(i) + ".cc"
path_cc = temp.relpath(c_file_name)
with open(path_cc, "w") as f:
f.write(m.get_source())
files.append(path_cc)
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
Expand All @@ -152,13 +164,7 @@ def export_library(self,
fcompile = _tar.tar
else:
fcompile = _cc.create_shared
if self.type_key == "c" or has_imported_c_file:
options = []
if "options" in kwargs:
opts = kwargs["options"]
options = opts if isinstance(opts, (list, tuple)) else [opts]
opts = options + ["-I" + path for path in find_include_path()]
kwargs.update({'options': opts})

fcompile(file_name, files, **kwargs)

def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
Expand Down Expand Up @@ -219,6 +225,22 @@ def evaluator(*args):
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")

def _collect_dso_modules(self):
"""Helper function to collect dso modules, then return it."""
visited, stack, dso_modules = set(), [], []
# append root module
visited.add(self)
stack.append(self)
while stack:
module = stack.pop()
if module.type_key == "llvm" or module.type_key == "c":
dso_modules.append(module)
for m in module.imported_modules:
if m not in visited:
visited.add(m)
stack.append(m)
return dso_modules


def system_lib():
"""Get system-wide library module singleton.
Expand Down
126 changes: 115 additions & 11 deletions src/codegen/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
#include <tvm/build_module.h>
#include <dmlc/memory_io.h>
#include <sstream>
#include <iostream>
#include <vector>
#include <cstdint>
#include <unordered_set>

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -58,20 +60,122 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
return m;
}

/*! \brief Helper class to serialize module */
class ModuleSerializer {
public:
explicit ModuleSerializer(runtime::Module mod) : mod_(mod) {
Init();
}

void SerializeModule(dmlc::Stream* stream) {
// Only have one DSO module and it is in the root, then
// we will not produce import_tree_.
bool has_import_tree = true;
if (IsDSOModule(mod_->type_key()) && mod_->imports().empty()) {
has_import_tree = false;
}
uint64_t sz = 0;
if (has_import_tree) {
// we will append one key for _import_tree
// The layout is the same as before: binary_size, key, logic, key, logic...
sz = mod_vec_.size() + 1;
} else {
// Keep the old behaviour
sz = mod_->imports().size();
}
stream->Write(sz);

for (auto m : mod_vec_) {
std::string mod_type_key = m->type_key();
if (!IsDSOModule(mod_type_key)) {
stream->Write(mod_type_key);
m->SaveToBinary(stream);
} else {
if (has_import_tree) {
mod_type_key = "_lib";
stream->Write(mod_type_key);
}
}
}

// Write _import_tree key if we have
if (has_import_tree) {
std::string import_key = "_import_tree";
stream->Write(import_key);
stream->Write(import_tree_);
}
}

private:
void Init() {
CreateModuleIndex();
CreateImportTree();
}

void CreateModuleIndex() {
std::unordered_set<const runtime::ModuleNode*> visited {mod_.operator->()};
std::vector<runtime::ModuleNode*> stack {mod_.operator->()};
uint64_t module_index = 0;

while (!stack.empty()) {
runtime::ModuleNode* n = stack.back();
stack.pop_back();
mod2index_[n] = module_index++;
mod_vec_.emplace_back(n);
for (runtime::Module m : n->imports()) {
runtime::ModuleNode* next = m.operator->();
if (visited.count(next)) {
continue;
}
visited.insert(next);
stack.push_back(next);
}
}
import_tree_.resize(mod_vec_.size());
}

void CreateImportTree() {
std::vector<uint64_t> csr_row_ptr {0};
std::vector<uint64_t> csr_col_indices;
std::vector<uint64_t> csr_values;
for (auto m : mod_vec_) {
for (size_t i = 0; i < m->imports().size(); i++) {
runtime::Module module = m->imports()[i];
uint64_t mod_index = mod2index_[module.operator->()];
csr_values.push_back(mod_index);
csr_col_indices.push_back(i);
}
csr_row_ptr.push_back(csr_values.size());
import_tree_[mod2index_[m]].resize(m->imports().size());
}

for (size_t i = 0; i < mod_vec_.size(); i++) {
for (size_t j = csr_row_ptr[i]; j < csr_row_ptr[i + 1]; j++) {
import_tree_[i][csr_col_indices[j]] = csr_values[j];
}
}
}

bool IsDSOModule(const std::string& key) {
return key == "llvm" || key == "c";
}

runtime::Module mod_;
// construct module to index
std::unordered_map<runtime::ModuleNode*, size_t> mod2index_;
// index -> module
std::vector<runtime::ModuleNode*> mod_vec_;
std::vector<std::vector<uint64_t>> import_tree_;
};

std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;
uint64_t sz = static_cast<uint64_t>(mod->imports().size());
stream->Write(sz);
for (runtime::Module im : mod->imports()) {
CHECK_EQ(im->imports().size(), 0U)
<< "Only support simply one-level hierarchy";
std::string tkey = im->type_key();
stream->Write(tkey);
if (tkey == "c") continue;
im->SaveToBinary(stream);
}

ModuleSerializer module_graph(mod);
module_graph.SerializeModule(stream);

// translate to C program
std::ostringstream os;
os << "#ifdef _WIN32\n"
Expand Down
87 changes: 60 additions & 27 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/registry.h>
#include <string>
#include <vector>
#include <cstdint>
#include "library_module.h"

namespace tvm {
Expand Down Expand Up @@ -65,17 +66,6 @@ class LibraryModuleNode final : public ModuleNode {
ObjectPtr<Library> lib_;
};

/*!
* \brief Helper classes to get into internal of a module.
*/
class ModuleInternal {
public:
// Get mutable reference of imports.
static std::vector<Module>* GetImportsAddr(ModuleNode* node) {
return &(node->imports_);
}
};

PackedFunc WrapPackedFunc(BackendPackedCFunc faddr,
const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
Expand Down Expand Up @@ -108,9 +98,11 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
* \param module_list The module list to append to
* \param lib The library.
*
* \return Root Module.
*/
void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
#ifndef _LIBCPP_SGX_CONFIG
CHECK(mblob != nullptr);
uint64_t nbytes = 0;
Expand All @@ -123,20 +115,59 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
dmlc::Stream* stream = &fs;
uint64_t size;
CHECK(stream->Read(&size));
std::vector<Module> modules;
bool has_import_tree = false;
for (uint64_t i = 0; i < size; ++i) {
std::string tkey;
CHECK(stream->Read(&tkey));
if (tkey == "c") continue;
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
// Currently, _lib is for DSOModule, but we
// don't have loadbinary function for it currently
if (tkey == "_lib") {
auto dso_module = Module(make_object<LibraryModuleNode>(lib));
// allow lookup of symbol from dso root (so all symbols are visible).
if (auto *ctx_addr =
reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = dso_module.operator->();
}
modules.emplace_back(dso_module);
} else if (tkey == "_import_tree") {
has_import_tree = true;
std::vector<std::vector<uint64_t>> import_tree;
stream->Read(&import_tree);
for (size_t i = 0; i < import_tree.size(); i++) {
for (size_t j = 0; j < import_tree[i].size(); j++) {
modules[i].Import(modules[import_tree[i][j]]);
}
}
} else {
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
<< fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(stream));
mlist->push_back(m);
Module m = (*f)(static_cast<void*>(stream));
modules.emplace_back(m);
}
}
// if we are using old dll, we don't have import tree
// so that we can't reconstruct module relationship using import tree
if (!has_import_tree) {
auto n = make_object<LibraryModuleNode>(lib);
for (const auto& m : modules) {
n->Import(m);
}
// allow lookup of symbol from dso root (so all symbols are visible).
if (auto *ctx_addr =
reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = n.operator->();
}
return Module(n);
}
CHECK(!modules.empty());
return modules[0];
#else
LOG(FATAL) << "SGX does not support ImportModuleBlob";
return Module();
#endif
}

Expand All @@ -149,17 +180,19 @@ Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {
const char* dev_mblob =
reinterpret_cast<const char*>(
lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
Module root_mod;
if (dev_mblob != nullptr) {
ImportModuleBlob(
dev_mblob, ModuleInternal::GetImportsAddr(n.operator->()));
root_mod = ProcessModuleBlob(dev_mblob, lib);
} else {
// Only have one single DSO Module
root_mod = Module(n);
// allow lookup of symbol from dso root (so all symbols are visible).
if (auto *ctx_addr =
reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = root_mod.operator->();
}
}

Module root_mod = Module(n);
// allow lookup of symbol from root(so all symbols are visible).
if (auto *ctx_addr =
reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = root_mod.operator->();
}
return root_mod;
}
} // namespace runtime
Expand Down
Loading

0 comments on commit d35d4c5

Please sign in to comment.