From 74171646733dd5871a29832068898a1f4e4ccde4 Mon Sep 17 00:00:00 2001 From: hanruobing Date: Tue, 21 Jul 2020 13:58:35 +0800 Subject: [PATCH] add docstring --- .../decoders/plain_decoder.py | 37 +++++++++++++++++++ tools/pytorch2onnx.py | 17 ++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py b/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py index 87c604ece3..b16ca0c38b 100644 --- a/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py @@ -11,10 +11,27 @@ class MaxUnpool2dop(Function): + """We warp the `torch.nn.functional.max_unpool2d` + with an extra `symbolic` method, which is needed while exporting to ONNX. + Users should not call this function directly. + """ @staticmethod def forward(ctx, input, indices, kernel_size, stride, padding, output_size): + """Forward function of MaxUnpool2dop. + + Args: + input (Tensor): Tensor needed to upsample. + indices (Tensor): Indices output of the previous MaxPool. + kernel_size (Tuple): Size of the max pooling window. + stride (Tuple): Stride of the max pooling window. + padding (Tuple): Padding that was added to the input. + output_size (List or Tuple): The shape of output tensor. + + Returns: + Tensor: Output tensor. + """ return F.max_unpool2d(input, indices, kernel_size, stride, padding, output_size) @@ -32,6 +49,15 @@ def symbolic(g, input, indices, kernel_size, stride, padding, output_size): class MaxUnpool2d(_MaxUnpoolNd): + """This module is modified from Pytorch `MaxUnpool2d` module. + + Args: + kernel_size (int or tuple): Size of the max pooling window. + stride (int or tuple): Stride of the max pooling window. + Default: None (It is set to `kernel_size` by default). + padding (int or tuple): Padding that is added to the input. + Default: 0 + """ def __init__(self, kernel_size, stride=None, padding=0): super(MaxUnpool2d, self).__init__() @@ -40,6 +66,17 @@ def __init__(self, kernel_size, stride=None, padding=0): self.padding = _pair(padding) def forward(self, input, indices, output_size=None): + """Forward function of MaxUnpool2d. + + Args: + input (Tensor): Tensor needed to upsample. + indices (Tensor): Indices output of the previous MaxPool. + output_size (List or Tuple): The shape of output tensor. + Default: None. + + Returns: + Tensor: Output tensor. + """ return MaxUnpool2dop.apply(input, indices, self.kernel_size, self.stride, self.padding, output_size) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 01cc007667..f2a9ce8daa 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -17,6 +17,19 @@ def pytorch2onnx(model, show=False, output_file='tmp.onnx', verify=False): + """Export Pytorch model to ONNX model and verify the outputs are same + between Pytorch and ONNX. + + Args: + model (nn.Module): Pytorch model we want to export. + input (dict): We need to use this input to execute the model. + opset_version (int): The onnx op version. Default: 11. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the output ONNX model. + Default: `tmp.onnx`. + verify (bool): Whether compare the outputs between Pytorch and ONNX. + Default: False. + """ model.cpu().eval() merged = input['merged'].unsqueeze(0) trimap = input['trimap'].unsqueeze(0) @@ -68,8 +81,8 @@ def parse_args(): parser.add_argument('img_path', help='path to input image file') parser.add_argument('trimap_path', help='path to input trimap file') parser.add_argument('--show', action='store_true', help='show onnx graph') - parser.add_argument('--output_file', type=str, default='tmp.onnx') - parser.add_argument('--opset_version', type=int, default=11) + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--opset-version', type=int, default=11) parser.add_argument( '--verify', action='store_true',