From a2d8ac5b97cca18f3f051c9b139dcfbaecdaae1a Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 15 Jun 2018 20:08:54 -0700 Subject: [PATCH] [CONTAINER] Introduce StrMap (#1292) --- CMakeLists.txt | 2 +- HalideIR | 2 +- include/tvm/packed_func_ext.h | 19 ++++ python/tvm/_ffi/_ctypes/node.py | 1 + python/tvm/_ffi/node_generic.py | 6 +- python/tvm/container.py | 17 ++- src/api/api_lang.cc | 111 +++++++++++++------ src/lang/expr.cc | 1 + tests/cpp/container_test.cc | 9 ++ tests/python/unittest/test_lang_container.py | 13 +++ 10 files changed, 142 insertions(+), 39 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 972e2cbe712d..e9699b1f6736 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,7 +196,7 @@ if(GTEST_LIB) add_executable(${__execname} ${__srcpath}) list(APPEND TEST_EXECS ${__execname}) target_link_libraries(${__execname} - tvm ${GTEST_LIB} ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS} pthread) + tvm ${GTEST_LIB} pthread) set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1) set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) endforeach() diff --git a/HalideIR b/HalideIR index a3698398faff..0b7e25275138 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit a3698398faff7fec1c0fa4e4479357651382db75 +Subproject commit 0b7e25275138768bb05edb9b9db2c86d0fb09c9a diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 455717ce753c..95964547ef8e 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -60,6 +60,25 @@ struct NodeTypeChecker > { } }; +template +struct NodeTypeChecker > { + static inline bool Check(Node* sptr) { + if (sptr == nullptr) return false; + if (!sptr->is_type()) return false; + StrMapNode* n = static_cast(sptr); + for (const auto& kv : n->data) { + if (!NodeTypeChecker::Check(kv.second.get())) return false; + } + return true; + } + static inline void PrintName(std::ostringstream& os) { // NOLINT(*) + os << "map::PrintName(os); + os << '>'; + } +}; + template struct NodeTypeChecker > { static inline bool Check(Node* sptr) { diff --git a/python/tvm/_ffi/_ctypes/node.py b/python/tvm/_ffi/_ctypes/node.py index cb32b83291d1..01244519532b 100644 --- a/python/tvm/_ffi/_ctypes/node.py +++ b/python/tvm/_ffi/_ctypes/node.py @@ -30,6 +30,7 @@ def _return_node(x): C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func( _return_node, TypeCode.NODE_HANDLE) + class NodeBase(object): __slots__ = ["handle"] # pylint: disable=no-member diff --git a/python/tvm/_ffi/node_generic.py b/python/tvm/_ffi/node_generic.py index 7561097bf305..b7230f29da59 100644 --- a/python/tvm/_ffi/node_generic.py +++ b/python/tvm/_ffi/node_generic.py @@ -13,12 +13,14 @@ def _set_class_node_base(cls): global _CLASS_NODE_BASE _CLASS_NODE_BASE = cls + class NodeGeneric(object): """Base class for all classes that can be converted to node.""" def asnode(self): """Convert value to node""" raise NotImplementedError() + def convert_to_node(value): """Convert a python value to corresponding node type. @@ -46,7 +48,8 @@ def convert_to_node(value): elif isinstance(value, dict): vlist = [] for item in value.items(): - if not isinstance(item[0], _CLASS_NODE_BASE): + if (not isinstance(item[0], _CLASS_NODE_BASE) and + not isinstance(item[0], string_types)): raise ValueError("key of map must already been a container type") vlist.append(item[0]) vlist.append(convert_to_node(item[1])) @@ -56,6 +59,7 @@ def convert_to_node(value): else: raise ValueError("don't know how to convert type %s to node" % type(value)) + def const(value, dtype=None): """Construct a constant value for a given type. diff --git a/python/tvm/container.py b/python/tvm/container.py index d1d4546fd86a..27e533113926 100644 --- a/python/tvm/container.py +++ b/python/tvm/container.py @@ -32,9 +32,8 @@ class Map(NodeBase): """Map container of TVM. You do not need to create Map explicitly. - Normally python dict will be converted automatically - to Array during tvm function call. - You may get Map in return values of TVM function call. + Normally python dict will be converted automaticall to Map during tvm function call. + You can use convert to create a dict[NodeBase-> NodeBase] into a Map """ def __getitem__(self, k): return _api_internal._MapGetItem(self, k) @@ -51,6 +50,18 @@ def __len__(self): return _api_internal._MapSize(self) +@register_node +class StrMap(Map): + """A special map container that has str as key. + + You can use convert to create a dict[str->NodeBase] into a Map. + """ + def items(self): + """Get the items from the map""" + akvs = _api_internal._MapItems(self) + return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)] + + @register_node class Range(NodeBase): """Represent range in TVM. diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index ba158a7c3f79..00f8ba5f3847 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -76,63 +76,108 @@ TVM_REGISTER_API("_ArraySize") TVM_REGISTER_API("_Map") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size() % 2, 0); - MapNode::ContainerType data; - for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kNodeHandle) - << "need content of array to be NodeBase"; - CHECK(args[i + 1].type_code() == kNodeHandle) - << "need content of array to be NodeBase"; - data.emplace(std::make_pair(args[i].node_sptr(), - args[i + 1].node_sptr())); + if (args.size() != 0 && args[0].type_code() == kStr) { + // StrMap + StrMapNode::ContainerType data; + for (int i = 0; i < args.num_args; i += 2) { + CHECK(args[i].type_code() == kStr) + << "key of str map need to be str"; + CHECK(args[i + 1].type_code() == kNodeHandle) + << "value of the map to be NodeRef"; + data.emplace(std::make_pair(args[i].operator std::string(), + args[i + 1].node_sptr())); + } + auto node = std::make_shared(); + node->data = std::move(data); + *ret = node; + } else { + // Container node. + MapNode::ContainerType data; + for (int i = 0; i < args.num_args; i += 2) { + CHECK(args[i].type_code() == kNodeHandle) + << "key of str map need to be str"; + CHECK(args[i + 1].type_code() == kNodeHandle) + << "value of map to be NodeRef"; + data.emplace(std::make_pair(args[i].node_sptr(), + args[i + 1].node_sptr())); + } + auto node = std::make_shared(); + node->data = std::move(data); + *ret = node; } - auto node = std::make_shared(); - node->data = std::move(data); - *ret = node; }); TVM_REGISTER_API("_MapSize") .set_body([](TVMArgs args, TVMRetValue* ret) { auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); - *ret = static_cast(n->data.size()); + if (sptr->is_type()) { + auto* n = static_cast(sptr.get()); + *ret = static_cast(n->data.size()); + } else { + CHECK(sptr->is_type()); + auto* n = static_cast(sptr.get()); + *ret = static_cast(n->data.size()); + } }); TVM_REGISTER_API("_MapGetItem") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK(args[0].type_code() == kNodeHandle); - CHECK(args[1].type_code() == kNodeHandle); auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); - auto it = n->data.find(args[1].node_sptr()); - CHECK(it != n->data.end()) - << "cannot find the corresponding key in the Map"; - *ret = (*it).second; + if (sptr->is_type()) { + CHECK(args[1].type_code() == kNodeHandle); + auto* n = static_cast(sptr.get()); + auto it = n->data.find(args[1].node_sptr()); + CHECK(it != n->data.end()) + << "cannot find the corresponding key in the Map"; + *ret = (*it).second; + } else { + CHECK(sptr->is_type()); + auto* n = static_cast(sptr.get()); + auto it = n->data.find(args[1].operator std::string()); + CHECK(it != n->data.end()) + << "cannot find the corresponding key in the Map"; + *ret = (*it).second; + } }); TVM_REGISTER_API("_MapCount") .set_body([](TVMArgs args, TVMRetValue* ret) { CHECK(args[0].type_code() == kNodeHandle); - CHECK(args[1].type_code() == kNodeHandle); auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); - *ret = static_cast( - n->data.count(args[1].node_sptr())); + if (sptr->is_type()) { + auto* n = static_cast(sptr.get()); + CHECK(args[1].type_code() == kNodeHandle); + *ret = static_cast( + n->data.count(args[1].node_sptr())); + } else { + CHECK(sptr->is_type()); + auto* n = static_cast(sptr.get()); + *ret = static_cast( + n->data.count(args[1].operator std::string())); + } }); TVM_REGISTER_API("_MapItems") .set_body([](TVMArgs args, TVMRetValue* ret) { auto& sptr = args[0].node_sptr(); - CHECK(sptr->is_type()); - auto* n = static_cast(sptr.get()); - auto rkvs = std::make_shared(); - for (const auto& kv : n->data) { - rkvs->data.push_back(kv.first); - rkvs->data.push_back(kv.second); + if (sptr->is_type()) { + auto* n = static_cast(sptr.get()); + auto rkvs = std::make_shared(); + for (const auto& kv : n->data) { + rkvs->data.push_back(kv.first); + rkvs->data.push_back(kv.second); + } + *ret = rkvs; + } else { + auto* n = static_cast(sptr.get()); + auto rkvs = std::make_shared(); + for (const auto& kv : n->data) { + rkvs->data.push_back(ir::StringImm::make(kv.first).node_); + rkvs->data.push_back(kv.second); + } + *ret = rkvs; } - *ret = rkvs; }); TVM_REGISTER_API("Range") diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 0fb783d70cb8..684211079e94 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ArrayNode); TVM_REGISTER_NODE_TYPE(MapNode); +TVM_REGISTER_NODE_TYPE(StrMapNode); TVM_REGISTER_NODE_TYPE(RangeNode); TVM_REGISTER_NODE_TYPE(IterVarNode); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index b5141262c0d3..4a0500bf4faf 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -35,6 +35,15 @@ TEST(Map, Expr) { CHECK(!dict.count(zz)); } +TEST(StrMap, Expr) { + using namespace tvm; + Var x("x"); + auto z = max(x + 1 + 2, 100); + Map dict{{"x", z}, {"z", 2}}; + CHECK(dict.size() == 2); + CHECK(dict["x"].same_as(z)); +} + TEST(Map, Mutate) { using namespace tvm; Var x("x"); diff --git a/tests/python/unittest/test_lang_container.py b/tests/python/unittest/test_lang_container.py index d945fce31fd4..615c5ac0a8d5 100644 --- a/tests/python/unittest/test_lang_container.py +++ b/tests/python/unittest/test_lang_container.py @@ -10,6 +10,7 @@ def test_array_save_load_json(): a_loaded = tvm.load_json(json_str) assert(a[1].value == 2) + def test_map(): a = tvm.var('a') b = tvm.var('b') @@ -22,6 +23,17 @@ def test_map(): assert b in dd assert a + 1 not in amap + +def test_str_map(): + amap = tvm.convert({'a': 2, 'b': 3}) + assert 'a' in amap + assert len(amap) == 2 + dd = dict(amap.items()) + assert amap['a'].value == 2 + assert 'a' in dd + assert 'b' in dd + + def test_map_save_load_json(): a = tvm.var('a') b = tvm.var('b') @@ -35,6 +47,7 @@ def test_map_save_load_json(): if __name__ == "__main__": + test_str_map() test_array() test_map() test_array_save_load_json()