Skip to content

Commit

Permalink
[NODE] General serialzation of leaf objects into bytes.
Browse files Browse the repository at this point in the history
This PR refactors the serialization mechanism to support general
serialization of leaf objects into bytes.

The new feature superceded the original GetGlobalKey feature for singletons.
Added serialization support for runtime::String.
  • Loading branch information
tqchen committed Apr 10, 2020
1 parent a4321e0 commit ac386a7
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 44 deletions.
52 changes: 29 additions & 23 deletions include/tvm/node/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,29 +98,31 @@ class ReflectionVTable {
typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce);
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey must be defined for the object.
* \param repr_bytes Repr bytes to create the object.
* If this is not empty then FReprBytes must be defined for the object.
* \return The created function.
*/
typedef ObjectPtr<Object> (*FCreate)(const std::string& global_key);
typedef ObjectPtr<Object> (*FCreate)(const std::string& repr_bytes);
/*!
* \brief Global key function, only needed by global objects.
* \brief Function to get a byte representation that can be used to recover the object.
* \param node The node pointer.
* \return node The global key to the node.
* \return bytes The bytes that can be used to recover the object.
*/
typedef std::string (*FGlobalKey)(const Object* self);
typedef std::string (*FReprBytes)(const Object* self);
/*!
* \brief Dispatch the VisitAttrs function.
* \param self The pointer to the object.
* \param visitor The attribute visitor.
*/
inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
/*!
* \brief Get global key of the object, if any.
* \brief Get repr bytes if any.
* \param self The pointer to the object.
* \return the global key if object has one, otherwise return empty string.
* \param repr_bytes The output repr bytes, can be null, in which case the function
* simply queries if the ReprBytes function exists for the type.
* \return Whether repr bytes exists
*/
inline std::string GetGlobalKey(Object* self) const;
inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const;
/*!
* \brief Dispatch the SEqualReduce function.
* \param self The pointer to the object.
Expand All @@ -141,10 +143,10 @@ class ReflectionVTable {
* by type_key and global key.
*
* \param type_key The type key of the object.
* \param global_key A global key that can be used to uniquely identify the object if any.
* \param repr_bytes Bytes representation of the object if any.
*/
TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
const std::string& global_key = "") const;
const std::string& repr_bytes = "") const;
/*!
* \brief Get an field object by the attr name.
* \param self The pointer to the object.
Expand Down Expand Up @@ -176,8 +178,8 @@ class ReflectionVTable {
std::vector<FSHashReduce> fshash_reduce_;
/*! \brief Creation function. */
std::vector<FCreate> fcreate_;
/*! \brief Global key function. */
std::vector<FGlobalKey> fglobal_key_;
/*! \brief ReprBytes function. */
std::vector<FReprBytes> frepr_bytes_;
};

/*! \brief Registry of a reflection table. */
Expand All @@ -196,13 +198,13 @@ class ReflectionVTable::Registry {
return *this;
}
/*!
* \brief Set global_key function.
* \param f The creator function.
* \brief Set bytes repr function.
* \param f The ReprBytes function.
* \return rference to self.
*/
Registry& set_global_key(FGlobalKey f) { // NOLINT(*)
CHECK_LT(type_index_, parent_->fglobal_key_.size());
parent_->fglobal_key_[type_index_] = f;
Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*)
CHECK_LT(type_index_, parent_->frepr_bytes_.size());
parent_->frepr_bytes_[type_index_] = f;
return *this;
}

Expand Down Expand Up @@ -365,7 +367,7 @@ ReflectionVTable::Register() {
if (tindex >= fvisit_attrs_.size()) {
fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr);
fglobal_key_.resize(tindex + 1, nullptr);
frepr_bytes_.resize(tindex + 1, nullptr);
fsequal_reduce_.resize(tindex + 1, nullptr);
fshash_reduce_.resize(tindex + 1, nullptr);
}
Expand All @@ -392,12 +394,16 @@ VisitAttrs(Object* self, AttrVisitor* visitor) const {
fvisit_attrs_[tindex](self, visitor);
}

inline std::string ReflectionVTable::GetGlobalKey(Object* self) const {
inline bool ReflectionVTable::GetReprBytes(const Object* self,
std::string* repr_bytes) const {
uint32_t tindex = self->type_index();
if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) {
return fglobal_key_[tindex](self);
if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) {
if (repr_bytes != nullptr) {
*repr_bytes = frepr_bytes_[tindex](self);
}
return true;
} else {
return std::string();
return false;
}
}

