Skip to content

Commit

Permalink
[GRAPH] Include default metadata description in graph. (#2770)
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 authored and tqchen committed Mar 14, 2019
1 parent df72239 commit 3741447
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
26 changes: 22 additions & 4 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from __future__ import absolute_import
import json
from collections import defaultdict
from collections import defaultdict, OrderedDict
import attr
from . import _backend
from . import compile_engine
Expand Down Expand Up @@ -348,13 +348,31 @@ def _get_json(self):
attrs["device_index"] = ["list_int", device_types]
attrs["dltype"] = ["list_str", dltypes]

json_dict = {
# Metadata definitions
def nested_defaultdict():
return defaultdict(nested_defaultdict)
metadata = nested_defaultdict()
for node_id in arg_nodes:
node_name = nodes[node_id]['name']
if node_name not in self.params:
metadata['signatures']['default']['inputs'][node_name]['id'] = node_id
metadata['signatures']['default']['inputs'][node_name]['dtype'] = dltypes[node_id]
metadata['signatures']['default']['inputs'][node_name]['shape'] = shapes[node_id]
for node_id in heads:
node_name = nodes[node_id[0]]['name']
metadata['signatures']['default']['outputs'][node_name]['id'] = node_id[0]
metadata['signatures']['default']['outputs'][node_name]['dtype'] = dltypes[node_id[0]]
metadata['signatures']['default']['outputs'][node_name]['shape'] = shapes[node_id[0]]

# Keep 'metadata' always at end
json_dict = OrderedDict({
"nodes": nodes,
"arg_nodes": arg_nodes,
"heads": heads,
"attrs": attrs,
"node_row_ptr": node_row_ptr
}
"node_row_ptr": node_row_ptr,
"metadata": metadata
})

return json.dumps(json_dict, indent=2)

Expand Down
2 changes: 2 additions & 0 deletions src/runtime/graph/graph_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ class GraphRuntime : public ModuleNode {
} else if (key == "attrs") {
reader->Read(&attrs_);
bitmask |= 16;
} else if (key == "metadata") {
break;
} else {
LOG(FATAL) << "key " << key << " is not supported";
}
Expand Down

0 comments on commit 3741447

Please sign in to comment.