Skip to content

Commit

Permalink
[CONTAINER] Introduce StrMap (apache#1292)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jun 16, 2018
1 parent c970359 commit 146714a
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 39 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ if(GTEST_LIB)
add_executable(${__execname} ${__srcpath})
list(APPEND TEST_EXECS ${__execname})
target_link_libraries(${__execname}
tvm ${GTEST_LIB} ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS} pthread)
tvm ${GTEST_LIB} pthread)
set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1)
set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
endforeach()
Expand Down
2 changes: 1 addition & 1 deletion HalideIR
19 changes: 19 additions & 0 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ struct NodeTypeChecker<Array<T> > {
}
};

template<typename V>
struct NodeTypeChecker<Map<std::string, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<StrMapNode>()) return false;
StrMapNode* n = static_cast<StrMapNode*>(sptr);
for (const auto& kv : n->data) {
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<string";
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};

template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) {
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/_ctypes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _return_node(x):
C_TO_PY_ARG_SWITCH[TypeCode.NODE_HANDLE] = _wrap_arg_func(
_return_node, TypeCode.NODE_HANDLE)


class NodeBase(object):
__slots__ = ["handle"]
# pylint: disable=no-member
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/_ffi/node_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ def _set_class_node_base(cls):
global _CLASS_NODE_BASE
_CLASS_NODE_BASE = cls


class NodeGeneric(object):
"""Base class for all classes that can be converted to node."""
def asnode(self):
"""Convert value to node"""
raise NotImplementedError()


def convert_to_node(value):
"""Convert a python value to corresponding node type.
Expand Down Expand Up @@ -46,7 +48,8 @@ def convert_to_node(value):
elif isinstance(value, dict):
vlist = []
for item in value.items():
if not isinstance(item[0], _CLASS_NODE_BASE):
if (not isinstance(item[0], _CLASS_NODE_BASE) and
not isinstance(item[0], string_types)):
raise ValueError("key of map must already been a container type")
vlist.append(item[0])
vlist.append(convert_to_node(item[1]))
Expand All @@ -56,6 +59,7 @@ def convert_to_node(value):
else:
raise ValueError("don't know how to convert type %s to node" % type(value))


def const(value, dtype=None):
"""Construct a constant value for a given type.
Expand Down
17 changes: 14 additions & 3 deletions python/tvm/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ class Map(NodeBase):
"""Map container of TVM.
You do not need to create Map explicitly.
Normally python dict will be converted automatically
to Array during tvm function call.
You may get Map in return values of TVM function call.
Normally python dict will be converted automaticall to Map during tvm function call.
You can use convert to create a dict[NodeBase-> NodeBase] into a Map
"""
def __getitem__(self, k):
return _api_internal._MapGetItem(self, k)
Expand All @@ -51,6 +50,18 @@ def __len__(self):
return _api_internal._MapSize(self)


@register_node
class StrMap(Map):
"""A special map container that has str as key.
You can use convert to create a dict[str->NodeBase] into a Map.
"""
def items(self):
"""Get the items from the map"""
akvs = _api_internal._MapItems(self)
return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)]


