Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[1.x] Fixed setting attributes in reviewSubgraph #19274

Merged
merged 7 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion example/extensions/lib_subgraph/subgraph_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,16 @@ MXReturnValue mySupportedOps(const mxnet::ext::Graph* graph,
}

MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_id, bool* accept,
const std::unordered_map<std::string, std::string>& options) {
const std::unordered_map<std::string, std::string>& options,
std::unordered_map<std::string, std::string>* attrs) {
for (auto kv : options) {
std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl;
}

std::string sg = subgraph->toString();
std::cout << "subgraph " << subgraph_id << ": " << std::endl;
std::cout << sg << std::endl;

// check if option `reject` was specified, and if so check if value is 'True'
if(options.count("reject") > 0 && options.at("reject").compare("True") == 0) {
// if specified, reject the subgraph. this is only used for testing
Expand All @@ -223,6 +228,9 @@ MXReturnValue myReviewSubgraph(const mxnet::ext::Graph *subgraph, int subgraph_i
*accept = true;
std::cout << "accepting subgraph" << std::endl;
}

attrs->emplace("myKey","myVal");

return MX_SUCCESS;
}

Expand Down
8 changes: 5 additions & 3 deletions include/mxnet/lib_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,10 @@ class Graph {
static Graph* fromJson(JsonVal val);

/* \brief convert graph object back to JSON object */
JsonVal toJson();
JsonVal toJson() const;

/* \brief convert graph object to JSON string */
std::string toString();
std::string toString() const;

/* \brief visits a node "n" */
void _dfs_util(Node* n, std::unordered_set<Node*>* to_visit,
Expand Down Expand Up @@ -819,7 +819,9 @@ typedef MXReturnValue (*createSelector_t)(const mxnet::ext::Graph *graph,
typedef MXReturnValue (*reviewSubgraph_t)(const mxnet::ext::Graph *subgraph, int subgraph_id,
bool* accept,
const std::unordered_map<std::string,
std::string>& options);
std::string>& options,
std::unordered_map<std::string,
std::string>* attrs);

/*!
* \brief An abstract class for subgraph property
Expand Down
40 changes: 24 additions & 16 deletions src/lib_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json) {
mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_string(const std::string& json, unsigned int* idx) {
JsonVal ret(STR);
while (*idx < json.size()) {
if (json[*idx] == '"') {
if (json[*idx] == '"' && (ret.str.size() == 0 ||
(ret.str.size() > 0 && ret.str.back() != '\\'))) {
++(*idx);
return ret;
} else {
Expand Down Expand Up @@ -561,7 +562,7 @@ mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) {
}

/* \brief convert graph object back to JSON object */
mxnet::ext::JsonVal mxnet::ext::Graph::toJson() {
mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const {
// top level object is a map
JsonVal val(MAP);

Expand Down Expand Up @@ -646,7 +647,7 @@ mxnet::ext::JsonVal mxnet::ext::Graph::toJson() {
}

/* \brief convert graph object to JSON string */
std::string mxnet::ext::Graph::toString() {
std::string mxnet::ext::Graph::toString() const {
return toJson().dump();
}

Expand Down Expand Up @@ -725,6 +726,7 @@ void mxnet::ext::Graph::print(int indent) const {
/* \brief add a new node to this graph */
mxnet::ext::Node* mxnet::ext::Graph::addNode(const std::string& name, const std::string& op) {
Node* n = new Node();
nodes.push_back(n);
n->name = name;
n->op = op;
if (res)
Expand Down Expand Up @@ -766,10 +768,14 @@ void mxnet::ext::Graph::_setParams(std::unordered_map<std::string, mxnet::ext::M
std::unordered_map<std::string, mxnet::ext::MXTensor>* aux) {
// set params for each input node
for (Node* node : inputs) {
if (args->count(node->name) > 0)
node->tensor = &args->at(node->name);
else if (aux->count(node->name) > 0)
node->tensor = &aux->at(node->name);
std::string name = node->name;
if (node->attrs.count("isArg") > 0 && node->attrs["isArg"].compare("True") == 0)
// mapping name back to original node name from subgraph input name
name = node->attrs["argName"];
if (args->count(name) > 0)
node->tensor = &args->at(name);
else if (aux->count(name) > 0)
node->tensor = &aux->at(name);
}
}

Expand Down Expand Up @@ -1494,26 +1500,27 @@ MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph,
}

subgraph->_setParams(&args, &aux);

std::unordered_map<std::string, std::string> attrs;
mxnet::ext::MXReturnValue retval = reviewSubgraph(subgraph, subgraph_id, &accept_bool,
opts);
opts, &attrs);
if (!retval) return retval;

*accept = accept_bool;

if (subgraph->attrs.size() > 0) {
*num_attrs = subgraph->attrs.size();
if (attrs.size() > 0) {
*num_attrs = attrs.size();
// allocate space for attributes
*attr_keys = static_cast<char**>(malloc (*num_attrs * sizeof(char*))); // NOLINT
*attr_vals = static_cast<char**>(malloc (*num_attrs * sizeof(char*))); // NOLINT

// copy attributes
int i = 0;
for (auto kv : subgraph->attrs) {
for (auto kv : attrs) {
(*attr_keys)[i] = static_cast<char*>(malloc ((kv.first.size()+1) * sizeof(char))); // NOLINT
std::string val = kv.second.dump(); // convert JsonVal back to string
(*attr_vals)[i] = static_cast<char*>(malloc ((val.size()+1) * sizeof(char))); // NOLINT
(*attr_vals)[i] = static_cast<char*>(malloc ((kv.second.size()+1) * sizeof(char))); // NOLINT
snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str());
snprintf((*attr_vals)[i], val.size()+1, "%s", val.c_str());
snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str());
i++;
}
}
Expand Down Expand Up @@ -1587,8 +1594,9 @@ MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *jso
mxnet::ext::MXReturnValue retval = graphPass(graph, opts);
if (!retval) return retval;

std::string *tmp = new std::string(graph->toString());
*out_graph = const_cast<char*>(tmp->c_str());
std::string tmp = graph->toString();
*out_graph = static_cast<char*>(malloc ((tmp.size()+1) * sizeof(char))); // NOLINT
snprintf((*out_graph), tmp.size()+1, "%s", tmp.c_str());
return retval;
}

Expand Down