Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dynamic export and visualize to pytorch2onnx #463

Merged
merged 9 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions docs/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,32 @@ The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pt

We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.

```shell
python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output-file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
```bash
python tools/pytorch2onnx.py \
${CONFIG_FILE} \
--checkpoint ${CHECKPOINT_FILE} \
--output-file ${ONNX_FILE} \
--input-img ${INPUT_IMG} \
--shape ${INPUT_SHAPE} \
--show \
--verify \
--dynamic-export \
--cfg-options \
model.test_cfg.mode="whole"
```

Description of arguments:

- `config` : The path of a model config file.
- `--checkpoint` : The path of a model checkpoint file.
- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`.
- `--input-img` : The path of an input image for conversion and visualize.
- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `256 256`.
- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`.
- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`.
- `dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`.
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
- `cfg-options`:Update config options.
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved

**Note**: This tool is still experimental. Some customized operators are not supported for now.

## Miscellaneous
Expand Down
12 changes: 10 additions & 2 deletions mmseg/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def show_result_pyplot(model,
result,
palette=None,
fig_size=(15, 10),
opacity=0.5):
opacity=0.5,
title='',
block=True):
"""Visualize the segmentation results on the image.

Args:
Expand All @@ -117,11 +119,17 @@ def show_result_pyplot(model,
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
title (str): The title of pyplot figure.
Default is ''.
block (bool): Whether to block the pyplot figure.
Default is False.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(
img, result, palette=palette, show=False, opacity=opacity)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.show()
plt.title(title)
plt.tight_layout()
plt.show(block=block)
7 changes: 6 additions & 1 deletion mmseg/models/segmentors/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,14 @@ def whole_inference(self, img, img_meta, rescale):

seg_logit = self.encode_decode(img, img_meta)
if rescale:
# support dynamic shape for onnx
if torch.onnx.is_in_onnx_export():
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
new_size = img.shape[2:]
else:
new_size = img_meta[0]['ori_shape'][:2]
seg_logit = resize(
seg_logit,
size=img_meta[0]['ori_shape'][:2],
size=new_size,
mode='bilinear',
align_corners=self.align_corners,
warning=False)
Expand Down
3 changes: 0 additions & 3 deletions mmseg/ops/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F

Expand All @@ -24,8 +23,6 @@ def resize(input,
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
if isinstance(size, torch.Size):
size = tuple(int(x) for x in size)
return F.interpolate(input, size, scale_factor, mode, align_corners)


Expand Down
172 changes: 156 additions & 16 deletions tools/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
import torch
import torch._C
import torch.serialization
from mmcv import DictAction
from mmcv.onnx import register_extra_symbolics
from mmcv.runner import load_checkpoint
from torch import nn

from mmseg.apis import show_result_pyplot
from mmseg.apis.inference import LoadImage
from mmseg.datasets.pipelines import Compose
from mmseg.models import build_segmentor

torch.manual_seed(3)
Expand Down Expand Up @@ -67,25 +71,60 @@ def _demo_mm_inputs(input_shape, num_classes):
return mm_inputs


def _prepare_input_img(img_path, test_pipeline, shape=None):
# build the data pipeline
if shape is not None:
test_pipeline[1]['img_scale'] = shape
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
test_pipeline = [LoadImage()] + test_pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img_path)
data = test_pipeline(data)
imgs = data['img']
img_metas = [i.data for i in data['img_metas']]

mm_inputs = {'imgs': imgs, 'img_metas': img_metas}

return mm_inputs


def _update_input_img(img_list, img_meta_list):
N, C, H, W = img_list[0].shape
img_meta = img_meta_list[0][0]
new_img_meta_list = [[{
'img_shape': (H, W, C),
'ori_shape': (H, W, C),
'pad_shape': (H, W, C),
'filename': img_meta['filename'],
'scale_factor': 1.,
'flip': False,
} for _ in range(N)]]

return img_list, new_img_meta_list


def pytorch2onnx(model,
input_shape,
mm_inputs,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False):
verify=False,
dynamic_export=False):
"""Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX.

Args:
model (nn.Module): Pytorch model we want to export.
input_shape (tuple): Use this input shape to construct
the corresponding dummy input and execute the model.
mm_inputs (dict): Contain the input tensors and img_metas infomation.
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
opset_version (int): The onnx op version. Default: 11.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between Pytorch and ONNX.
Default: False.
dynamic_export (bool): Whether to export ONNX with dynamic axis.
Default: False.
"""
model.cpu().eval()

Expand All @@ -94,28 +133,45 @@ def pytorch2onnx(model,
else:
num_classes = model.decode_head.num_classes

mm_inputs = _demo_mm_inputs(input_shape, num_classes)

imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
ori_shape = img_metas[0]['ori_shape']

img_list = [img[None, :] for img in imgs]
img_meta_list = [[img_meta] for img_meta in img_metas]
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)

# replace original forward function
origin_forward = model.forward
model.forward = partial(
model.forward, img_metas=img_meta_list, return_loss=False)
dynamic_axes = None
if dynamic_export:
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
1: 'batch',
2: 'height',
3: 'width'
}
}

register_extra_symbolics(opset_version)
with torch.no_grad():
torch.onnx.export(
model, (img_list, ),
output_file,
input_names=['input'],
output_names=['output'],
export_params=True,
keep_initializers_as_inputs=True,
keep_initializers_as_inputs=False,
verbose=show,
opset_version=opset_version)
opset_version=opset_version,
dynamic_axes=dynamic_axes)
print(f'Successfully exported ONNX model: {output_file}')
model.forward = origin_forward

Expand All @@ -125,9 +181,28 @@ def pytorch2onnx(model,
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)

if dynamic_export:
# scale image for dynamic shape test
img_list = [
nn.functional.interpolate(_, scale_factor=1.5)
for _ in img_list
]
# concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [
torch.cat((ori_img, flip_img), 0)
for ori_img, flip_img in zip(img_list, flip_img_list)
]

# update img_meta
img_list, img_meta_list = _update_input_img(
img_list, img_meta_list)

# check the numerical value
# get pytorch output
pytorch_result = model(img_list, img_meta_list, return_loss=False)[0]
with torch.no_grad():
pytorch_result = model(img_list, img_meta_list, return_loss=False)
pytorch_result = np.stack(pytorch_result, 0)

# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
Expand All @@ -138,17 +213,51 @@ def pytorch2onnx(model,
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file)
onnx_result = sess.run(
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
if not np.allclose(pytorch_result, onnx_result):
raise ValueError(
'The outputs are different between Pytorch and ONNX')
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0]
# show segmentation results
if show:
import cv2
import os.path as osp
img = img_meta_list[0][0]['filename']
if not osp.exists(img):
img = imgs[0][:3, ...].permute(1, 2, 0) * 255
img = img.detach().numpy().astype(np.uint8)
# resize onnx_result to ori_shape
onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8),
(ori_shape[1], ori_shape[0]))
show_result_pyplot(
model,
img, (onnx_result_, ),
palette=model.PALETTE,
block=False,
title='ONNXRuntime',
opacity=0.5)

# resize pytorch_result to ori_shape
pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8),
(ori_shape[1], ori_shape[0]))
show_result_pyplot(
model,
img, (pytorch_result_, ),
title='PyTorch',
palette=model.PALETTE,
opacity=0.5)
# compare results
np.testing.assert_allclose(
pytorch_result.astype(np.float32) / num_classes,
onnx_result.astype(np.float32) / num_classes,
rtol=1e-5,
atol=1e-5,
err_msg='The outputs are different between Pytorch and ONNX')
print('The outputs are same between Pytorch and ONNX')


def parse_args():
parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
parser.add_argument(
'--input-img', type=str, help='Images for input', default=None)
parser.add_argument('--show', action='store_true', help='show onnx graph')
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument(
'--verify', action='store_true', help='verify the onnx model')
Expand All @@ -160,6 +269,20 @@ def parse_args():
nargs='+',
default=[256, 256],
help='input image size')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--dynamic-export',
action='store_true',
help='Wether to export onnx with dynamic axis.')
args = parser.parse_args()
return args

Expand All @@ -178,6 +301,8 @@ def parse_args():
raise ValueError('invalid input shape')

cfg = mmcv.Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
cfg.model.pretrained = None

# build the model and load checkpoint
Expand All @@ -188,13 +313,28 @@ def parse_args():
segmentor = _convert_batchnorm(segmentor)

if args.checkpoint:
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
checkpoint = load_checkpoint(
segmentor, args.checkpoint, map_location='cpu')
segmentor.CLASSES = checkpoint['meta']['CLASSES']
segmentor.PALETTE = checkpoint['meta']['PALETTE']

# read input or create dummpy input
if args.input_img is not None:
mm_inputs = _prepare_input_img(args.input_img, cfg.data.test.pipeline,
(input_shape[3], input_shape[2]))
else:
if isinstance(segmentor.decode_head, nn.ModuleList):
num_classes = segmentor.decode_head[-1].num_classes
else:
num_classes = segmentor.decode_head.num_classes
mm_inputs = _demo_mm_inputs(input_shape, num_classes)

# conver model to onnx file
pytorch2onnx(
segmentor,
input_shape,
mm_inputs,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify)
verify=args.verify,
dynamic_export=args.dynamic_export)