@register_node
class Range(NodeBase):
"""Represent range in TVM.
Expand Down
111 changes: 78 additions & 33 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,63 +76,108 @@ TVM_REGISTER_API("_ArraySize")
TVM_REGISTER_API("_Map")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kNodeHandle)
<< "need content of array to be NodeBase";
CHECK(args[i + 1].type_code() == kNodeHandle)
<< "need content of array to be NodeBase";
data.emplace(std::make_pair(args[i].node_sptr(),
args[i + 1].node_sptr()));
if (args.size() != 0 && args[0].type_code() == kStr) {
// StrMap
StrMapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kStr)
<< "key of str map need to be str";
CHECK(args[i + 1].type_code() == kNodeHandle)
<< "value of the map to be NodeRef";
data.emplace(std::make_pair(args[i].operator std::string(),
args[i + 1].node_sptr()));
}
auto node = std::make_shared<StrMapNode>();
node->data = std::move(data);
*ret = node;
} else {
// Container node.
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kNodeHandle)
<< "key of str map need to be str";
CHECK(args[i + 1].type_code() == kNodeHandle)
<< "value of map to be NodeRef";
data.emplace(std::make_pair(args[i].node_sptr(),
args[i + 1].node_sptr()));
}
auto node = std::make_shared<MapNode>();
node->data = std::move(data);
*ret = node;
}
auto node = std::make_shared<MapNode>();
node->data = std::move(data);
*ret = node;
});

TVM_REGISTER_API("_MapSize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.size());
if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.size());
} else {
CHECK(sptr->is_type<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(sptr.get());
*ret = static_cast<int64_t>(n->data.size());
}
});

TVM_REGISTER_API("_MapGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle);
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
auto it = n->data.find(args[1].node_sptr());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
if (sptr->is_type<MapNode>()) {
CHECK(args[1].type_code() == kNodeHandle);
auto* n = static_cast<const MapNode*>(sptr.get());
auto it = n->data.find(args[1].node_sptr());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
} else {
CHECK(sptr->is_type<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(sptr.get());
auto it = n->data.find(args[1].operator std::string());
CHECK(it != n->data.end())
<< "cannot find the corresponding key in the Map";
*ret = (*it).second;
}
});

TVM_REGISTER_API("_MapCount")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK(args[0].type_code() == kNodeHandle);
CHECK(args[1].type_code() == kNodeHandle);
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
*ret = static_cast<int64_t>(
n->data.count(args[1].node_sptr()));
if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get());
CHECK(args[1].type_code() == kNodeHandle);
*ret = static_cast<int64_t>(
n->data.count(args[1].node_sptr()));
} else {
CHECK(sptr->is_type<StrMapNode>());
auto* n = static_cast<const StrMapNode*>(sptr.get());
*ret = static_cast<int64_t>(
n->data.count(args[1].operator std::string()));
}
});

TVM_REGISTER_API("_MapItems")
.set_body([](TVMArgs args, TVMRetValue* ret) {
auto& sptr = args[0].node_sptr();
CHECK(sptr->is_type<MapNode>());
auto* n = static_cast<const MapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
if (sptr->is_type<MapNode>()) {
auto* n = static_cast<const MapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(kv.first);
rkvs->data.push_back(kv.second);
}
*ret = rkvs;
} else {
auto* n = static_cast<const StrMapNode*>(sptr.get());
auto rkvs = std::make_shared<ArrayNode>();
for (const auto& kv : n->data) {
rkvs->data.push_back(ir::StringImm::make(kv.first).node_);
rkvs->data.push_back(kv.second);
}
*ret = rkvs;
}
*ret = rkvs;
});

TVM_REGISTER_API("Range")
Expand Down
1 change: 1 addition & 0 deletions src/lang/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)

TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
TVM_REGISTER_NODE_TYPE(RangeNode);
TVM_REGISTER_NODE_TYPE(IterVarNode);

Expand Down
9 changes: 9 additions & 0 deletions tests/cpp/container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ TEST(Map, Expr) {
CHECK(!dict.count(zz));
}

TEST(StrMap, Expr) {
using namespace tvm;
Var x("x");
auto z = max(x + 1 + 2, 100);
Map<std::string, Expr> dict{{"x", z}, {"z", 2}};
CHECK(dict.size() == 2);
CHECK(dict["x"].same_as(z));
}

TEST(Map, Mutate) {
using namespace tvm;
Var x("x");
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_lang_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def test_array_save_load_json():
a_loaded = tvm.load_json(json_str)
assert(a[1].value == 2)


def test_map():
a = tvm.var('a')
b = tvm.var('b')
Expand All @@ -22,6 +23,17 @@ def test_map():
assert b in dd
assert a + 1 not in amap


def test_str_map():
amap = tvm.convert({'a': 2, 'b': 3})
assert 'a' in amap
assert len(amap) == 2
dd = dict(amap.items())
assert amap['a'].value == 2
assert 'a' in dd
assert 'b' in dd


def test_map_save_load_json():
a = tvm.var('a')
b = tvm.var('b')
Expand All @@ -35,6 +47,7 @@ def test_map_save_load_json():


if __name__ == "__main__":
test_str_map()
test_array()
test_map()
test_array_save_load_json()
Expand Down

0 comments on commit 146714a

Please sign in to comment.