From 3519c1b2c53c24c549bc881161afbcc9b84df29c Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Wed, 15 Sep 2021 12:10:39 -0400 Subject: [PATCH] Make get_node_names explicit --- yolort/relaying/ir_visualizer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/yolort/relaying/ir_visualizer.py b/yolort/relaying/ir_visualizer.py index 9e36e053f..761bafc6a 100644 --- a/yolort/relaying/ir_visualizer.py +++ b/yolort/relaying/ir_visualizer.py @@ -37,11 +37,15 @@ def __init__(self, module): def render(self, classes_to_visit={'YOLO', 'YOLOHead'}): model_input = next(self.module.graph.inputs()) - model_type = model_input.type().str().split('.')[-1] + model_type = self.get_node_names(model_input)[-1] dot = Digraph(format='svg', graph_attr={'label': model_type, 'labelloc': 't'}) self.make_graph(self.module, dot=dot, classes_to_visit=classes_to_visit) return dot + @staticmethod + def get_node_names(node): + return node.type().str().split('.') + def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None, classes_to_visit=None, classes_found=None): graph = module.graph @@ -66,9 +70,9 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N relevant_outputs = [o for o in node.outputs() if is_relevant_type(o.type())] if node.kind() == 'prim::CallMethod': - fq_submodule_name = '.'.join([ - nc for nc in list(node.inputs())[0].type().str().split('.') if not nc.startswith('__')]) - submodule_type = list(node.inputs())[0].type().str().split('.')[-1] + fq_submodule_name = '.'.join([nc for nc in self.get_node_names( + list(node.inputs())[0]) if not nc.startswith('__')]) + submodule_type = self.get_node_names(list(node.inputs())[0])[-1] submodule_name = find_name(list(node.inputs())[0], self_input) name = f'{prefix}.{node.output().debugName()}' label = f'{prefix}{submodule_name} ({submodule_type})'