Skip to content

Commit

Permalink
Fix status initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 16, 2021
1 parent 31698fc commit a084845
Showing 1 changed file with 35 additions and 20 deletions.
55 changes: 35 additions & 20 deletions yolort/relaying/ir_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ class TorchScriptVisualizer:
def __init__(self, module):

self.module = module
self.seen_edges = set()
self.seen_input_names = set()
self.predictions = OrderedDict()

self.unseen_ops = {
'prim::ListConstruct', 'prim::ListUnpack',
Expand All @@ -37,13 +34,31 @@ def __init__(self, module):
# probably also partially absorbing ops. :/
self.absorbing_ops = ('aten::size', 'aten::_shape_as_tensor')

def render(self, classes_to_visit={'YOLO', 'YOLOHead'}):
def render(
self,
classes_to_visit={'YOLO', 'YOLOHead'},
format='svg',
labelloc='t',
attr_size='8,7',
):
self.clean_status()

model_input = next(self.module.graph.inputs())
model_type = self.get_node_names(model_input)[-1]
dot = Digraph(format='svg', graph_attr={'label': model_type, 'labelloc': 't'})
dot = Digraph(
format=format,
graph_attr={'label': model_type, 'labelloc': labelloc},
)
self.make_graph(self.module, dot=dot, classes_to_visit=classes_to_visit)

dot.attr(size=attr_size)
return dot

def clean_status(self):
self._seen_edges = set()
self._seen_input_names = set()
self._predictions = OrderedDict()

@staticmethod
def get_node_names(node):
return node.type().str().split('.')
Expand All @@ -57,11 +72,11 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
graph = module.graph

self_input = next(graph.inputs())
self.predictions[self_input] = (set(), set()) # inputs, ops
self._predictions[self_input] = (set(), set()) # Stand for `input` and `op` respectively

for nr, i in enumerate(list(graph.inputs())[1:]):
name = f'{prefix}input_{i.debugName()}'
self.predictions[i] = {name}, set()
self._predictions[i] = {name}, set()
dot.node(name, shape='ellipse')
if input_preds is not None:
pred, op = input_preds[nr]
Expand Down Expand Up @@ -95,7 +110,7 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
sub_prefix = f'{prefix}{submodule_name}.'

for i, o in enumerate(node.outputs()):
self.predictions[o] = {f'{sub_prefix}output_{i}'}, set()
self._predictions[o] = {f'{sub_prefix}output_{i}'}, set()

with dot.subgraph(name=f'cluster_{name}') as sub_dot:
sub_dot.attr(label=label)
Expand All @@ -108,34 +123,34 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
dot=sub_dot,
parent_dot=dot,
prefix=sub_prefix,
input_preds=[self.predictions[i] for i in node_inputs[1:]],
input_preds=[self._predictions[i] for i in node_inputs[1:]],
classes_to_visit=classes_to_visit,
classes_found=classes_found,
)

else:
dot.node(name, label=label, shape='box')
for i in relevant_inputs:
pred, op = self.predictions[i]
pred, op = self._predictions[i]
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
for o in node.outputs():
self.predictions[o] = {name}, set()
self._predictions[o] = {name}, set()

elif node.kind() == 'prim::CallFunction':
name = f'{prefix}.{node.output().debugName()}'
fun_name = self.get_function_name(node_inputs[0])
dot.node(name, label=fun_name, shape='box')
for i in relevant_inputs:
pred, op = self.predictions[i]
pred, op = self._predictions[i]
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
for o in node.outputs():
self.predictions[o] = {name}, set()
self._predictions[o] = {name}, set()

else:
label = node.kind().split('::')[-1].rstrip('_')
pred, op = set(), set()
for i in relevant_inputs:
apred, aop = self.predictions[i]
apred, aop = self._predictions[i]
pred |= apred
op |= aop

Expand All @@ -146,23 +161,23 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
and node.kind() not in self.unseen_ops):
op.add(label)
for o in node.outputs():
self.predictions[o] = pred, op
self._predictions[o] = pred, op

for i, o in enumerate(graph.outputs()):
name = f'{prefix}output_{i}'
dot.node(name, shape='ellipse')
pred, op = self.predictions[o]
pred, op = self._predictions[o]
self.make_edges(pred, f'input_{name}', name, op, dot)

def add_edge(self, dot, n1, n2):
if (n1, n2) not in self.seen_edges:
self.seen_edges.add((n1, n2))
if (n1, n2) not in self._seen_edges:
self._seen_edges.add((n1, n2))
dot.edge(n1, n2)

def make_edges(self, preds, input_name, name, op, edge_dot):
if len(op) > 0:
if input_name not in self.seen_input_names:
self.seen_input_names.add(input_name)
if input_name not in self._seen_input_names:
self._seen_input_names.add(input_name)
label_lines = [[]]
line_len = 0
for w in op:
Expand Down

0 comments on commit a084845

Please sign in to comment.