diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 1a7a8dfef685..a3cfdaf267ac 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -50,6 +50,7 @@ using runtime::ObjectRef; using runtime::String; using runtime::StringObj; +/*! \brief String-aware ObjectRef hash functor */ struct ObjectHash { size_t operator()(const ObjectRef& a) const { if (const auto* str = a.as()) { @@ -59,6 +60,7 @@ struct ObjectHash { } }; +/*! \brief String-aware ObjectRef equal functor */ struct ObjectEqual { bool operator()(const ObjectRef& a, const ObjectRef& b) const { if (a.same_as(b)) { @@ -96,8 +98,7 @@ class MapNode : public Object { * \tparam V The value NodeRef type. */ template ::value || - std::is_base_of::value>::type, + typename = typename std::enable_if::value>::type, typename = typename std::enable_if::value>::type> class Map : public ObjectRef { public: diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 6fc24c0acfdc..2facc79b3af4 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -129,6 +129,7 @@ def _convert(item, nodes): "relay.PassContext": _rename("transform.PassContext"), "relay.ModulePass": _rename("transform.ModulePass"), "relay.Sequential": _rename("transform.Sequential"), + "StrMap": _rename("Map"), # TIR "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], diff --git a/src/node/container.cc b/src/node/container.cc index f8bad0070c55..bc212a83d87e 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -247,40 +247,51 @@ struct MapNodeTrait { } static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) { - if (key->data.empty()) { - hash_reduce(uint64_t(0)); - return; - } - if (key->data.begin()->first->IsInstance()) { + bool is_str_map = std::all_of(key->data.begin(), key->data.end(), [](const auto& v) { + return v.first->template IsInstance(); + }); + if (is_str_map) { SHashReduceForSMap(key, hash_reduce); } else { SHashReduceForOMap(key, hash_reduce); } } + static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { + for (const auto& kv : lhs->data) { + // Only allow equal checking if the keys are already mapped + // This resolves common use cases where we want to store + // Map where Var is defined in the function + // parameters. + ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); + if (!rhs_key.defined()) return false; + auto it = rhs->data.find(rhs_key); + if (it == rhs->data.end()) return false; + if (!equal(kv.second, it->second)) return false; + } + return true; + } + + static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { + for (const auto& kv : lhs->data) { + auto it = rhs->data.find(kv.first); + if (it == rhs->data.end()) return false; + if (!equal(kv.second, it->second)) return false; + } + return true; + } + static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { if (rhs->data.size() != lhs->data.size()) return false; if (rhs->data.size() == 0) return true; - if (lhs->data.begin()->first->IsInstance()) { - for (const auto& kv : lhs->data) { - auto it = rhs->data.find(kv.first); - if (it == rhs->data.end()) return false; - if (!equal(kv.second, it->second)) return false; - } - } else { - for (const auto& kv : lhs->data) { - // Only allow equal checking if the keys are already mapped - // This resolves common use cases where we want to store - // Map where Var is defined in the function - // parameters. - ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); - if (!rhs_key.defined()) return false; - auto it = rhs->data.find(rhs_key); - if (it == rhs->data.end()) return false; - if (!equal(kv.second, it->second)) return false; - } + bool ls = std::all_of(lhs->data.begin(), lhs->data.end(), + [](const auto& v) { return v.first->template IsInstance(); }); + bool rs = std::all_of(rhs->data.begin(), rhs->data.end(), + [](const auto& v) { return v.first->template IsInstance(); }); + if (ls != rs) { + return false; } - return true; + return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) : SEqualReduceForOMap(lhs, rhs, equal); } }; diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 9845a6fb8c95..386653349904 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -110,11 +110,18 @@ class NodeIndexer : public AttrVisitor { } } else if (node->IsInstance()) { MapNode* n = static_cast(node); - for (const auto& kv : n->data) { - if (!kv.first->IsInstance()) { + bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) { + return v.first->template IsInstance(); + }); + if (is_str_map) { + for (const auto& kv : n->data) { + MakeIndex(const_cast(kv.second.get())); + } + } else { + for (const auto& kv : n->data) { MakeIndex(const_cast(kv.first.get())); + MakeIndex(const_cast(kv.second.get())); } - MakeIndex(const_cast(kv.second.get())); } } else { // if the node already have repr bytes, no need to visit Attrs. @@ -246,13 +253,19 @@ class JSONAttrGetter : public AttrVisitor { } } else if (node->IsInstance()) { MapNode* n = static_cast(node); - for (const auto& kv : n->data) { - if (const auto* str = kv.first.as()) { - node_->keys.push_back(std::string(str->data, str->size)); - } else { + bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) { + return v.first->template IsInstance(); + }); + if (is_str_map) { + for (const auto& kv : n->data) { + node_->keys.push_back(Downcast(kv.first)); + node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); + } + } else { + for (const auto& kv : n->data) { node_->data.push_back(node_index_->at(const_cast(kv.first.get()))); + node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); } - node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); } } else { // recursively index normal object. diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index c961f991c8a7..00d41f0ffc3e 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -186,6 +186,34 @@ def test_tir_var(): assert y.name == "y" +def test_str_map(): + nodes = [ + {'type_key': ''}, + {'type_key': 'StrMap', 'keys': ['z', 'x'], 'data': [2, 3]}, + {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}}, + {'type_key': 'Max', 'attrs': {'a': '4', 'b': '10', 'dtype': 'int32'}}, + {'type_key': 'Add', 'attrs': {'a': '5', 'b': '9', 'dtype': 'int32'}}, + {'type_key': 'Add', 'attrs': {'a': '6', 'b': '8', 'dtype': 'int32'}}, + {'type_key': 'tir.Var', 'attrs': {'dtype': 'int32', 'name': '7', 'type_annotation': '0'}}, + {'type_key': 'runtime.String', 'repr_str': 'x'}, + {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '1'}}, + {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}}, + {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '100'}} + ] + data = { + "root" : 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + x = tvm.ir.load_json(json.dumps(data)) + assert(isinstance(x, tvm.ir.container.Map)) + assert(len(x) == 2) + assert('x' in x) + assert('z' in x) + assert(bool(x['z'] == 2)) + + if __name__ == "__main__": test_op() test_type_var() @@ -194,3 +222,4 @@ def test_tir_var(): test_func_tuple_type() test_global_var() test_tir_var() + test_str_map()