From d59b345dacb8024217def4679c75b794b261ecd6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 2 Oct 2020 04:27:02 +0000 Subject: [PATCH 1/7] initial commit --- .../extensions/lib_subgraph/subgraph_lib.cc | 10 +++++++++- include/mxnet/lib_api.h | 8 +++++--- src/lib_api.cc | 19 ++++++++++--------- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc index df071b96798d..7bdc036445da 100644 --- a/example/extensions/lib_subgraph/subgraph_lib.cc +++ b/example/extensions/lib_subgraph/subgraph_lib.cc @@ -209,11 +209,16 @@ MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph, } MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept, - const std::unordered_map& options) { + const std::unordered_map& options, + std::unordered_map* attrs) { for (auto kv : options) { std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; } + std::string sg = subgraph->toString(); + std::cout << "subgraph " << subgraph_id << ": " << std::endl; + std::cout << sg << std::endl; + // check if option `reject` was specified, and if so check if value is 'True' if(options.count("reject") > 0 && options.at("reject").compare("True") == 0) { // if specified, reject the subgraph. this is only used for testing @@ -223,6 +228,9 @@ MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_i *accept = true; std::cout << "accepting subgraph" << std::endl; } + + attrs->emplace("myKey","myVal"); + return MX_SUCCESS; } diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index 0213557fdc92..db93dbe6ff41 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -594,10 +594,10 @@ class Graph { static Graph* fromJson(JsonVal val); /* \brief convert graph object back to JSON object */ - JsonVal toJson(); + JsonVal toJson() const; /* \brief convert graph object to JSON string */ - std::string toString(); + std::string toString() const; /* \brief visits a node "n" */ void _dfs_util(Node* n, std::unordered_set* to_visit, @@ -819,7 +819,9 @@ typedef MXReturnValue (*createSelector_t)(const mxnet::ext::Graph *graph, typedef MXReturnValue (*reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept, const std::unordered_map& options); + std::string>& options, + std::unordered_map* attrs); /*! * \brief An abstract class for subgraph property diff --git a/src/lib_api.cc b/src/lib_api.cc index 20ae280acf6c..aee2bad3df47 100644 --- a/src/lib_api.cc +++ b/src/lib_api.cc @@ -561,7 +561,7 @@ mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) { } /* \brief convert graph object back to JSON object */ -mxnet::ext::JsonVal mxnet::ext::Graph::toJson() { +mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const { // top level object is a map JsonVal val(MAP); @@ -646,7 +646,7 @@ mxnet::ext::JsonVal mxnet::ext::Graph::toJson() { } /* \brief convert graph object to JSON string */ -std::string mxnet::ext::Graph::toString() { +std::string mxnet::ext::Graph::toString() const { return toJson().dump(); } @@ -1494,26 +1494,27 @@ MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, } subgraph->_setParams(&args, &aux); + + std::unordered_map attrs; mxnet::ext::MXReturnValue retval = reviewSubgraph(subgraph, subgraph_id, &accept_bool, - opts); + opts, &attrs); if (!retval) return retval; *accept = accept_bool; - if (subgraph->attrs.size() > 0) { - *num_attrs = subgraph->attrs.size(); + if (attrs.size() > 0) { + *num_attrs = attrs.size(); // allocate space for attributes *attr_keys = static_cast(malloc (*num_attrs * sizeof(char*))); // NOLINT *attr_vals = static_cast(malloc (*num_attrs * sizeof(char*))); // NOLINT // copy attributes int i = 0; - for (auto kv : subgraph->attrs) { + for (auto kv : attrs) { (*attr_keys)[i] = static_cast(malloc ((kv.first.size()+1) * sizeof(char))); // NOLINT - std::string val = kv.second.dump(); // convert JsonVal back to string - (*attr_vals)[i] = static_cast(malloc ((val.size()+1) * sizeof(char))); // NOLINT + (*attr_vals)[i] = static_cast(malloc ((kv.second.size()+1) * sizeof(char))); // NOLINT snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str()); - snprintf((*attr_vals)[i], val.size()+1, "%s", val.c_str()); + snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str()); i++; } } From c197365e69e1b83feb8f5a5fd125fa50473abd35 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 2 Oct 2020 22:35:04 +0000 Subject: [PATCH 2/7] fixed mapping from top level param names to subgraph input names --- src/lib_api.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/lib_api.cc b/src/lib_api.cc index aee2bad3df47..d4ab1824970c 100644 --- a/src/lib_api.cc +++ b/src/lib_api.cc @@ -766,10 +766,14 @@ void mxnet::ext::Graph::_setParams(std::unordered_map* aux) { // set params for each input node for (Node* node : inputs) { - if (args->count(node->name) > 0) - node->tensor = &args->at(node->name); - else if (aux->count(node->name) > 0) - node->tensor = &aux->at(node->name); + std::string name = node->name; + if(node->attrs.count("isArg") > 0 && node->attrs["isArg"].compare("True") == 0) + // mapping name back to original node name from subgraph input name + name = node->attrs["argName"]; + if (args->count(name) > 0) + node->tensor = &args->at(name); + else if (aux->count(name) > 0) + node->tensor = &aux->at(name); } } From 03818a59002528b22dc4c858188ed7f1051b5f5a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 2 Oct 2020 23:02:32 +0000 Subject: [PATCH 3/7] fixed sanity --- src/lib_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib_api.cc b/src/lib_api.cc index d4ab1824970c..cc8c6d9d97d7 100644 --- a/src/lib_api.cc +++ b/src/lib_api.cc @@ -767,7 +767,7 @@ void mxnet::ext::Graph::_setParams(std::unordered_mapname; - if(node->attrs.count("isArg") > 0 && node->attrs["isArg"].compare("True") == 0) + if (node->attrs.count("isArg") > 0 && node->attrs["isArg"].compare("True") == 0) // mapping name back to original node name from subgraph input name name = node->attrs["argName"]; if (args->count(name) > 0) From baa24f996495b0ed3b1d4547106e93868e1d466f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 3 Oct 2020 08:36:41 +0000 Subject: [PATCH 4/7] support escape characters when parsing strings --- src/lib_api.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib_api.cc b/src/lib_api.cc index cc8c6d9d97d7..f30810a7bf81 100644 --- a/src/lib_api.cc +++ b/src/lib_api.cc @@ -348,7 +348,8 @@ mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json) { mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_string(const std::string& json, unsigned int* idx) { JsonVal ret(STR); while (*idx < json.size()) { - if (json[*idx] == '"') { + if (json[*idx] == '"' && (ret.str.size() == 0 || + (ret.str.size() > 0 && ret.str.back() != '\\'))) { ++(*idx); return ret; } else { From c9c8c2498f97a96ab09a8111d4220ae29030947a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 3 Oct 2020 23:20:07 +0000 Subject: [PATCH 5/7] add node to graph nodes array --- src/lib_api.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib_api.cc b/src/lib_api.cc index f30810a7bf81..1a09dd6151c0 100644 --- a/src/lib_api.cc +++ b/src/lib_api.cc @@ -726,6 +726,7 @@ void mxnet::ext::Graph::print(int indent) const { /* \brief add a new node to this graph */ mxnet::ext::Node* mxnet::ext::Graph::addNode(const std::string& name, const std::string& op) { Node* n = new Node(); + g->nodes.push_back(n); n->name = name; n->op = op; if (res) From 2137207a7c1132ce95f31f72369da8546e5ca361 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 3 Oct 2020 23:21:51 +0000 Subject: [PATCH 6/7] changed string allocation from new to malloc to match free --- src/lib_api.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lib_api.cc b/src/lib_api.cc index 1a09dd6151c0..5c47e459da4f 100644 --- a/src/lib_api.cc +++ b/src/lib_api.cc @@ -1594,8 +1594,9 @@ MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *jso mxnet::ext::MXReturnValue retval = graphPass(graph, opts); if (!retval) return retval; - std::string *tmp = new std::string(graph->toString()); - *out_graph = const_cast(tmp->c_str()); + std::string tmp = graph->toString(); + *out_graph = static_cast(malloc ((tmp.size()+1) * sizeof(char))); // NOLINT + snprintf((*out_graph), tmp.size()+1, "%s", tmp.c_str()); return retval; } From a8046390214abdb2dd1db6bd735df25e4d94bae4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 3 Oct 2020 23:39:37 +0000 Subject: [PATCH 7/7] fixed add nodes --- src/lib_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib_api.cc b/src/lib_api.cc index 5c47e459da4f..c273678dcd1a 100644 --- a/src/lib_api.cc +++ b/src/lib_api.cc @@ -726,7 +726,7 @@ void mxnet::ext::Graph::print(int indent) const { /* \brief add a new node to this graph */ mxnet::ext::Node* mxnet::ext::Graph::addNode(const std::string& name, const std::string& op) { Node* n = new Node(); - g->nodes.push_back(n); + nodes.push_back(n); n->name = name; n->op = op; if (res)