Skip to content

Commit

Permalink
separate serialize/deserialize and export/import APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Aug 8, 2023
1 parent d335bfa commit 010ef02
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 80 deletions.
40 changes: 31 additions & 9 deletions include/tvm/target/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,37 @@ using runtime::TVMRetValue;
*/
runtime::Module Build(IRModule mod, Target target);

/*!
* \brief Serialize runtime module including its submodules
* \param mod The runtime module to serialize including its import tree.
* \param include_dso By default, include the info of DSOExportable modules. If disabled, an error
* will be raised when encountering DSO modules.
*/
std::string SerializeModuleToBytes(const runtime::Module& mod, bool include_dso = true);

/*!
* \brief Deserialize runtime module including its submodules
* \param blob byte stream, which are generated by `SerializeModuleToBytes`.
* \return runtime::Module runtime module constructed from the given stream
*/
runtime::Module DeserializeModuleFromBytes(std::string blob);

/*!
* \brief Export TVM runtime module to base64 stream including its submodules.
* Note that this targets modules that are binary serializable and DSOExportable.
* \param module The runtime module to export
* \return std::string The content of exported file
*/
std::string ExportModuleToBase64(tvm::runtime::Module module);

/*!
* \brief Import TVM runtime module from base64 stream
* Note that this targets modules that are binary serializable and DSOExportable.
* \param base64str base64 stream, which are generated by `ExportModuleToBase64`.
* \return runtime::Module runtime module constructed from the given stream
*/
runtime::Module ImportModuleFromBase64(std::string base64str);

/*!
* \brief Pack imported device library to a C file.
* Compile the C file and link with the host library
Expand Down Expand Up @@ -78,15 +109,6 @@ runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib,
const std::string& target_triple,
const std::string& c_symbol_prefix = "");

/*
* Encode TVM runtime module to base64 stream
*/
std::string SerializeModuleToBase64(tvm::runtime::Module module);

/*
* Decode TVM runtime module from base64 stream
*/
runtime::Module DeserializeModuleFromBase64(std::string state);
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_CODEGEN_H_
2 changes: 1 addition & 1 deletion python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def num_threads() -> int:
_set_class_module(Module)


@register_func("serialize_runtime_module")
@register_func("export_runtime_module")
def save_to_base64(obj) -> bytes:
with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
obj.export_library(tmpfile.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,14 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod
}

char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) {
std::string std = tvm::codegen::SerializeModuleToBase64(runtime_module->mod);
std::string std = tvm::codegen::ExportModuleToBase64(runtime_module->mod);
char* ret = new char[std.length() + 1];
snprintf(ret, std.length() + 1, "%s", std.c_str());
return ret;
}

TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) {
tvm::runtime::Module ret = tvm::codegen::DeserializeModuleFromBase64(state);
tvm::runtime::Module ret = tvm::codegen::ImportModuleFromBase64(state);
return new TVMContribTorchRuntimeModule(ret);
}

Expand Down
25 changes: 21 additions & 4 deletions src/node/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,35 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);

struct ModuleNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static constexpr std::nullptr_t SHashReduce = nullptr;
static constexpr std::nullptr_t SEqualReduce = nullptr;
static void SHashReduce(const runtime::ModuleNode* key, SHashReducer hash_reduce) {
const auto* rtmod = static_cast<const runtime::ModuleNode*>(key);
runtime::String str_key =
codegen::SerializeModuleToBytes(GetRef<runtime::Module>(rtmod), /*include_dso*/ false);
hash_reduce->SHashReduceHashedValue(
runtime::String::StableHashBytes(str_key->data, str_key->size));
}

static bool SEqualReduce(const runtime::ModuleNode* lhs, const runtime::ModuleNode* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
const auto* lhs_mod = static_cast<const runtime::ModuleNode*>(lhs);
const auto* rhs_mod = static_cast<const runtime::ModuleNode*>(rhs);
runtime::String lhs_str =
codegen::SerializeModuleToBytes(GetRef<runtime::Module>(lhs_mod), /*include_dso*/ false);
runtime::String rhs_str =
codegen::SerializeModuleToBytes(GetRef<runtime::Module>(rhs_mod), /*include_dso*/ false);
return lhs_str == rhs_str;
}
};

TVM_REGISTER_REFLECTION_VTABLE(runtime::ModuleNode, ModuleNodeTrait)
.set_creator([](const std::string& blob) {
runtime::Module rtmod = codegen::DeserializeModuleFromBase64(blob);
runtime::Module rtmod = codegen::DeserializeModuleFromBytes(blob);
return RefToObjectPtr::Get(rtmod);
})
.set_repr_bytes([](const Object* n) -> std::string {
const auto* rtmod = static_cast<const runtime::ModuleNode*>(n);
return codegen::SerializeModuleToBase64(GetRef<runtime::Module>(rtmod));
return codegen::SerializeModuleToBytes(GetRef<runtime::Module>(rtmod), /*include_dso*/ false);
});

void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce,
Expand Down
9 changes: 0 additions & 9 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,6 @@ class LibraryModuleNode final : public ModuleNode {
PackedFuncWrapper packed_func_wrapper_;
};

