diff --git a/python/tvm/contrib/debugger/debug_result.py b/python/tvm/contrib/debugger/debug_result.py index 882364dd3971..e0c9f44c8481 100644 --- a/python/tvm/contrib/debugger/debug_result.py +++ b/python/tvm/contrib/debugger/debug_result.py @@ -204,7 +204,7 @@ def dump_graph_json(self, graph): with open(os.path.join(self._dump_path, graph_dump_file_name), 'w') as outfile: json.dump(graph, outfile, indent=4, sort_keys=False) - def display_debug_result(self): + def display_debug_result(self, sort_by_time=True): """Displays the debugger result" """ header = ["Node Name", "Ops", "Time(us)", "Time(%)", "Shape", "Inputs", "Outputs"] @@ -228,6 +228,14 @@ def display_debug_result(self): node_data = [name, op, time_us, time_percent, shape, inputs, outputs] data.append(node_data) eid += 1 + + if sort_by_time: + # Sort on the basis of execution time. Prints the most expensive ops in the start. + data = sorted(data, key=lambda x: x[2], reverse=True) + # Insert a row for total time at the end. + rounded_total_time = round(total_time * 1000000, 3) + data.append(["Total_time", "-", rounded_total_time, "-", "-", "-", "-", "-"]) + fmt = "" for i, _ in enumerate(header): max_len = len(header[i])