Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch2onnx which can be used for DIM and GCA model #105

Merged
merged 9 commits into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,86 @@
import warnings

import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.utils.weight_init import xavier_init
from torch.autograd import Function
from torch.nn.modules.pooling import _MaxUnpoolNd
from torch.nn.modules.utils import _pair

from mmedit.models.registry import COMPONENTS


class MaxUnpool2dop(Function):
"""We warp the `torch.nn.functional.max_unpool2d`
with an extra `symbolic` method, which is needed while exporting to ONNX.
Users should not call this function directly.
"""

hejm37 marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def forward(ctx, input, indices, kernel_size, stride, padding,
output_size):
hejm37 marked this conversation as resolved.
Show resolved Hide resolved
"""Forward function of MaxUnpool2dop.

Args:
input (Tensor): Tensor needed to upsample.
indices (Tensor): Indices output of the previous MaxPool.
kernel_size (Tuple): Size of the max pooling window.
stride (Tuple): Stride of the max pooling window.
padding (Tuple): Padding that was added to the input.
output_size (List or Tuple): The shape of output tensor.

Returns:
Tensor: Output tensor.
"""
return F.max_unpool2d(input, indices, kernel_size, stride, padding,
output_size)

@staticmethod
def symbolic(g, input, indices, kernel_size, stride, padding, output_size):
warnings.warn(
hejm37 marked this conversation as resolved.
Show resolved Hide resolved
'The definitions of indices are different between Pytorch and ONNX'
', so the outputs between Pytorch and ONNX maybe different')
return g.op(
'MaxUnpool',
input,
indices,
kernel_shape_i=kernel_size,
strides_i=stride)


class MaxUnpool2d(_MaxUnpoolNd):
"""This module is modified from Pytorch `MaxUnpool2d` module.

Args:
kernel_size (int or tuple): Size of the max pooling window.
stride (int or tuple): Stride of the max pooling window.
Default: None (It is set to `kernel_size` by default).
padding (int or tuple): Padding that is added to the input.
hejm37 marked this conversation as resolved.
Show resolved Hide resolved
Default: 0.
"""

def __init__(self, kernel_size, stride=None, padding=0):
super(MaxUnpool2d, self).__init__()
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride or kernel_size)
self.padding = _pair(padding)

def forward(self, input, indices, output_size=None):
"""Forward function of MaxUnpool2d.

Args:
input (Tensor): Tensor needed to upsample.
indices (Tensor): Indices output of the previous MaxPool.
output_size (List or Tuple): The shape of output tensor.
Default: None.

Returns:
Tensor: Output tensor.
"""
return MaxUnpool2dop.apply(input, indices, self.kernel_size,
self.stride, self.padding, output_size)


@COMPONENTS.register_module()
class PlainDecoder(nn.Module):
"""Simple decoder from Deep Image Matting.
Expand All @@ -25,7 +102,7 @@ def __init__(self, in_channels):
self.deconv1 = nn.Conv2d(64, 1, kernel_size=5, padding=2)

self.relu = nn.ReLU(inplace=True)
self.max_unpool2d = nn.MaxUnpool2d(kernel_size=2, stride=2)
self.max_unpool2d = MaxUnpool2d(kernel_size=2, stride=2)

def init_weights(self):
"""Init weights for the module.
Expand Down
4 changes: 3 additions & 1 deletion mmedit/models/common/gca_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,10 @@ def pad(self, x, kernel_size, stride):

def get_self_correlation_mask(self, img_feat):
_, _, h, w = img_feat.shape
# As ONNX does not support dynamic num_classes, we have to convert it
# into an integer
self_mask = F.one_hot(
torch.arange(h * w).view(h, w), num_classes=h * w)
torch.arange(h * w).view(h, w), num_classes=int(h * w))
self_mask = self_mask.permute(2, 0, 1).view(1, h * w, h, w)
# use large negative value to mask out self-correlation before softmax
self_mask = self_mask * self.penalty
Expand Down
7 changes: 4 additions & 3 deletions mmedit/models/mattors/dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def _forward(self, x, refine):
refine_input = torch.cat((x[:, :3, :, :], pred_alpha), 1)
pred_refine = self.refiner(refine_input, raw_alpha)
else:
pred_refine = None

hejm37 marked this conversation as resolved.
Show resolved Hide resolved
# As ONNX does not support NoneType for output,
# we choose to use zero tensor to represent None
pred_refine = torch.zeros([])
return pred_alpha, pred_refine

