From 06621d425d21f2f37160b968d0b9f24dbab281c5 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 11 Jan 2017 21:41:09 -0800 Subject: [PATCH] [LANG] Enable json load/save and pickle --- include/tvm/base.h | 40 +++- include/tvm/c_api.h | 9 +- python/tvm/_ctypes/_api.py | 24 ++- python/tvm/function.py | 32 +++ src/base/common.h | 42 ++++ src/base/saveload_json.cc | 306 ++++++++++++++++++++++++++++ src/c_api/c_api_function.cc | 12 ++ src/c_api/c_api_registry.h | 28 +-- src/lang/expr.cc | 11 +- src/schedule/schedule_lang.cc | 1 + tests/cpp/expr_test.cc | 3 + tests/python/test_lang_basic.py | 11 + tests/python/test_lang_container.py | 20 ++ tests/python/test_lang_schedule.py | 13 ++ tests/python/test_lang_tensor.py | 6 +- 15 files changed, 521 insertions(+), 37 deletions(-) create mode 100644 src/base/common.h create mode 100644 src/base/saveload_json.cc diff --git a/include/tvm/base.h b/include/tvm/base.h index 7c28f39054cd..47ed2861a6aa 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -21,6 +21,41 @@ using ::tvm::Node; using ::tvm::NodeRef; using ::tvm::AttrVisitor; +/*! + * \brief save the node as well as all the node it depends on as json. + * This can be used to serialize any TVM object + * + * \return the string representation of the node. + */ +std::string SaveJSON(const NodeRef& node); + +/*! + * \brief Internal implementation of LoadJSON + * Load tvm Node object from json and return a shared_ptr of Node. + * \param json_str The json string to load from. + * + * \return The shared_ptr of the Node. + */ +std::shared_ptr LoadJSON_(std::string json_str); + +/*! + * \brief Load the node from json string. + * This can be used to deserialize any TVM object. + * + * \param json_str The json string to load from. + * + * \tparam NodeType the nodetype + * + * \code + * Expr e = LoadJSON(json_str); + * \endcode + */ +template::value>::type > +inline NodeType LoadJSON(const std::string& json_str) { + return NodeType(LoadJSON_(json_str)); +} + /*! \brief typedef the factory function of data iterator */ using NodeFactory = std::function ()>; /*! @@ -32,8 +67,9 @@ struct NodeFactoryReg }; #define TVM_REGISTER_NODE_TYPE(TypeName) \ - DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \ - .set_body([]() { return std::make_shared(); }) + static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ + ::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \ + .set_body([]() { return std::make_shared(); }) } // namespace tvm #endif // TVM_BASE_H_ diff --git a/include/tvm/c_api.h b/include/tvm/c_api.h index 1f06b40fb56b..5f788508764e 100644 --- a/include/tvm/c_api.h +++ b/include/tvm/c_api.h @@ -15,14 +15,15 @@ /*! \brief TVM_DLL prefix for windows */ #ifdef _WIN32 #ifdef TVM_EXPORTS -#define TVM_DLL TVM_EXTERN_C __declspec(dllexport) +#define TVM_DLL __declspec(dllexport) #else -#define TVM_DLL TVM_EXTERN_C __declspec(dllimport) +#define TVM_DLL __declspec(dllimport) #endif #else -#define TVM_DLL TVM_EXTERN_C +#define TVM_DLL #endif +TVM_EXTERN_C { /*! \brief handle to functions */ typedef void* FunctionHandle; /*! \brief handle to node */ @@ -147,5 +148,5 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle, TVM_DLL int TVMNodeListAttrNames(NodeHandle handle, int *out_size, const char*** out_array); - +} // TVM_EXTERN_C #endif // TVM_C_API_H_ diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py index 41014e867da4..c9d980928967 100644 --- a/python/tvm/_ctypes/_api.py +++ b/python/tvm/_ctypes/_api.py @@ -89,7 +89,6 @@ def __getattr__(self, name): "'%s' object has no attribute '%s'" % (str(type(self)), name)) return value - def __hash__(self): return _function_internal._raw_ptr(self) @@ -111,6 +110,29 @@ def __dir__(self): names.append(py_str(plist[i])) return names + def __reduce__(self): + return (type(self), (None,), self.__getstate__()) + + def __getstate__(self): + handle = self.handle + if handle is not None: + return {'handle': _function_internal._save_json(self)} + else: + return {'handle': None} + + def __setstate__(self, state): + # pylint: disable=assigning-non-slot + handle = state['handle'] + if handle is not None: + json_str = handle + _push_arg(json_str) + other = _function_internal._load_json(json_str) + self.handle = other.handle + other.handle = None + else: + self.handle = None + + def const(value, dtype=None): """construct a constant""" if dtype is None: diff --git a/python/tvm/function.py b/python/tvm/function.py index 22bfb3555fa1..43e688276362 100644 --- a/python/tvm/function.py +++ b/python/tvm/function.py @@ -19,6 +19,38 @@ def const(value, dtype=None): return _function_internal._const(value, dtype) +def load_json(json_str): + """Load tvm object from json_str. + + Parameters + ---------- + json_str : str + The json string + + Returns + ------- + node : Node + The loaded tvm node. + """ + return _function_internal._load_json(json_str) + + +def save_json(node): + """Load tvm object as json string. + + Parameters + ---------- + node : Node + A TVM Node object to be saved. + + Returns + ------- + json_str : str + Saved json string. + """ + return _function_internal._save_json(node) + + def Var(name="tindex", dtype=int32): """Create a new variable with specified name and dtype diff --git a/src/base/common.h b/src/base/common.h new file mode 100644 index 000000000000..66ffffef5cc2 --- /dev/null +++ b/src/base/common.h @@ -0,0 +1,42 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file common.h + * \brief Common utilities + */ +#ifndef TVM_BASE_COMMON_H_ +#define TVM_BASE_COMMON_H_ + +#include +#include + +namespace tvm { + +inline std::string Type2String(const Type& t) { + std::ostringstream os; + os << t; + return os.str(); +} + +inline Type String2Type(std::string s) { + std::istringstream is(s); + halide_type_code_t code = Type::Int; + if (s.substr(0, 3) == "int") { + code = Type::Int; s = s.substr(3); + } else if (s.substr(0, 4) == "uint") { + code = Type::UInt; s = s.substr(4); + } else if (s.substr(0, 5) == "float") { + code = Type::Float; s = s.substr(5); + } else if (s.substr(0, 5) == "float") { + code = Type::Float; s = s.substr(5); + } else { + LOG(FATAL) << "unknown type " << s; + } + int bits = 32, lanes = 1; + if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) { + LOG(FATAL) << "unknown type " << s; + } + return Type(code, bits, lanes); +} + +} // namespace tvm +#endif // TVM_BASE_COMMON_H_ diff --git a/src/base/saveload_json.cc b/src/base/saveload_json.cc new file mode 100644 index 000000000000..6a877caf4678 --- /dev/null +++ b/src/base/saveload_json.cc @@ -0,0 +1,306 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file saveload_json.cc + * \brief Utilities to save/load TVM objects. + */ +#include +#include +#include +#include +#include "./common.h" + +namespace tvm { + +// indexer to index all the ndoes +class NodeIndexer : public AttrVisitor { + public: + std::unordered_map node_index{{nullptr, 0}}; + std::vector node_list{nullptr}; + + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, Type* value) final {} + void Visit(const char* key, NodeRef* value) final { + MakeIndex(value->node_.get()); + } + + // make index of all the children of node + void MakeIndex(Node* node) { + if (node == nullptr) return; + if (node_index.count(node)) return; + CHECK_EQ(node_index.size(), node_list.size()); + node_index[node] = node_list.size(); + node_list.push_back(node); + + if (node->is_type()) { + ArrayNode* n = static_cast(node); + for (const auto& sp : n->data) { + MakeIndex(sp.get()); + } + } else if (node->is_type()) { + MapNode* n = static_cast(node); + for (const auto& kv : n->data) { + MakeIndex(kv.first.get()); + MakeIndex(kv.second.get()); + } + } else { + node->VisitAttrs(this); + } + } +}; + +// use map so attributes are ordered. +using AttrMap = std::map; + +// A Node structure for JSON node. +struct JSONNode { + // The type key of the data + std::string type_key; + // the attributes + AttrMap attrs; + // container data + std::vector data; + + void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("type_key", type_key); + if (attrs.size() != 0) { + writer->WriteObjectKeyValue("attrs", attrs); + } + if (data.size() != 0) { + writer->WriteObjectKeyValue("data", data); + } + writer->EndObject(); + } + + void Load(dmlc::JSONReader *reader) { + attrs.clear(); + data.clear(); + type_key.clear(); + dmlc::JSONObjectReadHelper helper; + helper.DeclareOptionalField("type_key", &type_key); + helper.DeclareOptionalField("attrs", &attrs); + helper.DeclareOptionalField("data", &data); + helper.ReadAllFields(reader); + } +}; + +class JSONAttrGetter : public AttrVisitor { + public: + const std::unordered_map* node_index_; + JSONNode* node_; + + void Visit(const char* key, double* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, int64_t* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, uint64_t* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, int* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, bool* value) final { + node_->attrs[key] = std::to_string(*value); + } + void Visit(const char* key, std::string* value) final { + node_->attrs[key] = *value; + } + void Visit(const char* key, Type* value) final { + node_->attrs[key] = Type2String(*value); + } + void Visit(const char* key, NodeRef* value) final { + node_->attrs[key] = std::to_string( + node_index_->at(value->node_.get())); + } + // Get the node + void Get(Node* node) { + if (node == nullptr) { + node_->type_key.clear(); + return; + } + node_->type_key = node->type_key(); + node_->attrs.clear(); + node_->data.clear(); + if (node->is_type()) { + ArrayNode* n = static_cast(node); + for (size_t i = 0; i < n->data.size(); ++i) { + node_->data.push_back( + node_index_->at(n->data[i].get())); + } + } else if (node->is_type()) { + MapNode* n = static_cast(node); + std::vector > elems; + for (const auto& kv : n->data) { + node_->data.push_back( + node_index_->at(kv.first.get())); + node_->data.push_back( + node_index_->at(kv.second.get())); + } + } else { + node->VisitAttrs(this); + } + } +}; + +class JSONAttrSetter : public AttrVisitor { + public: + const std::vector >* node_list_; + JSONNode* node_; + + std::string GetValue(const char* key) const { + auto it = node_->attrs.find(key); + if (it == node_->attrs.end()) { + LOG(FATAL) << "JSONReader: cannot find field " << key; + } + return it->second; + } + template + void ParseValue(const char* key, T* value) const { + std::istringstream is(GetValue(key)); + is >> *value; + if (is.fail()) { + LOG(FATAL) << "Wrong value format for field " << key; + } + } + void Visit(const char* key, double* value) final { + ParseValue(key, value); + } + void Visit(const char* key, int64_t* value) final { + ParseValue(key, value); + } + void Visit(const char* key, uint64_t* value) final { + ParseValue(key, value); + } + void Visit(const char* key, int* value) final { + ParseValue(key, value); + } + void Visit(const char* key, bool* value) final { + ParseValue(key, value); + } + void Visit(const char* key, std::string* value) final { + *value = GetValue(key); + } + void Visit(const char* key, Type* value) final { + std::string stype = GetValue(key); + *value = String2Type(stype); + } + void Visit(const char* key, NodeRef* value) final { + size_t index; + ParseValue(key, &index); + value->node_ = node_list_->at(index); + } + + // Get the node + void Set(Node* node) { + if (node == nullptr) return; + if (node->is_type()) { + ArrayNode* n = static_cast(node); + n->data.clear(); + for (size_t index : node_->data) { + n->data.push_back(node_list_->at(index)); + } + } else if (node->is_type()) { + MapNode* n = static_cast(node); + CHECK_EQ(node_->data.size() % 2, 0U); + for (size_t i = 0; i < node_->data.size(); i += 2) { + n->data[node_list_->at(node_->data[i])] + = node_list_->at(node_->data[i + 1]); + } + } else { + node->VisitAttrs(this); + } + } +}; + +// json graph structure to store node +struct JSONGraph { + // the root of the graph + size_t root; + // the nodes of the graph + std::vector nodes; + // global attributes + AttrMap attrs; + + void Save(dmlc::JSONWriter *writer) const { + writer->BeginObject(); + writer->WriteObjectKeyValue("root", root); + writer->WriteObjectKeyValue("nodes", nodes); + if (attrs.size() != 0) { + writer->WriteObjectKeyValue("attrs", attrs); + } + writer->EndObject(); + } + + void Load(dmlc::JSONReader *reader) { + attrs.clear(); + dmlc::JSONObjectReadHelper helper; + helper.DeclareField("root", &root); + helper.DeclareField("nodes", &nodes); + helper.DeclareOptionalField("attrs", &attrs); + helper.ReadAllFields(reader); + } + + static JSONGraph Create(const NodeRef& root) { + JSONGraph g; + NodeIndexer indexer; + indexer.MakeIndex(root.node_.get()); + JSONAttrGetter getter; + getter.node_index_ = &indexer.node_index; + for (Node* n : indexer.node_list) { + JSONNode jnode; + getter.node_ = &jnode; + getter.Get(n); + g.nodes.emplace_back(std::move(jnode)); + } + g.attrs["tvm_version"] = "0.1.0"; + g.root = indexer.node_index.at(root.node_.get()); + return g; + } +}; + +std::string SaveJSON(const NodeRef& n) { + auto jgraph = JSONGraph::Create(n); + std::ostringstream os; + dmlc::JSONWriter writer(&os); + jgraph.Save(&writer); + return os.str(); +} + +std::shared_ptr LoadJSON_(std::string json_str) { + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + JSONGraph jgraph; + // load in json graph. + jgraph.Load(&reader); + std::vector > nodes; + // node 0 is always null + nodes.reserve(jgraph.nodes.size()); + for (const JSONNode& jnode : jgraph.nodes) { + if (jnode.type_key.length() != 0) { + auto* f = dmlc::Registry::Find(jnode.type_key); + CHECK(f != nullptr) + << "Node type \'" << jnode.type_key << "\' is not registered in TVM"; + nodes.emplace_back(f->body()); + } else { + nodes.emplace_back(std::shared_ptr()); + } + } + CHECK_EQ(nodes.size(), jgraph.nodes.size()); + JSONAttrSetter setter; + setter.node_list_ = &nodes; + + for (size_t i = 0; i < nodes.size(); ++i) { + setter.node_ = &jgraph.nodes[i]; + setter.Set(nodes[i].get()); + } + return nodes.at(jgraph.root); +} + +} // namespace tvm diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc index 9327e61e6901..4cdaa5358e1e 100644 --- a/src/c_api/c_api_function.cc +++ b/src/c_api/c_api_function.cc @@ -34,4 +34,16 @@ TVM_REGISTER_API(_raw_ptr) }) .add_argument("src", "NodeBase", "the node base"); +TVM_REGISTER_API(_save_json) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = SaveJSON(args.at(0)); + }) +.add_argument("src", "json_str", "the node "); + +TVM_REGISTER_API(_load_json) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = NodeRef(LoadJSON_(args.at(0))); + }) +.add_argument("src", "NodeBase", "the node"); + } // namespace tvm diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 885c45b14432..0baa1fbd9a30 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -13,36 +13,10 @@ #include #include #include +#include "../base/common.h" namespace tvm { -inline std::string Type2String(const Type& t) { - std::ostringstream os; - os << t; - return os.str(); -} - -inline Type String2Type(std::string s) { - std::istringstream is(s); - halide_type_code_t code = Type::Int; - if (s.substr(0, 3) == "int") { - code = Type::Int; s = s.substr(3); - } else if (s.substr(0, 4) == "uint") { - code = Type::UInt; s = s.substr(4); - } else if (s.substr(0, 5) == "float") { - code = Type::Float; s = s.substr(5); - } else if (s.substr(0, 5) == "float") { - code = Type::Float; s = s.substr(5); - } else { - LOG(FATAL) << "unknown type " << s; - } - int bits = 32, lanes = 1; - if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) { - LOG(FATAL) << "unknown type " << s; - } - return Type(code, bits, lanes); -} - inline const char* TypeId2Str(ArgVariantID type_id) { switch (type_id) { case kNull: return "Null"; diff --git a/src/lang/expr.cc b/src/lang/expr.cc index c6ec66fcaa24..24e91c6384c9 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -13,8 +13,11 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); } // namespace dmlc namespace tvm { + +using Halide::IR::RangeNode; + Range::Range(Expr begin, Expr end) - : Range(std::make_shared( + : Range(std::make_shared( begin, is_zero(begin) ? end : (end - begin))) { } @@ -67,10 +70,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const Halide::IR::RangeNode *op, IRPrinter *p) { +.set_dispatch([](const Halide::IR::RangeNode *op, IRPrinter *p) { p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); + +TVM_REGISTER_NODE_TYPE(ArrayNode); +TVM_REGISTER_NODE_TYPE(MapNode); +TVM_REGISTER_NODE_TYPE(RangeNode); TVM_REGISTER_NODE_TYPE(IterVarNode); } // namespace tvm diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 22b3d10d8b9b..1f266126efed 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -206,5 +206,6 @@ IterVarRelation FuseNode::make( TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(FuseNode); +TVM_REGISTER_NODE_TYPE(ScheduleNode); } // namespace tvm diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index f64fc8651fb1..fb8685695013 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -6,8 +6,11 @@ TEST(Expr, Basic) { using namespace tvm; Var x("x"); auto z = max(x + 1 + 2, 100); + NodeRef tmp = z; + Expr zz(tmp.node_); std::ostringstream os; os << z; + CHECK(zz.same_as(z)); CHECK(os.str() == "max(((x + 1) + 2), 100)"); } diff --git a/tests/python/test_lang_basic.py b/tests/python/test_lang_basic.py index 3368a74e77fb..b1eb2990a6d0 100644 --- a/tests/python/test_lang_basic.py +++ b/tests/python/test_lang_basic.py @@ -5,6 +5,16 @@ def test_const(): assert x.dtype == 'int32' assert isinstance(x, tvm.expr.IntImm) +def test_const_saveload_json(): + # save load json + x = tvm.const(1) + y = tvm.const(10) + z = x + y + z = z + z + json_str = tvm.save_json(z) + zz = tvm.load_json(json_str) + assert tvm.save_json(zz) == tvm.save_json(z) + def test_make(): x = tvm.const(1) y = tvm.make.IntImm('int32', 1) @@ -57,6 +67,7 @@ def test_stmt(): if __name__ == "__main__": test_attr() test_const() + test_const_saveload_json() test_make() test_ir() test_basic() diff --git a/tests/python/test_lang_container.py b/tests/python/test_lang_container.py index 3333f60b41af..d8e12badb88c 100644 --- a/tests/python/test_lang_container.py +++ b/tests/python/test_lang_container.py @@ -4,6 +4,12 @@ def test_array(): a = tvm.convert([1,2,3]) assert len(a) == 3 +def test_array_save_load_json(): + a = tvm.convert([1,2,3]) + json_str = tvm.save_json(a) + a_loaded = tvm.load_json(json_str) + assert(a[1].value == 2) + def test_map(): a = tvm.Var('a') b = tvm.Var('b') @@ -15,6 +21,20 @@ def test_map(): assert str(dd) == str(amap) assert a + 1 not in amap +def test_map_save_load_json(): + a = tvm.Var('a') + b = tvm.Var('b') + amap = tvm.convert({a: 2, + b: 3}) + json_str = tvm.save_json(amap) + amap = tvm.load_json(json_str) + assert len(amap) == 2 + dd = {kv[0].name : kv[1].value for kv in amap.items()} + assert(dd == {"a": 2, "b": 3}) + + if __name__ == "__main__": test_array() test_map() + test_array_save_load_json() + test_map_save_load_json() diff --git a/tests/python/test_lang_schedule.py b/tests/python/test_lang_schedule.py index 7e3c2f3ce64d..fcb573dab4c3 100644 --- a/tests/python/test_lang_schedule.py +++ b/tests/python/test_lang_schedule.py @@ -1,4 +1,5 @@ import tvm +import pickle as pkl def test_schedule_create(): m = tvm.Var('m') @@ -17,6 +18,18 @@ def test_schedule_create(): s[T].reorder(xi2, xi1) assert T.op.axis[1] in s[T].leaf_iter_vars + # save load json + json_str = tvm.save_json(s) + s_loaded = tvm.load_json(json_str) + assert isinstance(s_loaded, tvm.schedule.Schedule) + assert(str(s_loaded.roots[0].body) == str(s.roots[0].body)) + + # pickle unpickle + dump = pkl.dumps(s) + s_loaded = pkl.loads(dump) + assert isinstance(s_loaded, tvm.schedule.Schedule) + assert(str(s_loaded.roots[0].body) == str(s.roots[0].body)) + def test_reorder(): m = tvm.Var('m') A = tvm.placeholder((m,), name='A') diff --git a/tests/python/test_lang_tensor.py b/tests/python/test_lang_tensor.py index ca695813d7a2..af0632866404 100644 --- a/tests/python/test_lang_tensor.py +++ b/tests/python/test_lang_tensor.py @@ -27,7 +27,11 @@ def test_tensor_reduce(): T = tvm.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k]) rv = tvm.IterVar((0, A.shape[1]), name="k") C = tvm.compute((m, n), lambda i, j: tvm.sum(T(i, j, rv+1), rdom=rv)) - print(C.op.body) + # json load save + C_json = tvm.save_json(C) + C_loaded = tvm.load_json(C_json) + assert(isinstance(C_loaded, tvm.tensor.Tensor)) + assert(str(C_loaded) == str(C)) if __name__ == "__main__": test_tensor()