Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime] Serialization/Deserialization of runtime module #15244

Merged
merged 20 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ class TVM_DLL ModuleNode : public Object {
return (GetPropertyMask() & ModulePropertyMask::kDSOExportable) != 0;
}

/*! \brief Returns true if this module is 'Binary Serializable'. */
bool IsBinarySerializable() const {
return (GetPropertyMask() & ModulePropertyMask::kBinarySerializable) != 0;
}

/*!
* \brief Returns true if this module has a definition for a function of \p name. If
* \p query_imports is true, also search in any imported modules.
Expand Down
32 changes: 32 additions & 0 deletions include/tvm/target/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,37 @@ using runtime::TVMRetValue;
*/
runtime::Module Build(IRModule mod, Target target);

/*!
* \brief Serialize runtime module including its submodules
* \param mod The runtime module to serialize including its import tree.
* \param include_dso By default, include the info of DSOExportable modules. If disabled, an error
* will be raised when encountering DSO modules.
*/
std::string SerializeModuleToBytes(const runtime::Module& mod, bool include_dso = true);
sunggg marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Deserialize runtime module including its submodules
* \param blob byte stream, which are generated by `SerializeModuleToBytes`.
* \return runtime::Module runtime module constructed from the given stream
*/
runtime::Module DeserializeModuleFromBytes(std::string blob);

/*!
* \brief Export TVM runtime module to base64 stream including its submodules.
* Note that this targets modules that are binary serializable and DSOExportable.
* \param module The runtime module to export
* \return std::string The content of exported file
*/
std::string ExportModuleToBase64(tvm::runtime::Module module);

/*!
* \brief Import TVM runtime module from base64 stream
* Note that this targets modules that are binary serializable and DSOExportable.
* \param base64str base64 stream, which are generated by `ExportModuleToBase64`.
* \return runtime::Module runtime module constructed from the given stream
*/
runtime::Module ImportModuleFromBase64(std::string base64str);

/*!
* \brief Pack imported device library to a C file.
* Compile the C file and link with the host library
Expand Down Expand Up @@ -77,6 +108,7 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib,
runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib,
const std::string& target_triple,
const std::string& c_symbol_prefix = "");

} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_CODEGEN_H_
12 changes: 2 additions & 10 deletions python/tvm/contrib/torch/optimize_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
which is used to optimize the `torch.nn.module` by TVM metaSchedule,
and returns a custom TorchScript operator
"""
import base64

import contextlib
import tempfile
from typing import Optional, Tuple, Union
Expand All @@ -35,7 +35,7 @@
import tvm
from tvm import meta_schedule as ms
from tvm import relay
from tvm._ffi import get_global_func, register_func
from tvm._ffi import get_global_func
from tvm.target import Target


Expand All @@ -51,14 +51,6 @@ def forward(self, *torch_inputs: Tuple[torch.Tensor]):
return ret


@register_func("script_torch.save_to_base64")
sunggg marked this conversation as resolved.
Show resolved Hide resolved
def save_to_base64(obj) -> bytes:
with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
obj.export_library(tmpfile.name)
with open(tmpfile.name, "rb") as temp_file:
return base64.b64encode(temp_file.read())


def optimize_torch(
func,
example_inputs,
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
import ctypes
import struct
from typing import Sequence
import base64
import tempfile
import numpy as np

import tvm._ffi
from tvm._ffi import register_func
from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY
from tvm._ffi.libinfo import find_include_path
from .packed_func import PackedFunc, PackedFuncHandle, _set_class_module
Expand Down Expand Up @@ -713,3 +716,11 @@ def num_threads() -> int:


_set_class_module(Module)


@register_func("export_runtime_module")
def save_to_base64(obj) -> bytes:
with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile:
obj.export_library(tmpfile.name)
with open(tmpfile.name, "rb") as temp_file:
return base64.b64encode(temp_file.read())
75 changes: 0 additions & 75 deletions src/contrib/torch/base64.h

This file was deleted.

52 changes: 2 additions & 50 deletions src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,54 +46,6 @@ struct ThreadLocalStore {
}
};

/*
* Encode TVM runtime module to base64 stream
*/
std::string serialize(tvm::runtime::Module module) {
static const runtime::PackedFunc* f_to_str =
runtime::Registry::Get("script_torch.save_to_base64");
ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
"`script_torch.save_to_base64` in the global registry";
return (*f_to_str)(module);
}

struct Deleter { // deleter
explicit Deleter(std::string file_name) { this->file_name = file_name; }
void operator()(FILE* p) const {
fclose(p);
ICHECK(remove(file_name.c_str()) == 0)
<< "remove temporary file (" << file_name << ") unsuccessfully";
}
std::string file_name;
};

/*
* Decode TVM runtime module from base64 stream
*/
tvm::runtime::Module deserialize(std::string state) {
auto length = tvm::support::b64strlen(state);

std::vector<u_char> bytes(length); // bytes stream
tvm::support::b64decode(state, bytes.data());

const std::string name = tmpnam(NULL);
auto file_name = name + ".so";
std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name));
fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
fflush(pFile.get());

std::string load_f_name = "runtime.module.loadfile_so";
const PackedFunc* f = runtime::Registry::Get(load_f_name);
ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
<< " resolved to (" << load_f_name << ") in the global registry."
<< "Ensure that you have loaded the correct runtime code, and"
<< "that you are on the correct hardware architecture.";

tvm::runtime::Module ret = (*f)(file_name, "");

return ret;
}

TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) {
ThreadLocalStore::ThreadLocal()->mod = mod;
});
Expand Down Expand Up @@ -243,14 +195,14 @@ size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeMod
}

char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) {
std::string std = tvm::contrib::serialize(runtime_module->mod);
std::string std = tvm::codegen::ExportModuleToBase64(runtime_module->mod);
char* ret = new char[std.length() + 1];
snprintf(ret, std.length() + 1, "%s", std.c_str());
return ret;
}

TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) {
tvm::runtime::Module ret = tvm::contrib::deserialize(state);
tvm::runtime::Module ret = tvm::codegen::ImportModuleFromBase64(state);
return new TVMContribTorchRuntimeModule(ret);
}

Expand Down
34 changes: 34 additions & 0 deletions src/node/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/container/adt.h>
#include <tvm/runtime/profiling.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/codegen.h>

#include <algorithm>
#include <unordered_map>
Expand Down Expand Up @@ -360,6 +361,39 @@ struct ADTObjTrait {

TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);

struct ModuleNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
static void SHashReduce(const runtime::ModuleNode* key, SHashReducer hash_reduce) {
sunggg marked this conversation as resolved.
Show resolved Hide resolved
const auto* rtmod = static_cast<const runtime::ModuleNode*>(key);
runtime::String str_key =
codegen::SerializeModuleToBytes(GetRef<runtime::Module>(rtmod), /*include_dso*/ false);
hash_reduce->SHashReduceHashedValue(
runtime::String::StableHashBytes(str_key->data, str_key->size));
}

static bool SEqualReduce(const runtime::ModuleNode* lhs, const runtime::ModuleNode* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;
const auto* lhs_mod = static_cast<const runtime::ModuleNode*>(lhs);
const auto* rhs_mod = static_cast<const runtime::ModuleNode*>(rhs);
runtime::String lhs_str =
codegen::SerializeModuleToBytes(GetRef<runtime::Module>(lhs_mod), /*include_dso*/ false);
runtime::String rhs_str =
codegen::SerializeModuleToBytes(GetRef<runtime::Module>(rhs_mod), /*include_dso*/ false);
return lhs_str == rhs_str;
}
};

TVM_REGISTER_REFLECTION_VTABLE(runtime::ModuleNode, ModuleNodeTrait)
.set_creator([](const std::string& blob) {
runtime::Module rtmod = codegen::DeserializeModuleFromBytes(blob);
return RefToObjectPtr::Get(rtmod);
})
.set_repr_bytes([](const Object* n) -> std::string {
const auto* rtmod = static_cast<const runtime::ModuleNode*>(n);
return codegen::SerializeModuleToBytes(GetRef<runtime::Module>(rtmod), /*include_dso*/ false);
});

void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce,
bool hash_data) {
ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
Expand Down
9 changes: 0 additions & 9 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,6 @@ class LibraryModuleNode final : public ModuleNode {
PackedFuncWrapper packed_func_wrapper_;
};

/*!
* \brief Helper classes to get into internal of a module.
*/
class ModuleInternal {
public:
// Get mutable reference of imports.
static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
};

PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
TVMValue ret_value;
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/library_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <functional>
#include <string>
#include <vector>

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -78,6 +79,15 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr<Object>&
*/
void InitContextFunctions(std::function<void*(const char*)> fgetsymbol);

/*!
* \brief Helper classes to get into internal of a module.
*/
class ModuleInternal {
public:
// Get mutable reference of imports.
static std::vector<Module>* GetImportsAddr(ModuleNode* node) { return &(node->imports_); }
};

/*!
* \brief Type alias for function to wrap a TVMBackendPackedCFunc.
* \param The function address imported from a module.
Expand Down
34 changes: 34 additions & 0 deletions src/support/base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,40 @@ class Base64OutStream : public dmlc::Stream {
}
}
};

inline size_t b64strlen(const std::string b64str) {
ICHECK(b64str.size() % 4 == 0) << "invalid base64 encoding";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move these two functions to callers, since they are not necessarily the optimal way of decoding b64. The b64Stream is the preferred way. So avoid making them in support(and gradually we move the contrib impl over, or keep them self-contained to a sub module scope) is better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I moved them to contrib side (src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc) since it is the only user currently.

size_t length = b64str.size() / 4 * 3;
if (b64str[b64str.size() - 2] == '=') {
length -= 2;
} else if (b64str[b64str.size() - 1] == '=') {
length -= 1;
}
return length;
}

inline void b64decode(const std::string b64str, uint8_t* ret) {
size_t index = 0;
const auto length = b64str.size();
for (size_t i = 0; i < length; i += 4) {
int8_t ch0 = base64::DecodeTable[(int32_t)b64str[i]];
int8_t ch1 = base64::DecodeTable[(int32_t)b64str[i + 1]];
int8_t ch2 = base64::DecodeTable[(int32_t)b64str[i + 2]];
int8_t ch3 = base64::DecodeTable[(int32_t)b64str[i + 3]];
uint8_t st1 = (ch0 << 2) + (ch1 >> 4);
ret[index++] = st1;
if (b64str[i + 2] != '=') {
uint8_t st2 = ((ch1 & 0b1111) << 4) + (ch2 >> 2);
ret[index++] = st2;
if (b64str[i + 3] != '=') {
uint8_t st3 = ((ch2 & 0b11) << 6) + ch3;
ret[index++] = st3;
}
}
}
ICHECK(b64strlen(b64str) == index) << "base64 decoding fails";
}

} // namespace support
} // namespace tvm
#endif // TVM_SUPPORT_BASE64_H_
Loading