Expand Down
8 changes: 8 additions & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,16 @@ def _convert(item, _):
return item
return _convert

def _update_global_key(item, _):
item["repr_str"] = item["global_key"]
del item["global_key"]
return item

node_map = {
# Base IR
"SourceName": _update_global_key,
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
Expand Down
2 changes: 1 addition & 1 deletion src/ir/env_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc")

TVM_REGISTER_NODE_TYPE(EnvFuncNode)
.set_creator(CreateEnvNode)
.set_global_key([](const Object* n) -> std::string {
.set_repr_bytes([](const Object* n) -> std::string {
return static_cast<const EnvFuncNode*>(n)->name;
});

Expand Down
2 changes: 1 addition & 1 deletion src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ ObjectPtr<Object> CreateOp(const std::string& name) {

TVM_REGISTER_NODE_TYPE(OpNode)
.set_creator(CreateOp)
.set_global_key([](const Object* n) {
.set_repr_bytes([](const Object* n) {
return static_cast<const OpNode*>(n)->name;
});

Expand Down
2 changes: 1 addition & 1 deletion src/ir/span.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode)
.set_global_key([](const Object* n) {
.set_repr_bytes([](const Object* n) {
return static_cast<const SourceNameNode*>(n)->name;
});

Expand Down
16 changes: 15 additions & 1 deletion src/node/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,21 @@ struct StringObjTrait {
}
};

TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
struct RefToObjectPtr : public ObjectRef {
static ObjectPtr<Object> Get(const ObjectRef& ref) {
return GetDataPtr<Object>(ref);
}
};

TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
.set_creator([](const std::string& bytes) {
return RefToObjectPtr::Get(runtime::String(bytes));
})
.set_repr_bytes([](const Object* n) -> std::string {
return GetRef<runtime::String>(
static_cast<const runtime::StringObj*>(n)).operator std::string();
});


struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
Expand Down
4 changes: 2 additions & 2 deletions src/node/reflection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ ReflectionVTable* ReflectionVTable::Global() {

ObjectPtr<Object>
ReflectionVTable::CreateInitObject(const std::string& type_key,
const std::string& global_key) const {
const std::string& repr_bytes) const {
uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
return fcreate_[tindex](global_key);
return fcreate_[tindex](repr_bytes);
}

class NodeAttrSetter : public AttrVisitor {
Expand Down
71 changes: 56 additions & 15 deletions src/node/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <tvm/ir/attrs.h>

#include <string>
#include <cctype>
#include <map>

#include "../support/base64.h"
Expand All @@ -46,6 +47,26 @@ inline DataType String2Type(std::string s) {
return DataType(runtime::String2DLDataType(s));
}

inline std::string Base64Decode(std::string s) {
dmlc::MemoryStringStream mstrm(&s);
support::Base64InStream b64strm(&mstrm);
std::string output;
b64strm.InitPosition();
dmlc::Stream* strm = &b64strm;
strm->Read(&output);
return output;
}

inline std::string Base64Encode(std::string s) {
std::string blob;
dmlc::MemoryStringStream mstrm(&blob);
support::Base64OutStream b64strm(&mstrm);
dmlc::Stream* strm = &b64strm;
strm->Write(s);
b64strm.Finish();
return blob;
}

// indexer to index all the nodes
class NodeIndexer : public AttrVisitor {
public:
Expand Down Expand Up @@ -103,7 +124,10 @@ class NodeIndexer : public AttrVisitor {
MakeIndex(const_cast<Object*>(kv.second.get()));
}
} else {
reflection_->VisitAttrs(node, this);
// if the node already have repr bytes, no need to visit Attrs.
if (!reflection_->GetReprBytes(node, nullptr)) {
reflection_->VisitAttrs(node, this);
}
}
}
};
Expand All @@ -115,8 +139,8 @@ using AttrMap = std::map<std::string, std::string>;
struct JSONNode {
/*! \brief The type of key of the object. */
std::string type_key;
/*! \brief The global key for global object. */
std::string global_key;
/*! \brief The str repr representation. */
std::string repr_bytes;
/*! \brief the attributes */
AttrMap attrs;
/*! \brief keys of a map. */
Expand All @@ -127,8 +151,15 @@ struct JSONNode {
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("type_key", type_key);
if (global_key.size() != 0) {
writer->WriteObjectKeyValue("global_key", global_key);
if (repr_bytes.size() != 0) {
// choose to use str representation or base64, based on whether
// the byte representation is printable.
if (std::all_of(repr_bytes.begin(), repr_bytes.end(),
[](char ch) { return std::isprint(ch); })) {
writer->WriteObjectKeyValue("repr_str", repr_bytes);
} else {
writer->WriteObjectKeyValue("repr_b64", Base64Encode(repr_bytes));
}
}
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
Expand All @@ -145,15 +176,24 @@ struct JSONNode {
void Load(dmlc::JSONReader *reader) {
attrs.clear();
data.clear();
global_key.clear();
repr_bytes.clear();
type_key.clear();
std::string repr_b64, repr_str;
dmlc::JSONObjectReadHelper helper;
helper.DeclareOptionalField("type_key", &type_key);
helper.DeclareOptionalField("global_key", &global_key);
helper.DeclareOptionalField("repr_b64", &repr_b64);
helper.DeclareOptionalField("repr_str", &repr_str);
helper.DeclareOptionalField("attrs", &attrs);
helper.DeclareOptionalField("keys", &keys);
helper.DeclareOptionalField("data", &data);
helper.ReadAllFields(reader);

if (repr_str.size() != 0) {
CHECK_EQ(repr_b64.size(), 0U);
repr_bytes = std::move(repr_str);
} else if (repr_b64.size() != 0) {
repr_bytes = Base64Decode(repr_b64);
}
}
};

Expand Down Expand Up @@ -212,10 +252,8 @@ class JSONAttrGetter : public AttrVisitor {
return;
}
node_->type_key = node->GetTypeKey();
node_->global_key = reflection_->GetGlobalKey(node);
// No need to recursively visit fields of global singleton
// They are registered via the environment.
if (node_->global_key.length() != 0) return;
// do not need to print additional things once we have repr bytes.
if (reflection_->GetReprBytes(node, &(node_->repr_bytes))) return;

// populates the fields.
node_->attrs.clear();
Expand Down Expand Up @@ -434,7 +472,7 @@ ObjectRef LoadJSON(std::string json_str) {
for (const JSONNode& jnode : jgraph.nodes) {
if (jnode.type_key.length() != 0) {
ObjectPtr<Object> node =
reflection->CreateInitObject(jnode.type_key, jnode.global_key);
reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes);
nodes.emplace_back(node);
} else {
nodes.emplace_back(ObjectPtr<Object>());
Expand All @@ -447,9 +485,12 @@ ObjectRef LoadJSON(std::string json_str) {

for (size_t i = 0; i < nodes.size(); ++i) {
setter.node_ = &jgraph.nodes[i];
// do not need to recover content of global singleton object
// they are registered via the environment
if (setter.node_->global_key.length() == 0) {
// Skip the nodes that has an repr bytes representation.
// NOTE: the second condition is used to guard the case
// where the repr bytes itself is an empty string "".
if (setter.node_->repr_bytes.length() == 0 &&
nodes[i] != nullptr &&
!reflection->GetReprBytes(nodes[i].get(), nullptr)) {
setter.Set(nodes[i].get());
}
}
Expand Down
18 changes: 18 additions & 0 deletions tests/python/relay/test_json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import tvm
from tvm import relay
from tvm import te
import json

Expand Down Expand Up @@ -108,6 +109,22 @@ def test_global_var():
assert isinstance(tvar, tvm.ir.GlobalVar)


def test_op():
nodes = [
{"type_key": ""},
{"type_key": "relay.Op",
"global_key": "nn.conv2d"}
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
op = tvm.ir.load_json(json.dumps(data))
assert op == relay.op.get("nn.conv2d")


def test_tir_var():
nodes = [
{"type_key": ""},
Expand All @@ -132,6 +149,7 @@ def test_tir_var():


if __name__ == "__main__":
test_op()
test_type_var()
test_incomplete_type()
test_func_tuple_type()
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_node_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,20 @@ def test(x):
assert x.func(10) == 11


def test_string():
# non printable str, need to store by b64
s1 = tvm.runtime.String("xy\x01z")
s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
tvm.ir.assert_structural_equal(s1, s2)

# printable str, need to store by repr_str
s1 = tvm.runtime.String("xyz")
s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
tvm.ir.assert_structural_equal(s1, s2)


if __name__ == "__main__":
test_string()
test_env_func()
test_make_node()
test_make_smap()
Expand Down

0 comments on commit ac386a7

Please sign in to comment.