Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor graph visualization #165

Merged
merged 16 commits into from
Sep 16, 2021
Merged
164 changes: 31 additions & 133 deletions notebooks/export-relay-inference-tvm.ipynb

Large diffs are not rendered by default.

776 changes: 606 additions & 170 deletions notebooks/model-graph-visualization.ipynb

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions test/test_relaying.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from torch.jit._trace import TopLevelTracedModule

from yolort.models import yolov5s
from yolort.relaying import get_trace_module


def test_get_trace_module():
model_func = yolov5s(pretrained=True)
script_module = get_trace_module(model_func, input_shape=(416, 320))
assert isinstance(script_module, TopLevelTracedModule)
assert script_module.code is not None
2 changes: 2 additions & 0 deletions yolort/relaying/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from .trace_wrapper import get_trace_module
222 changes: 222 additions & 0 deletions yolort/relaying/ir_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright (c) 2021, Zhiqiang Wang
# Copyright (c) 2020, Thomas Viehmann
"""
Visualizing JIT Modules

Modified from https://github.com/t-vi/pytorch-tvmisc/tree/master/hacks
with license under the CC-BY-SA 4.0.

Please link to Thomas's blog post or the original github source (linked from the
blog post) with the attribution notice.
"""
from collections import OrderedDict
from graphviz import Digraph


class TorchScriptVisualizer:
def __init__(self, module):

self.module = module

self.unseen_ops = {
'prim::ListConstruct', 'prim::ListUnpack',
'prim::TupleConstruct', 'prim::TupleUnpack',
'aten::Int',
'aten::unbind', 'aten::detach',
'aten::contiguous', 'aten::to',
'aten::unsqueeze', 'aten::squeeze',
'aten::index', 'aten::slice', 'aten::select',
'aten::constant_pad_nd',
'aten::size', 'aten::split_with_sizes',
'aten::expand_as', 'aten::expand',
'aten::_shape_as_tensor',
}
# probably also partially absorbing ops. :/
self.absorbing_ops = ('aten::size', 'aten::_shape_as_tensor')

def render(
self,
classes_to_visit={'YOLO', 'YOLOHead'},
format='svg',
labelloc='t',
attr_size='8,7',
):
self.clean_status()

model_input = next(self.module.graph.inputs())
model_type = self.get_node_names(model_input)[-1]
dot = Digraph(
format=format,
graph_attr={'label': model_type, 'labelloc': labelloc},
)
self.make_graph(self.module, dot=dot, classes_to_visit=classes_to_visit)

dot.attr(size=attr_size)
return dot

def clean_status(self):
self._seen_edges = set()
self._seen_input_names = set()
self._predictions = OrderedDict()

@staticmethod
def get_node_names(node):
return node.type().str().split('.')

@staticmethod
def get_function_name(node):
return node.type().__repr__().split('.')[-1]

def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None,
classes_to_visit=None, classes_found=None):
graph = module.graph

self_input = next(graph.inputs())
self._predictions[self_input] = (set(), set()) # Stand for `input` and `op` respectively

for nr, i in enumerate(list(graph.inputs())[1:]):
name = f'{prefix}input_{i.debugName()}'
self._predictions[i] = {name}, set()
dot.node(name, shape='ellipse')
if input_preds is not None:
pred, op = input_preds[nr]
self.make_edges(pred, f'input_{name}', name, op, parent_dot)

for node in graph.nodes():
node_inputs = list(node.inputs())
only_first_ops = {'aten::expand_as'}
rel_inp_end = 1 if node.kind() in only_first_ops else None

relevant_inputs = [i for i in node_inputs[:rel_inp_end] if is_relevant_type(i.type())]
relevant_outputs = [o for o in node.outputs() if is_relevant_type(o.type())]

if node.kind() == 'prim::CallMethod':
node_names = self.get_node_names(node_inputs[0])
fq_submodule_name = '.'.join([nc for nc in node_names if not nc.startswith('__')])
submodule_type = node_names[-1]
submodule_name = find_name(node_inputs[0], self_input)
name = f'{prefix}.{node.output().debugName()}'
label = f'{prefix}{submodule_name} ({submodule_type})'

if classes_found is not None:
classes_found.add(fq_submodule_name)

if ((classes_to_visit is None and (not fq_submodule_name.startswith('torch.nn')
or fq_submodule_name.startswith('torch.nn.modules.container')))
or (classes_to_visit is not None and (submodule_type in classes_to_visit
or fq_submodule_name in classes_to_visit))):

# go into subgraph
sub_prefix = f'{prefix}{submodule_name}.'

for i, o in enumerate(node.outputs()):
self._predictions[o] = {f'{sub_prefix}output_{i}'}, set()

