diff --git a/yolort/relaying/__init__.py b/yolort/relaying/__init__.py index e065479d2..828222256 100644 --- a/yolort/relaying/__init__.py +++ b/yolort/relaying/__init__.py @@ -1,2 +1,2 @@ -# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. from .trace_wrapper import get_trace_module diff --git a/yolort/relaying/graph_utils.py b/yolort/relaying/graph_utils.py deleted file mode 100644 index 5726fc8a3..000000000 --- a/yolort/relaying/graph_utils.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright (c) 2020, Thomas Viehmann -""" -Visualizing JIT Modules - -Mostly copy-paste from https://github.com/t-vi/pytorch-tvmisc/tree/master/hacks -with license under the CC-BY-SA 4.0. - -Please link to Thomas's blog post or the original github source (linked from the -blog post) with the attribution notice. -""" -from graphviz import Digraph - - -def make_graph(mod, classes_to_visit=None, classes_found=None, dot=None, - prefix="", input_preds=None, parent_dot=None): - preds = {} - - def find_name(i, self_input, suffix=None): - if i == self_input: - return suffix - cur = i.node().s('name') - if suffix is not None: - cur = f'{cur}.{suffix}' - of = next(i.node().inputs()) - return find_name(of, self_input, suffix=cur) - - gr = mod.graph - # list(traced_model.graph.nodes())[0] - self_input = next(gr.inputs()) - self_type = self_input.type().str().split('.')[-1] - preds[self_input] = (set(), set()) # inps, ops - - if dot is None: - dot = Digraph(format='svg', graph_attr={'label': self_type, 'labelloc': 't'}) - # dot.attr('node', shape='box') - - seen_inpnames = set() - seen_edges = set() - - def add_edge(dot, n1, n2): - if (n1, n2) not in seen_edges: - seen_edges.add((n1, n2)) - dot.edge(n1, n2) - - def make_edges(pr, inpname, name, op, edge_dot=dot): - if op: - if inpname not in seen_inpnames: - seen_inpnames.add(inpname) - label_lines = [[]] - line_len = 0 - for w in op: - if line_len >= 20: - label_lines.append([]) - line_len = 0 - label_lines[-1].append(w) - line_len += len(w) + 1 - - edge_dot.node( - inpname, - label='\n'.join([' '.join(w) for w in label_lines]), - shape='box', - style='rounded', - ) - for p in pr: - add_edge(edge_dot, p, inpname) - add_edge(edge_dot, inpname, name) - else: - for p in pr: - add_edge(edge_dot, p, name) - - for nr, i in enumerate(list(gr.inputs())[1:]): - name = f'{prefix}inp_{i.debugName()}' - preds[i] = {name}, set() - dot.node(name, shape='ellipse') - if input_preds is not None: - pr, op = input_preds[nr] - make_edges(pr, f'inp_{name}', name, op, edge_dot=parent_dot) - - def is_relevant_type(t): - kind = t.kind() - if kind == 'TensorType': - return True - if kind in ('ListType', 'OptionalType'): - return is_relevant_type(t.getElementType()) - if kind == 'TupleType': - return any([is_relevant_type(tt) for tt in t.elements()]) - return False - - for n in gr.nodes(): - only_first_ops = {'aten::expand_as'} - rel_inp_end = 1 if n.kind() in only_first_ops else None - - relevant_inputs = [i for i in list(n.inputs())[:rel_inp_end] if is_relevant_type(i.type())] - relevant_outputs = [o for o in n.outputs() if is_relevant_type(o.type())] - - if n.kind() == 'prim::CallMethod': - fq_submodule_name = '.'.join([ - nc for nc in list(n.inputs())[0].type().str().split('.') if not nc.startswith('__')]) - submodule_type = list(n.inputs())[0].type().str().split('.')[-1] - submodule_name = find_name(list(n.inputs())[0], self_input) - name = f'{prefix}.{n.output().debugName()}' - label = f'{prefix}{submodule_name} ({submodule_type})' - - if classes_found is not None: - classes_found.add(fq_submodule_name) - - if ((classes_to_visit is None and (not fq_submodule_name.startswith('torch.nn') - or fq_submodule_name.startswith('torch.nn.modules.container'))) - or (classes_to_visit is not None and (submodule_type in classes_to_visit - or fq_submodule_name in classes_to_visit))): - - # go into subgraph - sub_prefix = prefix + submodule_name + '.' - with dot.subgraph(name=f'cluster_{name}') as sub_dot: - sub_dot.attr(label=label) - submod = mod - for k in submodule_name.split('.'): - submod = getattr(submod, k) - - make_graph( - submod, - dot=sub_dot, - prefix=sub_prefix, - input_preds=[preds[i] for i in list(n.inputs())[1:]], - parent_dot=dot, - classes_to_visit=classes_to_visit, - classes_found=classes_found, - ) - - for i, o in enumerate(n.outputs()): - preds[o] = {sub_prefix + f'out_{i}'}, set() - else: - dot.node(name, label=label, shape='box') - for i in relevant_inputs: - pr, op = preds[i] - make_edges(pr, prefix + i.debugName(), name, op) - for o in n.outputs(): - preds[o] = {name}, set() - - elif n.kind() == 'prim::CallFunction': - funcname = list(n.inputs())[0].type().__repr__().split('.')[-1] - name = prefix + '.' + n.output().debugName() - label = funcname - dot.node(name, label=label, shape='box') - for i in relevant_inputs: - pr, op = preds[i] - make_edges(pr, prefix + i.debugName(), name, op) - for o in n.outputs(): - preds[o] = {name}, set() - - else: - unseen_ops = { - 'prim::ListConstruct', 'prim::TupleConstruct', 'aten::index', 'aten::size', - 'aten::slice', 'aten::unsqueeze', 'aten::squeeze', 'aten::to', - 'aten::view', 'aten::permute', 'aten::transpose', 'aten::contiguous', - 'aten::permute', 'aten::Int', 'prim::TupleUnpack', 'prim::ListUnpack', - 'aten::unbind', 'aten::select', 'aten::detach', 'aten::stack', - 'aten::reshape', 'aten::split_with_sizes', 'aten::cat', 'aten::expand', - 'aten::expand_as', 'aten::_shape_as_tensor', - } - - # probably also partially absorbing ops. :/ - absorbing_ops = ('aten::size', 'aten::_shape_as_tensor') - - label = n.kind().split('::')[-1].rstrip('_') - pr, op = set(), set() - for i in relevant_inputs: - apr, aop = preds[i] - pr |= apr - op |= aop - # if pr and n.kind() not in unseen_ops: - # print(n.kind(), n) - if n.kind() in absorbing_ops: - pr, op = set(), set() - elif len(relevant_inputs) > 0 and len(relevant_outputs) > 0 and n.kind() not in unseen_ops: - op.add(label) - for o in n.outputs(): - preds[o] = pr, op - - for i, o in enumerate(gr.outputs()): - name = prefix + f'out_{i}' - dot.node(name, shape='ellipse') - pr, op = preds[o] - make_edges(pr, f'inp_{name}', name, op) - return dot diff --git a/yolort/relaying/ir_visualizer.py b/yolort/relaying/ir_visualizer.py new file mode 100644 index 000000000..c9a3eed82 --- /dev/null +++ b/yolort/relaying/ir_visualizer.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021, Zhiqiang Wang +# Copyright (c) 2020, Thomas Viehmann +""" +Visualizing JIT Modules + +Modified from https://github.com/t-vi/pytorch-tvmisc/tree/master/hacks +with license under the CC-BY-SA 4.0. + +Please link to Thomas's blog post or the original github source (linked from the +blog post) with the attribution notice. +""" +from graphviz import Digraph + + +class TorchScriptVisualizer: + def __init__(self, module): + + self.module = module + self.seen_edges = set() + self.seen_input_names = set() + + self.unseen_ops = { + 'prim::ListConstruct', 'prim::TupleConstruct', 'aten::index', 'aten::size', + 'aten::slice', 'aten::unsqueeze', 'aten::squeeze', 'aten::to', + 'aten::view', 'aten::permute', 'aten::transpose', 'aten::contiguous', + 'aten::permute', 'aten::Int', 'prim::TupleUnpack', 'prim::ListUnpack', + 'aten::unbind', 'aten::select', 'aten::detach', 'aten::stack', + 'aten::reshape', 'aten::split_with_sizes', 'aten::cat', 'aten::expand', + 'aten::expand_as', 'aten::_shape_as_tensor', + } + # probably also partially absorbing ops. :/ + self.absorbing_ops = ('aten::size', 'aten::_shape_as_tensor') + + def render(self, classes_to_visit={'YOLO', 'YOLOHead'}): + return self.make_graph(self.module, classes_to_visit=classes_to_visit) + + def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None, + classes_to_visit=None, classes_found=None): + graph = module.graph + preds = {} + + self_input = next(graph.inputs()) + self_type = self_input.type().str().split('.')[-1] + preds[self_input] = (set(), set()) # inps, ops + + if dot is None: + dot = Digraph(format='svg', graph_attr={'label': self_type, 'labelloc': 't'}) + + for nr, i in enumerate(list(graph.inputs())[1:]): + name = f'{prefix}input_{i.debugName()}' + preds[i] = {name}, set() + dot.node(name, shape='ellipse') + if input_preds is not None: + pr, op = input_preds[nr] + self.make_edges(pr, f'input_{name}', name, op, parent_dot) + + for graph_node in graph.nodes(): + only_first_ops = {'aten::expand_as'} + rel_inp_end = 1 if graph_node.kind() in only_first_ops else None + + relevant_inputs = [i for i in list(graph_node.inputs())[:rel_inp_end] if is_relevant_type(i.type())] + relevant_outputs = [o for o in graph_node.outputs() if is_relevant_type(o.type())] + + if graph_node.kind() == 'prim::CallMethod': + fq_submodule_name = '.'.join([ + nc for nc in list(graph_node.inputs())[0].type().str().split('.') if not nc.startswith('__')]) + submodule_type = list(graph_node.inputs())[0].type().str().split('.')[-1] + submodule_name = find_name(list(graph_node.inputs())[0], self_input) + name = f'{prefix}.{graph_node.output().debugName()}' + label = f'{prefix}{submodule_name} ({submodule_type})' + + if classes_found is not None: + classes_found.add(fq_submodule_name) + + if ((classes_to_visit is None and (not fq_submodule_name.startswith('torch.nn') + or fq_submodule_name.startswith('torch.nn.modules.container'))) + or (classes_to_visit is not None and (submodule_type in classes_to_visit + or fq_submodule_name in classes_to_visit))): + + # go into subgraph + sub_prefix = prefix + submodule_name + '.' + with dot.subgraph(name=f'cluster_{name}') as sub_dot: + sub_dot.attr(label=label) + sub_module = module + for k in submodule_name.split('.'): + sub_module = getattr(sub_module, k) + + self.make_graph( + sub_module, + dot=sub_dot, + parent_dot=dot, + prefix=sub_prefix, + input_preds=[preds[i] for i in list(graph_node.inputs())[1:]], + classes_to_visit=classes_to_visit, + classes_found=classes_found, + ) + + for i, o in enumerate(graph_node.outputs()): + preds[o] = {sub_prefix + f'output_{i}'}, set() + else: + dot.node(name, label=label, shape='box') + for i in relevant_inputs: + pr, op = preds[i] + self.make_edges(pr, prefix + i.debugName(), name, op, dot) + for o in graph_node.outputs(): + preds[o] = {name}, set() + + elif graph_node.kind() == 'prim::CallFunction': + funcname = list(graph_node.inputs())[0].type().__repr__().split('.')[-1] + name = prefix + '.' + graph_node.output().debugName() + label = funcname + dot.node(name, label=label, shape='box') + for i in relevant_inputs: + pr, op = preds[i] + self.make_edges(pr, prefix + i.debugName(), name, op, dot) + for o in graph_node.outputs(): + preds[o] = {name}, set() + + else: + label = graph_node.kind().split('::')[-1].rstrip('_') + pr, op = set(), set() + for i in relevant_inputs: + apr, aop = preds[i] + pr |= apr + op |= aop + + if graph_node.kind() in self.absorbing_ops: + pr, op = set(), set() + elif len(relevant_inputs) > 0 and len(relevant_outputs) > 0 and graph_node.kind() not in self.unseen_ops: + op.add(label) + for o in graph_node.outputs(): + preds[o] = pr, op + + for i, o in enumerate(graph.outputs()): + name = prefix + f'output_{i}' + dot.node(name, shape='ellipse') + pr, op = preds[o] + self.make_edges(pr, f'input_{name}', name, op, dot) + + return dot + + def add_edge(self, dot, n1, n2): + if (n1, n2) not in self.seen_edges: + self.seen_edges.add((n1, n2)) + dot.edge(n1, n2) + + def make_edges(self, pr, input_name, name, op, edge_dot): + if op: + if input_name not in self.seen_input_names: + self.seen_input_names.add(input_name) + label_lines = [[]] + line_len = 0 + for w in op: + if line_len >= 20: + label_lines.append([]) + line_len = 0 + label_lines[-1].append(w) + line_len += len(w) + 1 + + edge_dot.node( + input_name, + label='\n'.join([' '.join(w) for w in label_lines]), + shape='box', + style='rounded', + ) + for p in pr: + self.add_edge(edge_dot, p, input_name) + self.add_edge(edge_dot, input_name, name) + else: + for p in pr: + self.add_edge(edge_dot, p, name) + + +def find_name(layer_input, self_input, suffix=None): + if layer_input == self_input: + return suffix + cur = layer_input.node().s('name') + if suffix is not None: + cur = f'{cur}.{suffix}' + of = next(layer_input.node().inputs()) + return find_name(of, self_input, suffix=cur) + + +def is_relevant_type(t): + kind = t.kind() + if kind == 'TensorType': + return True + if kind in ('ListType', 'OptionalType'): + return is_relevant_type(t.getElementType()) + if kind == 'TupleType': + return any([is_relevant_type(tt) for tt in t.elements()]) + return False diff --git a/yolort/relaying/trace_wrapper.py b/yolort/relaying/trace_wrapper.py index a8ac739b1..b65e3ea94 100644 --- a/yolort/relaying/trace_wrapper.py +++ b/yolort/relaying/trace_wrapper.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. from typing import Dict, Tuple, Callable import torch