From 721f1151b16ca57cc92267d794fadc7c39d97c6c Mon Sep 17 00:00:00 2001 From: zhaojinxi Date: Thu, 5 Jan 2023 21:14:17 +0800 Subject: [PATCH] [BugFix][Runtime] Fix Incorrect node information (#13693) * [BugFix][Runtime] Fix Incorrect node information * 1 * 1 --- python/tvm/contrib/debugger/debug_result.py | 24 ++++++++++------- .../unittest/test_runtime_graph_debug.py | 26 ++++++++++++++++++- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/python/tvm/contrib/debugger/debug_result.py b/python/tvm/contrib/debugger/debug_result.py index 006edd345802..8a1089f843cd 100644 --- a/python/tvm/contrib/debugger/debug_result.py +++ b/python/tvm/contrib/debugger/debug_result.py @@ -73,21 +73,25 @@ def _update_graph_json(self): """update the nodes_list with name, shape and data type, for temporarily storing the output. """ - nodes_len = len(self._nodes_list) - for i in range(nodes_len): - node = self._nodes_list[i] + eid = 0 + for node in self._nodes_list: input_list = [] - for input_node in node["inputs"]: - input_list.append(self._nodes_list[input_node[0]]["name"]) - node["inputs"] = input_list - dtype = str("type: " + self._dtype_list[1][i]) - if "attrs" not in node: + if node["op"] == "null": node["attrs"] = {} node["op"] = "param" - else: + num_outputs = 1 + elif node["op"] == "tvm_op": + for input_node in node["inputs"]: + input_list.append(self._nodes_list[input_node[0]]["name"]) node["op"] = node["attrs"]["func_name"] + num_outputs = int(node["attrs"]["num_outputs"]) + else: + raise ValueError("") + node["inputs"] = input_list + dtype = str("type: " + self._dtype_list[1][eid]) node["attrs"].update({"T": dtype}) - node["shape"] = self._shapes_list[1][i] + node["shape"] = self._shapes_list[1][eid] + eid += num_outputs def _cleanup_tensors(self): """Remove the tensor dump file (graph wont be removed)""" diff --git a/tests/python/unittest/test_runtime_graph_debug.py b/tests/python/unittest/test_runtime_graph_debug.py index bc0e96f50b45..9111ed38db33 100644 --- a/tests/python/unittest/test_runtime_graph_debug.py +++ b/tests/python/unittest/test_runtime_graph_debug.py @@ -29,7 +29,7 @@ from tvm._ffi.base import TVMError from tvm.contrib import utils from tvm.contrib.debugger import debug_executor - +from tvm import relay # Constants for creating simple graphs, fixtures to avoid free globals @pytest.fixture @@ -275,5 +275,29 @@ def test_run_single_node(graph, n, A, myadd): mod.run_individual_node(2) +@tvm.testing.requires_llvm +def test_multiple_output(): + x = relay.var("x", shape=(1, 3, 48, 16), dtype="float32") + t = relay.split(x, [12, 16, 32], 2).astuple() + x0 = relay.TupleGetItem(t, 0) + x1 = relay.TupleGetItem(t, 1) + x2 = relay.TupleGetItem(t, 2) + x3 = relay.TupleGetItem(t, 3) + p0 = relay.const(np.random.uniform(-1, 1, (3, 3, 1, 1)).astype("float32")) + y = relay.nn.conv2d(x2, p0, kernel_size=(1, 1), kernel_layout="OIHW", out_dtype="float32") + x3 + + func = relay.Function([x], relay.Tuple([x0, x1, y])) + mod = tvm.IRModule.from_expr(func) + mod = relay.transform.InferType()(mod) + target = tvm.target.Target("llvm") + device = tvm.cpu() + lib = relay.build(mod, target=target) + m = debug_executor.GraphModuleDebug( + lib["debug_create"]("default", device), [device], lib.get_graph_json(), None + ) + nodes = m.debug_datum.get_graph_nodes() + assert nodes[2]["shape"] == [3, 3, 1, 1] + + if __name__ == "__main__": tvm.testing.main()