with dot.subgraph(name=f'cluster_{name}') as sub_dot:
sub_dot.attr(label=label)
sub_module = module
for k in submodule_name.split('.'):
sub_module = getattr(sub_module, k)

self.make_graph(
sub_module,
dot=sub_dot,
parent_dot=dot,
prefix=sub_prefix,
input_preds=[self._predictions[i] for i in node_inputs[1:]],
classes_to_visit=classes_to_visit,
classes_found=classes_found,
)

else:
dot.node(name, label=label, shape='box')
for i in relevant_inputs:
pred, op = self._predictions[i]
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
for o in node.outputs():
self._predictions[o] = {name}, set()

elif node.kind() == 'prim::CallFunction':
name = f'{prefix}.{node.output().debugName()}'
fun_name = self.get_function_name(node_inputs[0])
dot.node(name, label=fun_name, shape='box')
for i in relevant_inputs:
pred, op = self._predictions[i]
self.make_edges(pred, prefix + i.debugName(), name, op, dot)
for o in node.outputs():
self._predictions[o] = {name}, set()

else:
label = node.kind().split('::')[-1].rstrip('_')
pred, op = set(), set()
for i in relevant_inputs:
apred, aop = self._predictions[i]
pred |= apred
op |= aop

if node.kind() in self.absorbing_ops:
pred, op = set(), set()
elif (len(relevant_inputs) > 0
and len(relevant_outputs) > 0
and node.kind() not in self.unseen_ops):
op.add(label)
for o in node.outputs():
self._predictions[o] = pred, op

for i, o in enumerate(graph.outputs()):
name = f'{prefix}output_{i}'
dot.node(name, shape='ellipse')
pred, op = self._predictions[o]
self.make_edges(pred, f'input_{name}', name, op, dot)

def add_edge(self, dot, n1, n2):
if (n1, n2) not in self._seen_edges:
self._seen_edges.add((n1, n2))
dot.edge(n1, n2)

def make_edges(self, preds, input_name, name, op, edge_dot):
if len(op) > 0:
if input_name not in self._seen_input_names:
self._seen_input_names.add(input_name)
label_lines = [[]]
line_len = 0
for w in op:
if line_len >= 20:
label_lines.append([])
line_len = 0
label_lines[-1].append(w)
line_len += len(w) + 1

edge_dot.node(
input_name,
label='\n'.join([' '.join(w) for w in label_lines]),
shape='box',
style='rounded',
)
for p in preds:
self.add_edge(edge_dot, p, input_name)
self.add_edge(edge_dot, input_name, name)
else:
for p in preds:
self.add_edge(edge_dot, p, name)


def find_name(layer_input, self_input, suffix=None):
if layer_input == self_input:
return suffix
cur = layer_input.node().s('name')
if suffix is not None:
cur = f'{cur}.{suffix}'
of = next(layer_input.node().inputs())
return find_name(of, self_input, suffix=cur)


def is_relevant_type(t):
kind = t.kind()
if kind == 'TensorType':
return True
if kind in ('ListType', 'OptionalType'):
return is_relevant_type(t.getElementType())
if kind == 'TupleType':
return any([is_relevant_type(tt) for tt in t.elements()])
return False
63 changes: 63 additions & 0 deletions yolort/relaying/trace_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from typing import Dict, Tuple, Callable

import torch
from torch import nn, Tensor


def dict_to_tuple(out_dict: Dict[str, Tensor]) -> Tuple:
"""
Convert the model output dictionary to tuple format.
"""
if "masks" in out_dict.keys():
return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
return out_dict["boxes"], out_dict["scores"], out_dict["labels"]


class TraceWrapper(nn.Module):
"""
This is a wrapper for `torch.jit.trace`, as there are some scenarios
where `torch.jit.script` support is limited.
"""
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, x):
out = self.model(x)
return dict_to_tuple(out[0])


@torch.no_grad()
def get_trace_module(
model_func: Callable[..., nn.Module],
input_shape: Tuple[int, int] = (416, 416),
):
"""
Get the tarcing of a given model function.

Example:

>>> from yolort.models import yolov5s
>>> from yolort.relaying.trace_wrapper import get_trace_module
>>>
>>> model = yolov5s(pretrained=True)
>>> tracing_module = get_trace_module(model)
>>> print(tracing_module.code)
def forward(self,
x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
_0, _1, _2, = (self.model).forward(x, )
return (_0, _1, _2)

Args:
model_func (Callable): The model function to be traced.
input_shape (Tuple[int, int]): Shape size of the input image.
"""
model = TraceWrapper(model_func)
model.eval()

dummy_input = torch.rand(1, 3, *input_shape)
trace_module = torch.jit.trace(model, dummy_input)
trace_module.eval()

return trace_module
Loading