From df88b3f9c4ba4e1a846d7e3554b0abefe2e7ee05 Mon Sep 17 00:00:00 2001 From: robin Han Date: Tue, 21 Jul 2020 17:55:47 +0800 Subject: [PATCH] Pytorch2onnx which can be used for DIM and GCA model (#105) * add pytorch2onnx for DIM * support convert to ONNX for GCA * use forward_dummy for onnx export * remove useless comment * retrieve useless modify * Updating according to the latest mmcv * modify the comment * add docstring * Fix docstring dot Co-authored-by: Jiamin <1052020748@qq.com> --- .../decoders/plain_decoder.py | 79 +++++++++- mmedit/models/common/gca_module.py | 4 +- mmedit/models/mattors/dim.py | 7 +- mmedit/models/mattors/gca.py | 3 +- setup.cfg | 2 +- tools/pytorch2onnx.py | 137 ++++++++++++++++++ 6 files changed, 224 insertions(+), 8 deletions(-) create mode 100644 tools/pytorch2onnx.py diff --git a/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py b/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py index 181c1182f8..28d1e70f9d 100644 --- a/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py @@ -1,9 +1,86 @@ +import warnings + import torch.nn as nn +import torch.nn.functional as F from mmcv.cnn.utils.weight_init import xavier_init +from torch.autograd import Function +from torch.nn.modules.pooling import _MaxUnpoolNd +from torch.nn.modules.utils import _pair from mmedit.models.registry import COMPONENTS +class MaxUnpool2dop(Function): + """We warp the `torch.nn.functional.max_unpool2d` + with an extra `symbolic` method, which is needed while exporting to ONNX. + Users should not call this function directly. + """ + + @staticmethod + def forward(ctx, input, indices, kernel_size, stride, padding, + output_size): + """Forward function of MaxUnpool2dop. + + Args: + input (Tensor): Tensor needed to upsample. + indices (Tensor): Indices output of the previous MaxPool. + kernel_size (Tuple): Size of the max pooling window. + stride (Tuple): Stride of the max pooling window. + padding (Tuple): Padding that was added to the input. + output_size (List or Tuple): The shape of output tensor. + + Returns: + Tensor: Output tensor. + """ + return F.max_unpool2d(input, indices, kernel_size, stride, padding, + output_size) + + @staticmethod + def symbolic(g, input, indices, kernel_size, stride, padding, output_size): + warnings.warn( + 'The definitions of indices are different between Pytorch and ONNX' + ', so the outputs between Pytorch and ONNX maybe different') + return g.op( + 'MaxUnpool', + input, + indices, + kernel_shape_i=kernel_size, + strides_i=stride) + + +class MaxUnpool2d(_MaxUnpoolNd): + """This module is modified from Pytorch `MaxUnpool2d` module. + + Args: + kernel_size (int or tuple): Size of the max pooling window. + stride (int or tuple): Stride of the max pooling window. + Default: None (It is set to `kernel_size` by default). + padding (int or tuple): Padding that is added to the input. + Default: 0. + """ + + def __init__(self, kernel_size, stride=None, padding=0): + super(MaxUnpool2d, self).__init__() + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride or kernel_size) + self.padding = _pair(padding) + + def forward(self, input, indices, output_size=None): + """Forward function of MaxUnpool2d. + + Args: + input (Tensor): Tensor needed to upsample. + indices (Tensor): Indices output of the previous MaxPool. + output_size (List or Tuple): The shape of output tensor. + Default: None. + + Returns: + Tensor: Output tensor. + """ + return MaxUnpool2dop.apply(input, indices, self.kernel_size, + self.stride, self.padding, output_size) + + @COMPONENTS.register_module() class PlainDecoder(nn.Module): """Simple decoder from Deep Image Matting. @@ -25,7 +102,7 @@ def __init__(self, in_channels): self.deconv1 = nn.Conv2d(64, 1, kernel_size=5, padding=2) self.relu = nn.ReLU(inplace=True) - self.max_unpool2d = nn.MaxUnpool2d(kernel_size=2, stride=2) + self.max_unpool2d = MaxUnpool2d(kernel_size=2, stride=2) def init_weights(self): """Init weights for the module. diff --git a/mmedit/models/common/gca_module.py b/mmedit/models/common/gca_module.py index 6d6f145d8b..2c8ee69903 100644 --- a/mmedit/models/common/gca_module.py +++ b/mmedit/models/common/gca_module.py @@ -341,8 +341,10 @@ def pad(self, x, kernel_size, stride): def get_self_correlation_mask(self, img_feat): _, _, h, w = img_feat.shape + # As ONNX does not support dynamic num_classes, we have to convert it + # into an integer self_mask = F.one_hot( - torch.arange(h * w).view(h, w), num_classes=h * w) + torch.arange(h * w).view(h, w), num_classes=int(h * w)) self_mask = self_mask.permute(2, 0, 1).view(1, h * w, h, w) # use large negative value to mask out self-correlation before softmax self_mask = self_mask * self.penalty diff --git a/mmedit/models/mattors/dim.py b/mmedit/models/mattors/dim.py index a8f298ffe6..8cbb574f02 100644 --- a/mmedit/models/mattors/dim.py +++ b/mmedit/models/mattors/dim.py @@ -66,8 +66,9 @@ def _forward(self, x, refine): refine_input = torch.cat((x[:, :3, :, :], pred_alpha), 1) pred_refine = self.refiner(refine_input, raw_alpha) else: - pred_refine = None - + # As ONNX does not support NoneType for output, + # we choose to use zero tensor to represent None + pred_refine = torch.zeros([]) return pred_alpha, pred_refine def forward_dummy(self, inputs): @@ -143,7 +144,7 @@ def forward_test(self, if self.test_cfg.refine: pred_alpha = pred_refine - pred_alpha = pred_alpha.cpu().numpy().squeeze() + pred_alpha = pred_alpha.detach().cpu().numpy().squeeze() pred_alpha = self.restore_shape(pred_alpha, meta) eval_result = self.evaluate(pred_alpha, meta) diff --git a/mmedit/models/mattors/gca.py b/mmedit/models/mattors/gca.py index 0b52648d3b..d0c6cf137c 100644 --- a/mmedit/models/mattors/gca.py +++ b/mmedit/models/mattors/gca.py @@ -91,8 +91,7 @@ def forward_test(self, dict: Contains the predicted alpha and evaluation result. """ pred_alpha = self._forward(torch.cat((merged, trimap), 1)) - - pred_alpha = pred_alpha.cpu().numpy().squeeze() + pred_alpha = pred_alpha.detach().cpu().numpy().squeeze() pred_alpha = self.restore_shape(pred_alpha, meta) eval_result = self.evaluate(pred_alpha, meta) diff --git a/setup.cfg b/setup.cfg index fc1b712740..a214a1d264 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,6 +17,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = pkg_resources,setuptools known_first_party = mmedit -known_third_party =PIL,cv2,lmdb,mmcv,numpy,pytest,scipy,torch,torchvision +known_third_party =PIL,cv2,lmdb,mmcv,numpy,onnx,onnxruntime,pytest,scipy,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py new file mode 100644 index 0000000000..cd3adaf331 --- /dev/null +++ b/tools/pytorch2onnx.py @@ -0,0 +1,137 @@ +import argparse + +import mmcv +import numpy as np +import onnx +import onnxruntime as rt +import torch +from mmcv.onnx import register_extra_symbolics +from mmcv.runner import load_checkpoint + +from mmedit.datasets.pipelines import Compose +from mmedit.models import build_model + + +def pytorch2onnx(model, + input, + opset_version=11, + show=False, + output_file='tmp.onnx', + verify=False): + """Export Pytorch model to ONNX model and verify the outputs are same + between Pytorch and ONNX. + + Args: + model (nn.Module): Pytorch model we want to export. + input (dict): We need to use this input to execute the model. + opset_version (int): The onnx op version. Default: 11. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the output ONNX model. + Default: `tmp.onnx`. + verify (bool): Whether compare the outputs between Pytorch and ONNX. + Default: False. + """ + model.cpu().eval() + merged = input['merged'].unsqueeze(0) + trimap = input['trimap'].unsqueeze(0) + input = torch.cat((merged, trimap), 1) + model.forward = model.forward_dummy + # pytorch has some bug in pytorch1.3, we have to fix it + # by replacing these existing op + register_extra_symbolics(opset_version) + with torch.no_grad(): + torch.onnx.export( + model, + input, + output_file, + input_names=['cat_input'], + export_params=True, + keep_initializers_as_inputs=True, + verbose=show, + opset_version=opset_version) + print(f'Successfully exported ONNX model: {output_file}') + if verify: + # check by onnx + onnx_model = onnx.load(output_file) + onnx.checker.check_model(onnx_model) + + # get pytorch output, only concern pred_alpha + pytorch_result = model(input) + if isinstance(pytorch_result, (tuple, list)): + pytorch_result = pytorch_result[0] + pytorch_result = pytorch_result.detach().numpy() + # get onnx output + sess = rt.InferenceSession(output_file) + onnx_result = sess.run(None, { + 'cat_input': input.detach().numpy(), + }) + # only concern pred_alpha value + if isinstance(onnx_result, (tuple, list)): + onnx_result = onnx_result[0] + # check the numerical value + assert np.allclose( + pytorch_result, + onnx_result), 'The outputs are different between Pytorch and ONNX' + print('The numerical values are same between Pytorch and ONNX') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Convert MMediting to ONNX') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('img_path', help='path to input image file') + parser.add_argument('trimap_path', help='path to input trimap file') + parser.add_argument('--show', action='store_true', help='show onnx graph') + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--opset-version', type=int, default=11) + parser.add_argument( + '--verify', + action='store_true', + help='verify the onnx model output against pytorch output') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + assert args.opset_version == 11, 'MMEditing only support opset 11 now' + + config = mmcv.Config.fromfile(args.config) + config.model.pretrained = None + # ONNX does not support spectral norm + if hasattr(config.model.backbone.encoder, 'with_spectral_norm'): + config.model.backbone.encoder.with_spectral_norm = False + config.model.backbone.decoder.with_spectral_norm = False + config.test_cfg.metrics = None + + # build the model + model = build_model(config.model, test_cfg=config.test_cfg) + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') + + # remove alpha from test_pipeline + keys_to_remove = ['alpha', 'ori_alpha'] + for key in keys_to_remove: + for pipeline in list(config.test_pipeline): + if 'key' in pipeline and key == pipeline['key']: + config.test_pipeline.remove(pipeline) + if 'keys' in pipeline and key in pipeline['keys']: + pipeline['keys'].remove(key) + if len(pipeline['keys']) == 0: + config.test_pipeline.remove(pipeline) + if 'meta_keys' in pipeline and key in pipeline['meta_keys']: + pipeline['meta_keys'].remove(key) + # build the data pipeline + test_pipeline = Compose(config.test_pipeline) + # prepare data + data = dict(merged_path=args.img_path, trimap_path=args.trimap_path) + data = test_pipeline(data) + + # conver model to onnx file + pytorch2onnx( + model, + data, + opset_version=args.opset_version, + show=args.show, + output_file=args.output_file, + verify=args.verify)