def forward_dummy(self, inputs):
Expand Down Expand Up @@ -143,7 +144,7 @@ def forward_test(self,
if self.test_cfg.refine:
pred_alpha = pred_refine

pred_alpha = pred_alpha.cpu().numpy().squeeze()
pred_alpha = pred_alpha.detach().cpu().numpy().squeeze()
pred_alpha = self.restore_shape(pred_alpha, meta)
eval_result = self.evaluate(pred_alpha, meta)

Expand Down
3 changes: 1 addition & 2 deletions mmedit/models/mattors/gca.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def forward_test(self,
dict: Contains the predicted alpha and evaluation result.
"""
pred_alpha = self._forward(torch.cat((merged, trimap), 1))

pred_alpha = pred_alpha.cpu().numpy().squeeze()
pred_alpha = pred_alpha.detach().cpu().numpy().squeeze()
pred_alpha = self.restore_shape(pred_alpha, meta)
eval_result = self.evaluate(pred_alpha, meta)

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmedit
known_third_party =PIL,cv2,lmdb,mmcv,numpy,pytest,scipy,torch,torchvision
known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,pytest,scipy,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
137 changes: 137 additions & 0 deletions tools/pytorch2onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import argparse

import mmcv
import numpy as np
import onnx
import onnxruntime as rt
import torch
from mmcv.onnx import register_extra_symbolics
from mmcv.runner import load_checkpoint

from mmedit.datasets.pipelines import Compose
from mmedit.models import build_model


def pytorch2onnx(model,
input,
opset_version=11,
show=False,
output_file='tmp.onnx',
verify=False):
"""Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX.

Args:
model (nn.Module): Pytorch model we want to export.
input (dict): We need to use this input to execute the model.
opset_version (int): The onnx op version. Default: 11.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between Pytorch and ONNX.
Default: False.
"""
model.cpu().eval()
hejm37 marked this conversation as resolved.
Show resolved Hide resolved
merged = input['merged'].unsqueeze(0)
trimap = input['trimap'].unsqueeze(0)
input = torch.cat((merged, trimap), 1)
model.forward = model.forward_dummy
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
register_extra_symbolics(opset_version)
with torch.no_grad():
torch.onnx.export(
model,
input,
output_file,
input_names=['cat_input'],
export_params=True,
keep_initializers_as_inputs=True,
verbose=show,
opset_version=opset_version)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)

# get pytorch output, only concern pred_alpha
pytorch_result = model(input)
if isinstance(pytorch_result, (tuple, list)):
pytorch_result = pytorch_result[0]
pytorch_result = pytorch_result.detach().numpy()
# get onnx output
sess = rt.InferenceSession(output_file)
onnx_result = sess.run(None, {
'cat_input': input.detach().numpy(),
})
# only concern pred_alpha value
if isinstance(onnx_result, (tuple, list)):
onnx_result = onnx_result[0]
# check the numerical value
assert np.allclose(
pytorch_result,
onnx_result), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')


def parse_args():
parser = argparse.ArgumentParser(description='Convert MMediting to ONNX')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('img_path', help='path to input image file')
parser.add_argument('trimap_path', help='path to input trimap file')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
help='verify the onnx model output against pytorch output')
args = parser.parse_args()
return args


if __name__ == '__main__':
args = parse_args()

assert args.opset_version == 11, 'MMEditing only support opset 11 now'

config = mmcv.Config.fromfile(args.config)
config.model.pretrained = None
# ONNX does not support spectral norm
if hasattr(config.model.backbone.encoder, 'with_spectral_norm'):
config.model.backbone.encoder.with_spectral_norm = False
config.model.backbone.decoder.with_spectral_norm = False
config.test_cfg.metrics = None

# build the model
model = build_model(config.model, test_cfg=config.test_cfg)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')

# remove alpha from test_pipeline
keys_to_remove = ['alpha', 'ori_alpha']
for key in keys_to_remove:
for pipeline in list(config.test_pipeline):
if 'key' in pipeline and key == pipeline['key']:
config.test_pipeline.remove(pipeline)
if 'keys' in pipeline and key in pipeline['keys']:
pipeline['keys'].remove(key)
if len(pipeline['keys']) == 0:
config.test_pipeline.remove(pipeline)
if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
pipeline['meta_keys'].remove(key)
# build the data pipeline
test_pipeline = Compose(config.test_pipeline)
# prepare data
data = dict(merged_path=args.img_path, trimap_path=args.trimap_path)
data = test_pipeline(data)

# conver model to onnx file
pytorch2onnx(
model,
data,
opset_version=args.opset_version,
show=args.show,
output_file=args.output_file,
verify=args.verify)