Skip to content

Commit

Permalink
Pytorch2onnx which can be used for DIM and GCA model (#105)
Browse files Browse the repository at this point in the history
* add pytorch2onnx for DIM

* support convert to ONNX for GCA

* use forward_dummy for onnx export

* remove useless comment

* retrieve useless modify

* Updating according to the latest mmcv

* modify the comment

* add docstring

* Fix docstring dot

Co-authored-by: Jiamin <[email protected]>
  • Loading branch information
drcut and hejm37 authored Jul 21, 2020
1 parent 68b7d60 commit df88b3f
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,86 @@
import warnings

import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.utils.weight_init import xavier_init
from torch.autograd import Function
from torch.nn.modules.pooling import _MaxUnpoolNd
from torch.nn.modules.utils import _pair

from mmedit.models.registry import COMPONENTS


class MaxUnpool2dop(Function):
"""We warp the `torch.nn.functional.max_unpool2d`
with an extra `symbolic` method, which is needed while exporting to ONNX.
Users should not call this function directly.
"""

@staticmethod
def forward(ctx, input, indices, kernel_size, stride, padding,
output_size):
"""Forward function of MaxUnpool2dop.
Args:
input (Tensor): Tensor needed to upsample.
indices (Tensor): Indices output of the previous MaxPool.
kernel_size (Tuple): Size of the max pooling window.
stride (Tuple): Stride of the max pooling window.
padding (Tuple): Padding that was added to the input.
output_size (List or Tuple): The shape of output tensor.
Returns:
Tensor: Output tensor.
"""
return F.max_unpool2d(input, indices, kernel_size, stride, padding,
output_size)

@staticmethod
def symbolic(g, input, indices, kernel_size, stride, padding, output_size):
warnings.warn(
'The definitions of indices are different between Pytorch and ONNX'
', so the outputs between Pytorch and ONNX maybe different')
return g.op(
'MaxUnpool',
input,
indices,
kernel_shape_i=kernel_size,
strides_i=stride)


class MaxUnpool2d(_MaxUnpoolNd):
"""This module is modified from Pytorch `MaxUnpool2d` module.
Args:
kernel_size (int or tuple): Size of the max pooling window.
stride (int or tuple): Stride of the max pooling window.
Default: None (It is set to `kernel_size` by default).
padding (int or tuple): Padding that is added to the input.
Default: 0.
"""

def __init__(self, kernel_size, stride=None, padding=0):
super(MaxUnpool2d, self).__init__()
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride or kernel_size)
self.padding = _pair(padding)

def forward(self, input, indices, output_size=None):
"""Forward function of MaxUnpool2d.
Args:
input (Tensor): Tensor needed to upsample.
indices (Tensor): Indices output of the previous MaxPool.
output_size (List or Tuple): The shape of output tensor.
Default: None.
Returns:
Tensor: Output tensor.
"""
return MaxUnpool2dop.apply(input, indices, self.kernel_size,
self.stride, self.padding, output_size)


@COMPONENTS.register_module()
class PlainDecoder(nn.Module):
"""Simple decoder from Deep Image Matting.
Expand All @@ -25,7 +102,7 @@ def __init__(self, in_channels):
self.deconv1 = nn.Conv2d(64, 1, kernel_size=5, padding=2)

self.relu = nn.ReLU(inplace=True)
self.max_unpool2d = nn.MaxUnpool2d(kernel_size=2, stride=2)
self.max_unpool2d = MaxUnpool2d(kernel_size=2, stride=2)

def init_weights(self):
"""Init weights for the module.
Expand Down
4 changes: 3 additions & 1 deletion mmedit/models/common/gca_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,10 @@ def pad(self, x, kernel_size, stride):

def get_self_correlation_mask(self, img_feat):
_, _, h, w = img_feat.shape
# As ONNX does not support dynamic num_classes, we have to convert it
# into an integer
self_mask = F.one_hot(
torch.arange(h * w).view(h, w), num_classes=h * w)
torch.arange(h * w).view(h, w), num_classes=int(h * w))
self_mask = self_mask.permute(2, 0, 1).view(1, h * w, h, w)
# use large negative value to mask out self-correlation before softmax
self_mask = self_mask * self.penalty
Expand Down
7 changes: 4 additions & 3 deletions mmedit/models/mattors/dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def _forward(self, x, refine):
refine_input = torch.cat((x[:, :3, :, :], pred_alpha), 1)
pred_refine = self.refiner(refine_input, raw_alpha)
else:
pred_refine = None

