Skip to content

Commit

Permalink
Apparently, ONNX Conv with no 'pads' defaults to zero padding (#5548)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored May 9, 2020
1 parent 47ea99c commit 0c43fa0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
6 changes: 4 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ class Conv(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
# Use shape of input to determine convolution type.
input_shape = infer_shape(inputs[0])

if 'auto_pad' in attr:
attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
Expand All @@ -350,7 +349,10 @@ def _impl_v1(cls, inputs, attr, params):
attr.pop('auto_pad')
elif len(attr['kernel_shape']) == 2:
sym_pad = True
padding = attr['pads']
if 'pads' in attr:
padding = attr['pads']
else:
padding = [0, 0, 0, 0]
for i in range(0, len(padding), 2):
sym_pad = sym_pad and padding[i] == padding[i + 1]

Expand Down
23 changes: 21 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,8 +2028,18 @@ def test_or():
verify_or(indata=[x, y], dtype=bool)


def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET"):
if padding is None:
def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET", unset_pad=False):
if unset_pad:
node = helper.make_node('Conv',
inputs=['x', 'W'],
outputs=['y'],
kernel_shape=kernel_shape,
# Default values for other attributes:
strides=strides,
dilations=dilations,
# groups=1
)
elif padding is None:
node = helper.make_node('Conv',
inputs=['x', 'W'],
outputs=['y'],
Expand Down Expand Up @@ -2095,6 +2105,15 @@ def repeat(N, D):
repeat(1, D),
repeat(1, D),
auto_pad="SAME_UPPER")
# Convolution with unset padding
verify_conv((1, 1) + repeat(5, D),
(1, 1) + repeat(3, D),
(1, 1) + repeat(3, D),
2 * repeat(0, D),
repeat(3, D),
repeat(1, D),
repeat(1, D),
True)
# Convolution with non uniform stride
verify_conv((1, 1) + repeat(5, D),
(1, 1) + repeat(3, D),
Expand Down

0 comments on commit 0c43fa0

Please sign in to comment.