/*!
* \brief Helper classes to get into internal of a module.
*/
class ModuleInternal {
public:
// Get mutable reference of imports.
static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
};

PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
TVMValue ret_value;
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/library_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>&
*/
void InitContextFunctions(std::function<void*(const char*)> fgetsymbol);

/*!
* \brief Helper classes to get into internal of a module.
*/
class ModuleInternal {
public:
// Get mutable reference of imports.
static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
};

/*!
* \brief Type alias for function to wrap a TVMBackendPackedCFunc.
* \param The function address imported from a module.
Expand Down
81 changes: 58 additions & 23 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <unordered_set>
#include <vector>

#include "../runtime/library_module.h"
#include "../support/base64.h"

namespace tvm {
Expand Down Expand Up @@ -65,13 +66,13 @@ class ModuleSerializer {
public:
explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); }

void SerializeModuleToBytes(dmlc::Stream* stream, bool export_dso) {
void SerializeModuleToBytes(dmlc::Stream* stream, bool include_dso) {
// Only have one DSO module and it is in the root, then
// we will not produce import_tree_.
bool has_import_tree = true;

if (mod_->IsDSOExportable()) {
ICHECK(export_dso) << "`export_dso` should be enabled for DSOExportable modules";
ICHECK(include_dso) << "`include_dso` should be enabled for DSOExportable modules";
has_import_tree = !mod_->imports().empty();
}

Expand All @@ -94,7 +95,7 @@ class ModuleSerializer {
stream->Write(mod_type_key);
group[0]->SaveToBinary(stream);
} else if (group[0]->IsDSOExportable()) {
ICHECK(export_dso) << "`export_dso` should be enabled for DSOExportable modules";
ICHECK(include_dso) << "`include_dso` should be enabled for DSOExportable modules";
// DSOExportable: do not need binary
if (has_import_tree) {
std::string mod_type_key = "_lib";
Expand Down Expand Up @@ -235,25 +236,56 @@ class ModuleSerializer {
std::vector<uint64_t> import_tree_child_indices_;
};

/*!
* \brief Serialize runtime module including
*
* \param mod The runtime module to serialize including its import tree.
* \param export_mode By default, allow export of DSOExportable modules. If disabled, an error will
* be reaised when encountering DSO.
*/
namespace {
std::string SerializeModuleToBytes(const runtime::Module& mod, bool export_dso = true) {
std::string SerializeModuleToBytes(const runtime::Module& mod, bool include_dso) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;

ModuleSerializer module_serializer(mod);
module_serializer.SerializeModuleToBytes(stream, export_dso);

module_serializer.SerializeModuleToBytes(stream, include_dso);
return bin;
}
} // namespace

runtime::Module DeserializeModuleFromBytes(std::string blob) {
dmlc::MemoryStringStream ms(&blob);
dmlc::Stream* stream = &ms;

uint64_t size;
ICHECK(stream->Read(&size));
std::vector<runtime::Module> modules;
std::vector<uint64_t> import_tree_row_ptr;
std::vector<uint64_t> import_tree_child_indices;

for (uint64_t i = 0; i < size; ++i) {
std::string tkey;
ICHECK(stream->Read(&tkey));
// "_lib" serves as a placeholder in the module import tree to indicate where
// to place the DSOModule
ICHECK(tkey != "_lib") << "Should not contain any placeholder for DSOModule.";
if (tkey == "_import_tree") {
ICHECK(stream->Read(&import_tree_row_ptr));
ICHECK(stream->Read(&import_tree_child_indices));
} else {
auto m = runtime::LoadModuleFromBinary(tkey, stream);
modules.emplace_back(m);
}
}

for (size_t i = 0; i < modules.size(); ++i) {
for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) {
auto module_import_addr = runtime::ModuleInternal::GetImportsAddr(modules[i].operator->());
auto child_index = import_tree_child_indices[j];
ICHECK(child_index < modules.size());
module_import_addr->emplace_back(modules[child_index]);
}
}

ICHECK(!modules.empty()) << "modules cannot be empty when import tree is present";
// invariance: root module is always at location 0.
// The module order is collected via DFS
runtime::Module root_mod = modules[0];
return root_mod;
}