# As ONNX does not support NoneType for output,
# we choose to use zero tensor to represent None
pred_refine = torch.zeros([])
return pred_alpha, pred_refine

def forward_dummy(self, inputs):
Expand Down Expand Up @@ -143,7 +144,7 @@ def forward_test(self,
if self.test_cfg.refine:
pred_alpha = pred_refine

pred_alpha = pred_alpha.cpu().numpy().squeeze()
pred_alpha = pred_alpha.detach().cpu().numpy().squeeze()
pred_alpha = self.restore_shape(pred_alpha, meta)
eval_result = self.evaluate(pred_alpha, meta)

Expand Down
3 changes: 1 addition & 2 deletions mmedit/models/mattors/gca.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def forward_test(self,
dict: Contains the predicted alpha and evaluation result.
"""
pred_alpha = self._forward(torch.cat((merged, trimap), 1))

pred_alpha = pred_alpha.cpu().numpy().squeeze()
pred_alpha = pred_alpha.detach().cpu().numpy().squeeze()
pred_alpha = self.restore_shape(pred_alpha, meta)
eval_result = self.evaluate(pred_alpha, meta)

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmedit
known_third_party =PIL,cv2,lmdb,mmcv,numpy,pytest,scipy,torch,torchvision
known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,pytest,scipy,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
137 changes: 137 additions & 0 deletions tools/pytorch2onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import argparse

import mmcv
import numpy as np
import onnx
import onnxruntime as rt
import torch
from mmcv.onnx import register_extra_symbolics
from mmcv.runner import load_checkpoint

from mmedit.datasets.pipelines import Compose
from mmedit.models import build_model


def pytorch2onnx(model,
input,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False):
"""Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX.
Args:
model (nn.Module): Pytorch model we want to export.
input (dict): We need to use this input to execute the model.
opset_version (int): The onnx op version. Default: 11.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between Pytorch and ONNX.
Default: False.
"""
model.cpu().eval()
merged = input['merged'].unsqueeze(0)
trimap = input['trimap'].unsqueeze(0)
input = torch.cat((merged, trimap), 1)
model.forward = model.forward_dummy
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
register_extra_symbolics(opset_version)
with torch.no_grad():
torch.onnx.export(
model,
input,
output_file,
input_names=['cat_input'],
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
opset_version=opset_version)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)

# get pytorch output, only concern pred_alpha
pytorch_result = model(input)
if isinstance(pytorch_result, (tuple, list)):
pytorch_result = pytorch_result[0]
pytorch_result = pytorch_result.detach().numpy()
# get onnx output
sess = rt.InferenceSession(output_file)
onnx_result = sess.run(None, {
'cat_input': input.detach().numpy(),
})
# only concern pred_alpha value
if isinstance(onnx_result, (tuple, list)):
onnx_result = onnx_result[0]
# check the numerical value
assert np.allclose(
pytorch_result,
onnx_result), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')


def parse_args():
parser = argparse.ArgumentParser(description='Convert MMediting to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('img_path', help='path to input image file')
parser.add_argument('trimap_path', help='path to input trimap file')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
args = parser.parse_args()
return args


if __name__ == '__main__':
args = parse_args()

assert args.opset_version == 11, 'MMEditing only support opset 11 now'

config = mmcv.Config.fromfile(args.config)
config.model.pretrained = None
# ONNX does not support spectral norm
if hasattr(config.model.backbone.encoder, 'with_spectral_norm'):
config.model.backbone.encoder.with_spectral_norm = False
config.model.backbone.decoder.with_spectral_norm = False
config.test_cfg.metrics = None

# build the model
model = build_model(config.model, test_cfg=config.test_cfg)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')

# remove alpha from test_pipeline
keys_to_remove = ['alpha', 'ori_alpha']
for key in keys_to_remove:
for pipeline in list(config.test_pipeline):
if 'key' in pipeline and key == pipeline['key']:
config.test_pipeline.remove(pipeline)
if 'keys' in pipeline and key in pipeline['keys']:
pipeline['keys'].remove(key)
if len(pipeline['keys']) == 0:
config.test_pipeline.remove(pipeline)
if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
pipeline['meta_keys'].remove(key)
# build the data pipeline
test_pipeline = Compose(config.test_pipeline)
# prepare data
data = dict(merged_path=args.img_path, trimap_path=args.trimap_path)
data = test_pipeline(data)

# conver model to onnx file
pytorch2onnx(
model,
data,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify)

0 comments on commit df88b3f

Please sign in to comment.