Skip to content

Commit

Permalink
Make get_node_names explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 15, 2021
1 parent 7227319 commit 3519c1b
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions yolort/relaying/ir_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@ def __init__(self, module):

def render(self, classes_to_visit={'YOLO', 'YOLOHead'}):
model_input = next(self.module.graph.inputs())
model_type = model_input.type().str().split('.')[-1]
model_type = self.get_node_names(model_input)[-1]
dot = Digraph(format='svg', graph_attr={'label': model_type, 'labelloc': 't'})
self.make_graph(self.module, dot=dot, classes_to_visit=classes_to_visit)
return dot

@staticmethod
def get_node_names(node):
return node.type().str().split('.')

def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None,
classes_to_visit=None, classes_found=None):
graph = module.graph
Expand All @@ -66,9 +70,9 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
relevant_outputs = [o for o in node.outputs() if is_relevant_type(o.type())]

if node.kind() == 'prim::CallMethod':
fq_submodule_name = '.'.join([
nc for nc in list(node.inputs())[0].type().str().split('.') if not nc.startswith('__')])
submodule_type = list(node.inputs())[0].type().str().split('.')[-1]
fq_submodule_name = '.'.join([nc for nc in self.get_node_names(
list(node.inputs())[0]) if not nc.startswith('__')])
submodule_type = self.get_node_names(list(node.inputs())[0])[-1]
submodule_name = find_name(list(node.inputs())[0], self_input)
name = f'{prefix}.{node.output().debugName()}'
label = f'{prefix}{submodule_name} ({submodule_type})'
Expand Down

0 comments on commit 3519c1b

Please sign in to comment.