Skip to content

Commit

Permalink
Adding minor ultis
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 15, 2021
1 parent 3519c1b commit a3f8200
Showing 1 changed file with 30 additions and 26 deletions.
56 changes: 30 additions & 26 deletions yolort/relaying/ir_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def render(self, classes_to_visit={'YOLO', 'YOLOHead'}):
def get_node_names(node):
return node.type().str().split('.')

@staticmethod
def get_function_name(node):
return node.type().__repr__().split('.')[-1]

def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None,
classes_to_visit=None, classes_found=None):
graph = module.graph
Expand All @@ -59,21 +63,22 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
preds[i] = {name}, set()
dot.node(name, shape='ellipse')
if input_preds is not None:
pr, op = input_preds[nr]
self.make_edges(pr, f'input_{name}', name, op, parent_dot)
pred, op = input_preds[nr]
self.make_edges(pred, f'input_{name}', name, op, parent_dot)

for node in graph.nodes():
node_inputs = list(node.inputs())
only_first_ops = {'aten::expand_as'}
rel_inp_end = 1 if node.kind() in only_first_ops else None

relevant_inputs = [i for i in list(node.inputs())[:rel_inp_end] if is_relevant_type(i.type())]
relevant_inputs = [i for i in node_inputs[:rel_inp_end] if is_relevant_type(i.type())]
relevant_outputs = [o for o in node.outputs() if is_relevant_type(o.type())]

if node.kind() == 'prim::CallMethod':
fq_submodule_name = '.'.join([nc for nc in self.get_node_names(
list(node.inputs())[0]) if not nc.startswith('__')])
submodule_type = self.get_node_names(list(node.inputs())[0])[-1]
submodule_name = find_name(list(node.inputs())[0], self_input)
node_names = self.get_node_names(node_inputs[0])
fq_submodule_name = '.'.join([nc for nc in node_names if not nc.startswith('__')])
submodule_type = node_names[-1]
submodule_name = find_name(node_inputs[0], self_input)
name = f'{prefix}.{node.output().debugName()}'
label = f'{prefix}{submodule_name} ({submodule_type})'

Expand All @@ -98,7 +103,7 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
dot=sub_dot,
parent_dot=dot,
prefix=sub_prefix,
input_preds=[preds[i] for i in list(node.inputs())[1:]],
input_preds=[preds[i] for i in node_inputs[1:]],
classes_to_visit=classes_to_visit,
classes_found=classes_found,
)
Expand All @@ -108,52 +113,51 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
else:
dot.node(name, label=label, shape='box')
for i in relevant_inputs:
pr, op = preds[i]
self.make_edges(pr, prefix + i.debugName(), name, op, dot)
pred, op = preds[i]
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
for o in node.outputs():
preds[o] = {name}, set()

elif node.kind() == 'prim::CallFunction':
funcname = list(node.inputs())[0].type().__repr__().split('.')[-1]
name = f'{prefix}.{node.output().debugName()}'
label = funcname
dot.node(name, label=label, shape='box')
fun_name = self.get_function_name(node_inputs[0])
dot.node(name, label=fun_name, shape='box')
for i in relevant_inputs:
pr, op = preds[i]
self.make_edges(pr, prefix + i.debugName(), name, op, dot)
pred, op = preds[i]
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
for o in node.outputs():
preds[o] = {name}, set()

else:
label = node.kind().split('::')[-1].rstrip('_')
pr, op = set(), set()
pred, op = set(), set()
for i in relevant_inputs:
apr, aop = preds[i]
pr |= apr
apred, aop = preds[i]
pred |= apred
op |= aop

if node.kind() in self.absorbing_ops:
pr, op = set(), set()
pred, op = set(), set()
elif (len(relevant_inputs) > 0
and len(relevant_outputs) > 0
and node.kind() not in self.unseen_ops):
op.add(label)
for o in node.outputs():
preds[o] = pr, op
preds[o] = pred, op

for i, o in enumerate(graph.outputs()):
name = f'{prefix}output_{i}'
dot.node(name, shape='ellipse')
pr, op = preds[o]
self.make_edges(pr, f'input_{name}', name, op, dot)
pred, op = preds[o]
self.make_edges(pred, f'input_{name}', name, op, dot)

def add_edge(self, dot, n1, n2):
if (n1, n2) not in self.seen_edges:
self.seen_edges.add((n1, n2))
dot.edge(n1, n2)

def make_edges(self, pr, input_name, name, op, edge_dot):
if op:
def make_edges(self, preds, input_name, name, op, edge_dot):
if len(op) > 0:
if input_name not in self.seen_input_names:
self.seen_input_names.add(input_name)
label_lines = [[]]
Expand All @@ -171,11 +175,11 @@ def make_edges(self, pr, input_name, name, op, edge_dot):
shape='box',
style='rounded',
)
for p in pr:
for p in preds:
self.add_edge(edge_dot, p, input_name)
self.add_edge(edge_dot, input_name, name)
else:
for p in pr:
for p in preds:
self.add_edge(edge_dot, p, name)


Expand Down

0 comments on commit a3f8200

Please sign in to comment.