Skip to content

Commit

Permalink
Move predictions to attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 15, 2021
1 parent a3f8200 commit 88b379c
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions yolort/relaying/ir_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Please link to Thomas's blog post or the original github source (linked from the
blog post) with the attribution notice.
"""
from collections import OrderedDict
from graphviz import Digraph


Expand All @@ -18,6 +19,7 @@ def __init__(self, module):
self.module = module
self.seen_edges = set()
self.seen_input_names = set()
self.predictions = OrderedDict()

self.unseen_ops = {
'aten::Int',
Expand Down Expand Up @@ -53,14 +55,13 @@ def get_function_name(node):
def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None,
classes_to_visit=None, classes_found=None):
graph = module.graph
preds = {}

self_input = next(graph.inputs())
preds[self_input] = (set(), set()) # inps, ops
self.predictions[self_input] = (set(), set()) # inputs, ops

for nr, i in enumerate(list(graph.inputs())[1:]):
name = f'{prefix}input_{i.debugName()}'
preds[i] = {name}, set()
self.predictions[i] = {name}, set()
dot.node(name, shape='ellipse')
if input_preds is not None:
pred, op = input_preds[nr]
Expand Down Expand Up @@ -103,36 +104,36 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
dot=sub_dot,
parent_dot=dot,
prefix=sub_prefix,
input_preds=[preds[i] for i in node_inputs[1:]],
input_preds=[self.predictions[i] for i in node_inputs[1:]],
classes_to_visit=classes_to_visit,
classes_found=classes_found,
)

for i, o in enumerate(node.outputs()):
preds[o] = {f'{sub_prefix}output_{i}'}, set()
self.predictions[o] = {f'{sub_prefix}output_{i}'}, set()
else:
dot.node(name, label=label, shape='box')
for i in relevant_inputs:
pred, op = preds[i]
pred, op = self.predictions[i]
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
for o in node.outputs():
preds[o] = {name}, set()
self.predictions[o] = {name}, set()

elif node.kind() == 'prim::CallFunction':
name = f'{prefix}.{node.output().debugName()}'
fun_name = self.get_function_name(node_inputs[0])
dot.node(name, label=fun_name, shape='box')
for i in relevant_inputs:
pred, op = preds[i]
pred, op = self.predictions[i]
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
for o in node.outputs():
preds[o] = {name}, set()
self.predictions[o] = {name}, set()

else:
label = node.kind().split('::')[-1].rstrip('_')
pred, op = set(), set()
for i in relevant_inputs:
apred, aop = preds[i]
apred, aop = self.predictions[i]
pred |= apred
op |= aop

Expand All @@ -143,12 +144,12 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
and node.kind() not in self.unseen_ops):
op.add(label)
for o in node.outputs():
preds[o] = pred, op
self.predictions[o] = pred, op

for i, o in enumerate(graph.outputs()):
name = f'{prefix}output_{i}'
dot.node(name, shape='ellipse')
pred, op = preds[o]
pred, op = self.predictions[o]
self.make_edges(pred, f'input_{name}', name, op, dot)

def add_edge(self, dot, n1, n2):
Expand Down

0 comments on commit 88b379c

Please sign in to comment.