diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index fba4d11aaf72..e31b44df81ba 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -20,7 +20,7 @@ from __future__ import absolute_import import json -from collections import defaultdict +from collections import defaultdict, OrderedDict import attr from . import _backend from . import compile_engine @@ -348,13 +348,31 @@ def _get_json(self): attrs["device_index"] = ["list_int", device_types] attrs["dltype"] = ["list_str", dltypes] - json_dict = { + # Metadata definitions + def nested_defaultdict(): + return defaultdict(nested_defaultdict) + metadata = nested_defaultdict() + for node_id in arg_nodes: + node_name = nodes[node_id]['name'] + if node_name not in self.params: + metadata['signatures']['default']['inputs'][node_name]['id'] = node_id + metadata['signatures']['default']['inputs'][node_name]['dtype'] = dltypes[node_id] + metadata['signatures']['default']['inputs'][node_name]['shape'] = shapes[node_id] + for node_id in heads: + node_name = nodes[node_id[0]]['name'] + metadata['signatures']['default']['outputs'][node_name]['id'] = node_id[0] + metadata['signatures']['default']['outputs'][node_name]['dtype'] = dltypes[node_id[0]] + metadata['signatures']['default']['outputs'][node_name]['shape'] = shapes[node_id[0]] + + # Keep 'metadata' always at end + json_dict = OrderedDict({ "nodes": nodes, "arg_nodes": arg_nodes, "heads": heads, "attrs": attrs, - "node_row_ptr": node_row_ptr - } + "node_row_ptr": node_row_ptr, + "metadata": metadata + }) return json.dumps(json_dict, indent=2) diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index dbc02ad980ac..cdd236a09d31 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -318,6 +318,8 @@ class GraphRuntime : public ModuleNode { } else if (key == "attrs") { reader->Read(&attrs_); bitmask |= 16; + } else if (key == "metadata") { + break; } else { LOG(FATAL) << "key " << key << " is not supported"; }