Skip to content

Commit

Permalink
[Object] Restore the StrMap behavior in JSON/SHash/SEqual (apache#5719)
Browse files Browse the repository at this point in the history
junrushao authored and Trevor Morris committed Jun 9, 2020
1 parent 932e434 commit ccd455c
Showing 5 changed files with 89 additions and 34 deletions.
5 changes: 3 additions & 2 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
@@ -50,6 +50,7 @@ using runtime::ObjectRef;
using runtime::String;
using runtime::StringObj;

/*! \brief String-aware ObjectRef hash functor */
struct ObjectHash {
size_t operator()(const ObjectRef& a) const {
if (const auto* str = a.as<StringObj>()) {
@@ -59,6 +60,7 @@ struct ObjectHash {
}
};

/*! \brief String-aware ObjectRef equal functor */
struct ObjectEqual {
bool operator()(const ObjectRef& a, const ObjectRef& b) const {
if (a.same_as(b)) {
@@ -96,8 +98,7 @@ class MapNode : public Object {
* \tparam V The value NodeRef type.
*/
template <typename K, typename V,
typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value ||
std::is_base_of<std::string, K>::value>::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
class Map : public ObjectRef {
public:
1 change: 1 addition & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
@@ -129,6 +129,7 @@ def _convert(item, nodes):
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequential": _rename("transform.Sequential"),
"StrMap": _rename("Map"),
# TIR
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
59 changes: 35 additions & 24 deletions src/node/container.cc
Original file line number Diff line number Diff line change
@@ -247,40 +247,51 @@ struct MapNodeTrait {
}

static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) {
if (key->data.empty()) {
hash_reduce(uint64_t(0));
return;
}
if (key->data.begin()->first->IsInstance<StringObj>()) {
bool is_str_map = std::all_of(key->data.begin(), key->data.end(), [](const auto& v) {
return v.first->template IsInstance<StringObj>();
});
if (is_str_map) {
SHashReduceForSMap(key, hash_reduce);
} else {
SHashReduceForOMap(key, hash_reduce);
}
}

static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
for (const auto& kv : lhs->data) {
// Only allow equal checking if the keys are already mapped
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
if (!rhs_key.defined()) return false;
auto it = rhs->data.find(rhs_key);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}

static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
return true;
}

static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
if (rhs->data.size() != lhs->data.size()) return false;
if (rhs->data.size() == 0) return true;
if (lhs->data.begin()->first->IsInstance<StringObj>()) {
for (const auto& kv : lhs->data) {
auto it = rhs->data.find(kv.first);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
} else {
for (const auto& kv : lhs->data) {
// Only allow equal checking if the keys are already mapped
// This resolves common use cases where we want to store
// Map<Var, Value> where Var is defined in the function
// parameters.
ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
if (!rhs_key.defined()) return false;
auto it = rhs->data.find(rhs_key);
if (it == rhs->data.end()) return false;
if (!equal(kv.second, it->second)) return false;
}
bool ls = std::all_of(lhs->data.begin(), lhs->data.end(),
[](const auto& v) { return v.first->template IsInstance<StringObj>(); });
bool rs = std::all_of(rhs->data.begin(), rhs->data.end(),
[](const auto& v) { return v.first->template IsInstance<StringObj>(); });
if (ls != rs) {
return false;
}
return true;
return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) : SEqualReduceForOMap(lhs, rhs, equal);
}
};

29 changes: 21 additions & 8 deletions src/node/serialization.cc
Original file line number Diff line number Diff line change
@@ -110,11 +110,18 @@ class NodeIndexer : public AttrVisitor {
}
} else if (node->IsInstance<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
for (const auto& kv : n->data) {
if (!kv.first->IsInstance<StringObj>()) {
bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) {
return v.first->template IsInstance<StringObj>();
});
if (is_str_map) {
for (const auto& kv : n->data) {
MakeIndex(const_cast<Object*>(kv.second.get()));
}
} else {
for (const auto& kv : n->data) {
MakeIndex(const_cast<Object*>(kv.first.get()));
MakeIndex(const_cast<Object*>(kv.second.get()));
}
MakeIndex(const_cast<Object*>(kv.second.get()));
}
} else {
// if the node already have repr bytes, no need to visit Attrs.
@@ -246,13 +253,19 @@ class JSONAttrGetter : public AttrVisitor {
}
} else if (node->IsInstance<MapNode>()) {
MapNode* n = static_cast<MapNode*>(node);
for (const auto& kv : n->data) {
if (const auto* str = kv.first.as<StringObj>()) {
node_->keys.push_back(std::string(str->data, str->size));
} else {
bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) {
return v.first->template IsInstance<StringObj>();
});
if (is_str_map) {
for (const auto& kv : n->data) {
node_->keys.push_back(Downcast<String>(kv.first));
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
}
} else {
for (const auto& kv : n->data) {
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.first.get())));
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
}
node_->data.push_back(node_index_->at(const_cast<Object*>(kv.second.get())));
}
} else {
// recursively index normal object.
29 changes: 29 additions & 0 deletions tests/python/relay/test_json_compact.py
Original file line number Diff line number Diff line change
@@ -186,6 +186,34 @@ def test_tir_var():
assert y.name == "y"


def test_str_map():
nodes = [
{'type_key': ''},
{'type_key': 'StrMap', 'keys': ['z', 'x'], 'data': [2, 3]},
{'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}},
{'type_key': 'Max', 'attrs': {'a': '4', 'b': '10', 'dtype': 'int32'}},
{'type_key': 'Add', 'attrs': {'a': '5', 'b': '9', 'dtype': 'int32'}},
{'type_key': 'Add', 'attrs': {'a': '6', 'b': '8', 'dtype': 'int32'}},
{'type_key': 'tir.Var', 'attrs': {'dtype': 'int32', 'name': '7', 'type_annotation': '0'}},
{'type_key': 'runtime.String', 'repr_str': 'x'},
{'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '1'}},
{'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}},
{'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '100'}}
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
x = tvm.ir.load_json(json.dumps(data))
assert(isinstance(x, tvm.ir.container.Map))
assert(len(x) == 2)
assert('x' in x)
assert('z' in x)
assert(bool(x['z'] == 2))


if __name__ == "__main__":
test_op()
test_type_var()
@@ -194,3 +222,4 @@ def test_tir_var():
test_func_tuple_type()
test_global_var()
test_tir_var()
test_str_map()

0 comments on commit ccd455c

Please sign in to comment.