From a0848451d5b8f780f72fe23ded121de36753fe92 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Thu, 16 Sep 2021 03:21:30 -0400 Subject: [PATCH] Fix status initialization --- yolort/relaying/ir_visualizer.py | 55 ++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/yolort/relaying/ir_visualizer.py b/yolort/relaying/ir_visualizer.py index e6bdf958..dc68d779 100644 --- a/yolort/relaying/ir_visualizer.py +++ b/yolort/relaying/ir_visualizer.py @@ -17,9 +17,6 @@ class TorchScriptVisualizer: def __init__(self, module): self.module = module - self.seen_edges = set() - self.seen_input_names = set() - self.predictions = OrderedDict() self.unseen_ops = { 'prim::ListConstruct', 'prim::ListUnpack', @@ -37,13 +34,31 @@ def __init__(self, module): # probably also partially absorbing ops. :/ self.absorbing_ops = ('aten::size', 'aten::_shape_as_tensor') - def render(self, classes_to_visit={'YOLO', 'YOLOHead'}): + def render( + self, + classes_to_visit={'YOLO', 'YOLOHead'}, + format='svg', + labelloc='t', + attr_size='8,7', + ): + self.clean_status() + model_input = next(self.module.graph.inputs()) model_type = self.get_node_names(model_input)[-1] - dot = Digraph(format='svg', graph_attr={'label': model_type, 'labelloc': 't'}) + dot = Digraph( + format=format, + graph_attr={'label': model_type, 'labelloc': labelloc}, + ) self.make_graph(self.module, dot=dot, classes_to_visit=classes_to_visit) + + dot.attr(size=attr_size) return dot + def clean_status(self): + self._seen_edges = set() + self._seen_input_names = set() + self._predictions = OrderedDict() + @staticmethod def get_node_names(node): return node.type().str().split('.') @@ -57,11 +72,11 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N graph = module.graph self_input = next(graph.inputs()) - self.predictions[self_input] = (set(), set()) # inputs, ops + self._predictions[self_input] = (set(), set()) # Stand for `input` and `op` respectively for nr, i in enumerate(list(graph.inputs())[1:]): name = f'{prefix}input_{i.debugName()}' - self.predictions[i] = {name}, set() + self._predictions[i] = {name}, set() dot.node(name, shape='ellipse') if input_preds is not None: pred, op = input_preds[nr] @@ -95,7 +110,7 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N sub_prefix = f'{prefix}{submodule_name}.' for i, o in enumerate(node.outputs()): - self.predictions[o] = {f'{sub_prefix}output_{i}'}, set() + self._predictions[o] = {f'{sub_prefix}output_{i}'}, set() with dot.subgraph(name=f'cluster_{name}') as sub_dot: sub_dot.attr(label=label) @@ -108,7 +123,7 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N dot=sub_dot, parent_dot=dot, prefix=sub_prefix, - input_preds=[self.predictions[i] for i in node_inputs[1:]], + input_preds=[self._predictions[i] for i in node_inputs[1:]], classes_to_visit=classes_to_visit, classes_found=classes_found, ) @@ -116,26 +131,26 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N else: dot.node(name, label=label, shape='box') for i in relevant_inputs: - pred, op = self.predictions[i] + pred, op = self._predictions[i] self.make_edges(pred, prefix + i.debugName(), name, op, dot) for o in node.outputs(): - self.predictions[o] = {name}, set() + self._predictions[o] = {name}, set() elif node.kind() == 'prim::CallFunction': name = f'{prefix}.{node.output().debugName()}' fun_name = self.get_function_name(node_inputs[0]) dot.node(name, label=fun_name, shape='box') for i in relevant_inputs: - pred, op = self.predictions[i] + pred, op = self._predictions[i] self.make_edges(pred, prefix + i.debugName(), name, op, dot) for o in node.outputs(): - self.predictions[o] = {name}, set() + self._predictions[o] = {name}, set() else: label = node.kind().split('::')[-1].rstrip('_') pred, op = set(), set() for i in relevant_inputs: - apred, aop = self.predictions[i] + apred, aop = self._predictions[i] pred |= apred op |= aop @@ -146,23 +161,23 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N and node.kind() not in self.unseen_ops): op.add(label) for o in node.outputs(): - self.predictions[o] = pred, op + self._predictions[o] = pred, op for i, o in enumerate(graph.outputs()): name = f'{prefix}output_{i}' dot.node(name, shape='ellipse') - pred, op = self.predictions[o] + pred, op = self._predictions[o] self.make_edges(pred, f'input_{name}', name, op, dot) def add_edge(self, dot, n1, n2): - if (n1, n2) not in self.seen_edges: - self.seen_edges.add((n1, n2)) + if (n1, n2) not in self._seen_edges: + self._seen_edges.add((n1, n2)) dot.edge(n1, n2) def make_edges(self, preds, input_name, name, op, edge_dot): if len(op) > 0: - if input_name not in self.seen_input_names: - self.seen_input_names.add(input_name) + if input_name not in self._seen_input_names: + self._seen_input_names.add(input_name) label_lines = [[]] line_len = 0 for w in op: