diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 18dfa129fa39..9ed87df46618 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -98,17 +98,17 @@ class ReflectionVTable { typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce); /*! * \brief creator function. - * \param global_key Key that identifies a global single object. - * If this is not empty then FGlobalKey must be defined for the object. + * \param repr_bytes Repr bytes to create the object. + * If this is not empty then FReprBytes must be defined for the object. * \return The created function. */ - typedef ObjectPtr (*FCreate)(const std::string& global_key); + typedef ObjectPtr (*FCreate)(const std::string& repr_bytes); /*! - * \brief Global key function, only needed by global objects. + * \brief Function to get a byte representation that can be used to recover the object. * \param node The node pointer. - * \return node The global key to the node. + * \return bytes The bytes that can be used to recover the object. */ - typedef std::string (*FGlobalKey)(const Object* self); + typedef std::string (*FReprBytes)(const Object* self); /*! * \brief Dispatch the VisitAttrs function. * \param self The pointer to the object. @@ -116,11 +116,13 @@ class ReflectionVTable { */ inline void VisitAttrs(Object* self, AttrVisitor* visitor) const; /*! - * \brief Get global key of the object, if any. + * \brief Get repr bytes if any. * \param self The pointer to the object. - * \return the global key if object has one, otherwise return empty string. + * \param repr_bytes The output repr bytes, can be null, in which case the function + * simply queries if the ReprBytes function exists for the type. + * \return Whether repr bytes exists */ - inline std::string GetGlobalKey(Object* self) const; + inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const; /*! * \brief Dispatch the SEqualReduce function. * \param self The pointer to the object. @@ -141,10 +143,10 @@ class ReflectionVTable { * by type_key and global key. * * \param type_key The type key of the object. - * \param global_key A global key that can be used to uniquely identify the object if any. + * \param repr_bytes Bytes representation of the object if any. */ TVM_DLL ObjectPtr CreateInitObject(const std::string& type_key, - const std::string& global_key = "") const; + const std::string& repr_bytes = "") const; /*! * \brief Get an field object by the attr name. * \param self The pointer to the object. @@ -176,8 +178,8 @@ class ReflectionVTable { std::vector fshash_reduce_; /*! \brief Creation function. */ std::vector fcreate_; - /*! \brief Global key function. */ - std::vector fglobal_key_; + /*! \brief ReprBytes function. */ + std::vector frepr_bytes_; }; /*! \brief Registry of a reflection table. */ @@ -196,13 +198,13 @@ class ReflectionVTable::Registry { return *this; } /*! - * \brief Set global_key function. - * \param f The creator function. + * \brief Set bytes repr function. + * \param f The ReprBytes function. * \return rference to self. */ - Registry& set_global_key(FGlobalKey f) { // NOLINT(*) - CHECK_LT(type_index_, parent_->fglobal_key_.size()); - parent_->fglobal_key_[type_index_] = f; + Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*) + CHECK_LT(type_index_, parent_->frepr_bytes_.size()); + parent_->frepr_bytes_[type_index_] = f; return *this; } @@ -365,7 +367,7 @@ ReflectionVTable::Register() { if (tindex >= fvisit_attrs_.size()) { fvisit_attrs_.resize(tindex + 1, nullptr); fcreate_.resize(tindex + 1, nullptr); - fglobal_key_.resize(tindex + 1, nullptr); + frepr_bytes_.resize(tindex + 1, nullptr); fsequal_reduce_.resize(tindex + 1, nullptr); fshash_reduce_.resize(tindex + 1, nullptr); } @@ -392,12 +394,16 @@ VisitAttrs(Object* self, AttrVisitor* visitor) const { fvisit_attrs_[tindex](self, visitor); } -inline std::string ReflectionVTable::GetGlobalKey(Object* self) const { +inline bool ReflectionVTable::GetReprBytes(const Object* self, + std::string* repr_bytes) const { uint32_t tindex = self->type_index(); - if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) { - return fglobal_key_[tindex](self); + if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) { + if (repr_bytes != nullptr) { + *repr_bytes = frepr_bytes_[tindex](self); + } + return true; } else { - return std::string(); + return false; } } diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index aa43df5a6697..e091cd12a208 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -79,8 +79,16 @@ def _convert(item, _): return item return _convert + def _update_global_key(item, _): + item["repr_str"] = item["global_key"] + del item["global_key"] + return item + node_map = { # Base IR + "SourceName": _update_global_key, + "EnvFunc": _update_global_key, + "relay.Op": _update_global_key, "relay.TypeVar": _ftype_var, "relay.GlobalTypeVar": _ftype_var, "relay.Type": _rename("Type"), diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 3e85c5f47b52..4d3ed30bc032 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -69,7 +69,7 @@ TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc") TVM_REGISTER_NODE_TYPE(EnvFuncNode) .set_creator(CreateEnvNode) -.set_global_key([](const Object* n) -> std::string { +.set_repr_bytes([](const Object* n) -> std::string { return static_cast(n)->name; }); diff --git a/src/ir/op.cc b/src/ir/op.cc index 54374eb8a526..6a50240ee7a1 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -223,7 +223,7 @@ ObjectPtr CreateOp(const std::string& name) { TVM_REGISTER_NODE_TYPE(OpNode) .set_creator(CreateOp) -.set_global_key([](const Object* n) { +.set_repr_bytes([](const Object* n) { return static_cast(n)->name; }); diff --git a/src/ir/span.cc b/src/ir/span.cc index d03903c2d3a5..f84353de2a8b 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -56,7 +56,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(SourceNameNode) .set_creator(GetSourceNameNode) -.set_global_key([](const Object* n) { +.set_repr_bytes([](const Object* n) { return static_cast(n)->name; }); diff --git a/src/node/container.cc b/src/node/container.cc index 8fff151ce605..e7e497946b6f 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -48,7 +48,21 @@ struct StringObjTrait { } }; -TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait); +struct RefToObjectPtr : public ObjectRef { + static ObjectPtr Get(const ObjectRef& ref) { + return GetDataPtr(ref); + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) +.set_creator([](const std::string& bytes) { + return RefToObjectPtr::Get(runtime::String(bytes)); +}) +.set_repr_bytes([](const Object* n) -> std::string { + return GetRef( + static_cast(n)).operator std::string(); +}); + struct ADTObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 824874f24ab0..08a914ff38f9 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -178,13 +178,13 @@ ReflectionVTable* ReflectionVTable::Global() { ObjectPtr ReflectionVTable::CreateInitObject(const std::string& type_key, - const std::string& global_key) const { + const std::string& repr_bytes) const { uint32_t tindex = Object::TypeKey2Index(type_key); if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) { LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE"; } - return fcreate_[tindex](global_key); + return fcreate_[tindex](repr_bytes); } class NodeAttrSetter : public AttrVisitor { diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 11c9e8fc8cb6..ee6072d77c1c 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -32,6 +32,7 @@ #include #include +#include #include #include "../support/base64.h" @@ -46,6 +47,26 @@ inline DataType String2Type(std::string s) { return DataType(runtime::String2DLDataType(s)); } +inline std::string Base64Decode(std::string s) { + dmlc::MemoryStringStream mstrm(&s); + support::Base64InStream b64strm(&mstrm); + std::string output; + b64strm.InitPosition(); + dmlc::Stream* strm = &b64strm; + strm->Read(&output); + return output; +} + +inline std::string Base64Encode(std::string s) { + std::string blob; + dmlc::MemoryStringStream mstrm(&blob); + support::Base64OutStream b64strm(&mstrm); + dmlc::Stream* strm = &b64strm; + strm->Write(s); + b64strm.Finish(); + return blob; +} + // indexer to index all the nodes class NodeIndexer : public AttrVisitor { public: @@ -103,7 +124,10 @@ class NodeIndexer : public AttrVisitor { MakeIndex(const_cast(kv.second.get())); } } else { - reflection_->VisitAttrs(node, this); + // if the node already have repr bytes, no need to visit Attrs. + if (!reflection_->GetReprBytes(node, nullptr)) { + reflection_->VisitAttrs(node, this); + } } } }; @@ -115,8 +139,8 @@ using AttrMap = std::map; struct JSONNode { /*! \brief The type of key of the object. */ std::string type_key; - /*! \brief The global key for global object. */ - std::string global_key; + /*! \brief The str repr representation. */ + std::string repr_bytes; /*! \brief the attributes */ AttrMap attrs; /*! \brief keys of a map. */ @@ -127,8 +151,15 @@ struct JSONNode { void Save(dmlc::JSONWriter *writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("type_key", type_key); - if (global_key.size() != 0) { - writer->WriteObjectKeyValue("global_key", global_key); + if (repr_bytes.size() != 0) { + // choose to use str representation or base64, based on whether + // the byte representation is printable. + if (std::all_of(repr_bytes.begin(), repr_bytes.end(), + [](char ch) { return std::isprint(ch); })) { + writer->WriteObjectKeyValue("repr_str", repr_bytes); + } else { + writer->WriteObjectKeyValue("repr_b64", Base64Encode(repr_bytes)); + } } if (attrs.size() != 0) { writer->WriteObjectKeyValue("attrs", attrs); @@ -145,15 +176,24 @@ struct JSONNode { void Load(dmlc::JSONReader *reader) { attrs.clear(); data.clear(); - global_key.clear(); + repr_bytes.clear(); type_key.clear(); + std::string repr_b64, repr_str; dmlc::JSONObjectReadHelper helper; helper.DeclareOptionalField("type_key", &type_key); - helper.DeclareOptionalField("global_key", &global_key); + helper.DeclareOptionalField("repr_b64", &repr_b64); + helper.DeclareOptionalField("repr_str", &repr_str); helper.DeclareOptionalField("attrs", &attrs); helper.DeclareOptionalField("keys", &keys); helper.DeclareOptionalField("data", &data); helper.ReadAllFields(reader); + + if (repr_str.size() != 0) { + CHECK_EQ(repr_b64.size(), 0U); + repr_bytes = std::move(repr_str); + } else if (repr_b64.size() != 0) { + repr_bytes = Base64Decode(repr_b64); + } } }; @@ -212,10 +252,8 @@ class JSONAttrGetter : public AttrVisitor { return; } node_->type_key = node->GetTypeKey(); - node_->global_key = reflection_->GetGlobalKey(node); - // No need to recursively visit fields of global singleton - // They are registered via the environment. - if (node_->global_key.length() != 0) return; + // do not need to print additional things once we have repr bytes. + if (reflection_->GetReprBytes(node, &(node_->repr_bytes))) return; // populates the fields. node_->attrs.clear(); @@ -434,7 +472,7 @@ ObjectRef LoadJSON(std::string json_str) { for (const JSONNode& jnode : jgraph.nodes) { if (jnode.type_key.length() != 0) { ObjectPtr node = - reflection->CreateInitObject(jnode.type_key, jnode.global_key); + reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); nodes.emplace_back(node); } else { nodes.emplace_back(ObjectPtr()); @@ -447,9 +485,12 @@ ObjectRef LoadJSON(std::string json_str) { for (size_t i = 0; i < nodes.size(); ++i) { setter.node_ = &jgraph.nodes[i]; - // do not need to recover content of global singleton object - // they are registered via the environment - if (setter.node_->global_key.length() == 0) { + // Skip the nodes that has an repr bytes representation. + // NOTE: the second condition is used to guard the case + // where the repr bytes itself is an empty string "". + if (setter.node_->repr_bytes.length() == 0 && + nodes[i] != nullptr && + !reflection->GetReprBytes(nodes[i].get(), nullptr)) { setter.Set(nodes[i].get()); } } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 54812be62d9b..16d02d2cc224 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -16,6 +16,7 @@ # under the License. import tvm +from tvm import relay from tvm import te import json @@ -108,6 +109,22 @@ def test_global_var(): assert isinstance(tvar, tvm.ir.GlobalVar) +def test_op(): + nodes = [ + {"type_key": ""}, + {"type_key": "relay.Op", + "global_key": "nn.conv2d"} + ] + data = { + "root" : 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + op = tvm.ir.load_json(json.dumps(data)) + assert op == relay.op.get("nn.conv2d") + + def test_tir_var(): nodes = [ {"type_key": ""}, @@ -132,6 +149,7 @@ def test_tir_var(): if __name__ == "__main__": + test_op() test_type_var() test_incomplete_type() test_func_tuple_type() diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index f2848ff0ef50..975192293d87 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -89,7 +89,20 @@ def test(x): assert x.func(10) == 11 +def test_string(): + # non printable str, need to store by b64 + s1 = tvm.runtime.String("xy\x01z") + s2 = tvm.ir.load_json(tvm.ir.save_json(s1)) + tvm.ir.assert_structural_equal(s1, s2) + + # printable str, need to store by repr_str + s1 = tvm.runtime.String("xyz") + s2 = tvm.ir.load_json(tvm.ir.save_json(s1)) + tvm.ir.assert_structural_equal(s1, s2) + + if __name__ == "__main__": + test_string() test_env_func() test_make_node() test_make_smap()