Skip to content

Commit

Permalink
[ONNX] Fix onnx convtranspose error (apache#9938)
Browse files Browse the repository at this point in the history
* fix mix up of channels with conv2d-transpose

* add grouped convtranspose tests

* turn off groups for non-llvm test
  • Loading branch information
AndrewZhaoLuo authored and crazydemo committed Jan 27, 2022
1 parent 469b25b commit 7a42530
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
18 changes: 9 additions & 9 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
from .. import ty as _ty
from .. import vision as _vision
from .common import (
autopad,
AttrCvt,
Renamer,
autopad,
ensure_scalar_shape,
fold_constant,
get_name,
Expand Down Expand Up @@ -557,13 +557,13 @@ class ConvTranspose(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
# get number of channels
out_type = infer_type(inputs[1])
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
channels = out_shapes[0][1]
attr["channels"] = channels
kernel_shape = [get_const_tuple(out_type.checked_type.shape)]
out_channels = kernel_shape[0][1] * attr.get("group", 1)
attr["channels"] = out_channels
groups = attr.get("group", 1)

if "kernel_shape" not in attr:
attr["kernel_shape"] = out_shapes[0][2:]
attr["kernel_shape"] = kernel_shape[0][2:]

attr["groups"] = groups
# infer pads for auto_pad
Expand Down Expand Up @@ -612,13 +612,13 @@ def _impl_v1(cls, inputs, attr, params):
def _impl_v11(cls, inputs, attr, params):
# get number of channels
out_type = infer_type(inputs[1])
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
channels = out_shapes[0][1]
attr["channels"] = channels
kernel_shape = [get_const_tuple(out_type.checked_type.shape)]
out_channels = kernel_shape[0][1] * attr.get("group", 1)
attr["channels"] = out_channels
groups = attr.get("group", 1)

if "kernel_shape" not in attr:
attr["kernel_shape"] = out_shapes[0][2:]
attr["kernel_shape"] = kernel_shape[0][2:]

attr["groups"] = groups
# infer pads for auto_pad
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def verify_with_ort_with_inputs(
opt_level=opt_level,
convert_config=convert_config,
)

if not isinstance(tvm_out, list):
tvm_out = [tvm_out]
if not isinstance(ort_out, list):
Expand Down Expand Up @@ -2892,6 +2893,14 @@ def verify_convtranspose(x_shape, w_shape, y_shape, p, group=1):
# Test undefined groups.
verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2], group=None)

if "llvm" in target:
# GPU does not support groups != 1 for convtranspose, so only test llvm
# Test depthwise-convolution
verify_convtranspose((1, 10, 3, 3), (10, 1, 3, 3), (1, 10, 7, 3), [1, 2, 1, 2], group=10)

# Test grouped-convolution
verify_convtranspose((1, 10, 3, 3), (10, 1, 3, 3), (1, 5, 7, 3), [1, 2, 1, 2], group=5)

def repeat(N, D):
return tuple([N for _ in range(D)])

Expand Down

0 comments on commit 7a42530

Please sign in to comment.