std::string PackImportsToC(const runtime::Module& mod, bool system_lib,
const std::string& c_symbol_prefix) {
Expand Down Expand Up @@ -349,20 +381,25 @@ struct Deleter { // deleter
std::string file_name;
};

std::string SerializeModuleToBase64(tvm::runtime::Module module) {
static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("serialize_runtime_module");
std::string ExportModuleToBase64(tvm::runtime::Module module) {
static const runtime::PackedFunc* f_to_str = runtime::Registry::Get("export_runtime_module");
ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
"`serialize_runtime_module` in the global registry";
"`export_runtime_module` in the global registry";
return (*f_to_str)(module);
}

tvm::runtime::Module DeserializeModuleFromBase64(std::string base64str) {
tvm::runtime::Module ImportModuleFromBase64(std::string base64str) {
auto length = tvm::support::b64strlen(base64str);

std::vector<u_char> bytes(length); // bytes stream
tvm::support::b64decode(base64str, bytes.data());
const std::string name = tmpnam(NULL);
auto file_name = name + ".so";

auto now = std::chrono::system_clock::now();
auto in_time_t = std::chrono::system_clock::to_time_t(now);
std::stringstream datetime;
datetime << std::put_time(std::localtime(&in_time_t), "%Y-%m-%d-%X");
const std::string file_name = "tmp-module-" + datetime.str() + ".so";
LOG(INFO) << file_name;
std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
fflush(pFile.get());
Expand All @@ -373,8 +410,6 @@ tvm::runtime::Module DeserializeModuleFromBase64(std::string base64str) {
<< " resolved to (" << load_f_name << ") in the global registry."
<< "Ensure that you have loaded the correct runtime code, and"
<< "that you are on the correct hardware architecture.";

LOG(INFO) << "Run " << load_f_name;
tvm::runtime::Module ret = (*f)(file_name, "");
return ret;
}
Expand Down
83 changes: 51 additions & 32 deletions tests/python/unittest/test_roundtrip_runtime_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import tvm.testing
from tvm import TVMError
from tvm import relay
import numpy as np
from tvm.contrib.graph_executor import GraphModule


def test_csource_module():
Expand All @@ -44,44 +42,65 @@ def test_aot_module():
tvm.ir.load_json(tvm.ir.save_json(mod))


@tvm.testing.requires_cuda
def test_recursive_imports():
def get_test_mod():
x = relay.var("x", shape=(1, 10), dtype="float32")
y = relay.var("y", shape=(1, 10), dtype="float32")
z = relay.add(x, y)
func = relay.Function([x, y], z)
mod = relay.build_module._build_module_no_factory(func, target="cuda")
return relay.build_module._build_module_no_factory(func, target="cuda")


def get_cuda_mod():
# Get Cuda module which is binary serializable
return get_test_mod().imported_modules[0].imported_modules[0]


@tvm.testing.requires_cuda
def test_cuda_module():
mod = get_cuda_mod()
assert mod.is_binary_serializable
# GraphExecutorFactory Module contains LLVM Module and LLVM Module contains cuda Module.
assert mod.type_key == "GraphExecutorFactory"
assert mod.imported_modules[0].type_key == "llvm"
assert mod.imported_modules[0].imported_modules[0].type_key == "cuda"
new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
tvm.ir.structural_equal(mod, new_mod)


@tvm.testing.requires_cuda
def test_valid_submodules():
mod, mod2, mod3, mod4 = get_cuda_mod(), get_cuda_mod(), get_cuda_mod(), get_cuda_mod()

# Create the nested cuda module
mod.import_module(mod2)
mod2.import_module(mod3)
mod2.import_module(mod4)

# Root module and all submodules should be binary serializable since they are cuda module
assert mod.is_binary_serializable
assert mod.imported_modules[0].is_binary_serializable
assert mod.imported_modules[0].imported_modules[0].is_binary_serializable
assert mod.imported_modules[0].imported_modules[1].is_binary_serializable

new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
assert new_mod.is_binary_serializable
# GraphExecutorFactory Module contains LLVM Module and LLVM Module contains cuda Module.
assert new_mod.type_key == "GraphExecutorFactory"
# This type key is now `library` rather than llvm.
assert new_mod.imported_modules[0].type_key == "library"
assert new_mod.imported_modules[0].imported_modules[0].type_key == "cuda"

dev = tvm.cuda()
data_x = tvm.nd.array(np.random.rand(1, 10).astype("float32"), dev)
data_y = tvm.nd.array(np.random.rand(1, 10).astype("float32"), dev)

graph_mod = GraphModule(mod["default"](dev))
graph_mod.set_input("x", data_x)
graph_mod.set_input("y", data_y)
graph_mod.run()
expected = graph_mod.get_output(0)

graph_mod = GraphModule(new_mod["default"](dev))
graph_mod.set_input("x", data_x)
graph_mod.set_input("y", data_y)
graph_mod.run()
output = graph_mod.get_output(0)
tvm.testing.assert_allclose(output.numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)
tvm.ir.structural_equal(mod, new_mod)


@tvm.testing.requires_cuda
def test_invalid_submodules():
mod, mod2, mod3 = get_cuda_mod(), get_cuda_mod(), get_cuda_mod()
mod4 = tvm.get_global_func("relay.build_module._AOTExecutorCodegen")()

# Create the nested cuda module
mod.import_module(mod2)
mod2.import_module(mod3)
mod2.import_module(mod4)

# One of submodules is not binary serializable.
assert mod.is_binary_serializable
assert mod.imported_modules[0].is_binary_serializable
assert mod.imported_modules[0].imported_modules[0].is_binary_serializable
assert not mod.imported_modules[0].imported_modules[1].is_binary_serializable

# Therefore, we cannot roundtrip.
with pytest.raises(TVMError):
tvm.ir.load_json(tvm.ir.save_json(mod))


if __name__ == "__main__":
Expand Down

0 comments on commit 010ef02

Please sign in to comment.