Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
hanruobing committed Jul 21, 2020
1 parent 173c954 commit 7417164
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
37 changes: 37 additions & 0 deletions mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,27 @@


class MaxUnpool2dop(Function):
"""We warp the `torch.nn.functional.max_unpool2d`
with an extra `symbolic` method, which is needed while exporting to ONNX.
Users should not call this function directly.
"""

@staticmethod
def forward(ctx, input, indices, kernel_size, stride, padding,
output_size):
"""Forward function of MaxUnpool2dop.
Args:
input (Tensor): Tensor needed to upsample.
indices (Tensor): Indices output of the previous MaxPool.
kernel_size (Tuple): Size of the max pooling window.
stride (Tuple): Stride of the max pooling window.
padding (Tuple): Padding that was added to the input.
output_size (List or Tuple): The shape of output tensor.
Returns:
Tensor: Output tensor.
"""
return F.max_unpool2d(input, indices, kernel_size, stride, padding,
output_size)

Expand All @@ -32,6 +49,15 @@ def symbolic(g, input, indices, kernel_size, stride, padding, output_size):


class MaxUnpool2d(_MaxUnpoolNd):
"""This module is modified from Pytorch `MaxUnpool2d` module.
Args:
kernel_size (int or tuple): Size of the max pooling window.
stride (int or tuple): Stride of the max pooling window.
Default: None (It is set to `kernel_size` by default).
padding (int or tuple): Padding that is added to the input.
Default: 0
"""

def __init__(self, kernel_size, stride=None, padding=0):
super(MaxUnpool2d, self).__init__()
Expand All @@ -40,6 +66,17 @@ def __init__(self, kernel_size, stride=None, padding=0):
self.padding = _pair(padding)

def forward(self, input, indices, output_size=None):
"""Forward function of MaxUnpool2d.
Args:
input (Tensor): Tensor needed to upsample.
indices (Tensor): Indices output of the previous MaxPool.
output_size (List or Tuple): The shape of output tensor.
Default: None.
Returns:
Tensor: Output tensor.
"""
return MaxUnpool2dop.apply(input, indices, self.kernel_size,
self.stride, self.padding, output_size)

Expand Down
17 changes: 15 additions & 2 deletions tools/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ def pytorch2onnx(model,
show=False,
output_file='tmp.onnx',
verify=False):
"""Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX.
Args:
model (nn.Module): Pytorch model we want to export.
input (dict): We need to use this input to execute the model.
opset_version (int): The onnx op version. Default: 11.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between Pytorch and ONNX.
Default: False.
"""
model.cpu().eval()
merged = input['merged'].unsqueeze(0)
trimap = input['trimap'].unsqueeze(0)
Expand Down Expand Up @@ -68,8 +81,8 @@ def parse_args():
parser.add_argument('img_path', help='path to input image file')
parser.add_argument('trimap_path', help='path to input trimap file')
parser.add_argument('--show', action='store_true', help='show onnx graph')
parser.add_argument('--output_file', type=str, default='tmp.onnx')
parser.add_argument('--opset_version', type=int, default=11)
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
'--verify',
action='store_true',
Expand Down

0 comments on commit 7417164

Please sign in to comment.