Skip to content

Commit

Permalink
Added pool autopadding and simplified parsers. (apache#4672)
Browse files Browse the repository at this point in the history
  • Loading branch information
jwfromm authored and alexwong committed Feb 26, 2020
1 parent 868442b commit c481dff
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 127 deletions.
122 changes: 61 additions & 61 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""ONNX: Open Neural Network Exchange frontend for Relay."""
from __future__ import absolute_import as _abs

from functools import partial
import numpy as np
import tvm
from ... import nd as _nd
Expand Down Expand Up @@ -81,18 +80,28 @@ def get_pad_pair(input1d, kernel1d, stride1d):
return [pad_before, pad_after]


def onnx_default_layout(dims):
if dims == 1:
return 'NCW'
if dims == 2:
return 'NCHW'

msg = "Only 1d and 2d layouts are currently supported"
raise tvm.error.OpAttributeInvalid(msg.format(op_name))


def onnx_storage_order2layout(storage_order, dims=2):
"""converter of onnx storage order parameter to tvm storage order format"""
if storage_order not in (0, 1):
raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1')

if dims == 1:
return 'NCW' if storage_order == 0 else 'NWC'
elif dims == 2:
if dims == 2:
return 'NCHW' if storage_order == 0 else 'NHWC'
else:
msg = "Only 1d and 2d layouts are currently supported"
raise tvm.error.OpAttributeInvalid(msg.format(op_name))

msg = "Only 1d and 2d layouts are currently supported"
raise tvm.error.OpAttributeInvalid(msg.format(op_name))


def dimension_constraint():
Expand Down Expand Up @@ -135,15 +144,28 @@ def get_converter(cls, opset):
version, cls.__name__))


class Unary(OnnxOpConverter):
""" A helper class for unary op converters.
"""
name = ''

@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 1, "Unary math op {} takes 1 input, {} given".format(
cls.name, len(inputs))
op_name = cls.name
return get_relay_op(op_name)(*inputs)


class Elemwise(OnnxOpConverter):
""" A helper class for elemwise op converters.
"""
name = ''

@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(
len(inputs))
assert len(inputs) == 2, "Math op {} take 2 inputs, {} given".format(
cls.name, len(inputs))
op_name = cls.name
conv_ops = ["conv2d", "conv2d_transpose"]
if attr.get('broadcast', 0) and any(x in str(inputs[0]) for x in conv_ops):
Expand All @@ -160,26 +182,48 @@ class Pool(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
input_shape = infer_shape(inputs[0])
if 'auto_pad' in attr:
attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
pad_tuple = []
for axis in range(len(input_shape) - 2):
axis_shape = input_shape[2 + axis]
stride = attr['strides'][axis]
kernel = attr['kernel_shape'][axis]
pad = get_pad_pair(axis_shape, kernel, stride)
pad_tuple.append(pad)
pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
attr['pads'] = pad_tuple
elif attr['auto_pad'] == 'VALID':
attr['pads'] = 0
elif attr['auto_pad'] == 'NOTSET':
pass
else:
msg = 'Value {} in attribute "auto_pad" of operator {} is invalid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'], cls.name))
attr.pop("auto_pad")

if 'storage_order' in attr:
attr['layout'] = onnx_storage_order2layout(attr['storage_order'],
dims=(len(input_shape) - 2))
else:
attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2))

return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad)
'pads': ('padding', 0)
},
# very weird attributes here in onnx, force check
ignores=['dilations', 'auto_pad'],
# TODO(zhreshold): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
ignores=['dilations'],
custom_check=dimension_constraint())(inputs, attr, params)


class Absolute(OnnxOpConverter):
class Absolute(Unary):
""" Operator converter for Absolute.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.nn.relu(inputs[0]) + _op.nn.relu(_op.negative(inputs[0]))
name = 'abs'


class Add(Elemwise):
Expand Down Expand Up @@ -387,50 +431,6 @@ class MaxPool(Pool):
"""
name = 'max_pool'

@classmethod
def _impl_v8(cls, inputs, attr, params):
return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'storage_order': ('layout', 'NCHW', onnx_storage_order2layout),
},
# very weird attributes here in onnx, force check
ignores=['dilations', 'auto_pad'],
# TODO(higumachan): make sure ceil_mode in onnx, and layout?
extras={'ceil_mode': False},
custom_check=dimension_constraint())(inputs, attr, params)

@classmethod
def _impl_v10(cls, inputs, attr, params):
input_shape = infer_shape(inputs[0])
# 1D Convolution
if len(input_shape) == 3:
return AttrCvt(
op_name="max_pool1d",
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0)),
'storage_order': ('layout', 'NCW', partial(onnx_storage_order2layout, dims=1)),
'ceil_mode': 'ceil_mode'
},
ignores=['dilations', 'auto_pad'])(inputs, attr, params)
#2D Convolution
if len(input_shape) == 4:
return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={
'kernel_shape': 'pool_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'storage_order': ('layout', 'NCHW', onnx_storage_order2layout),
'ceil_mode': 'ceil_mode'
},
# very weird attributes here in onnx, force check
ignores=['dilations', 'auto_pad'],
custom_check=dimension_constraint())(inputs, attr, params)

raise tvm.error.OpAttributeInvalid("Only 1D and 2D maxpooling are currently supported.")

class Mul(Elemwise):
""" Operator converter for Multiply.
Expand Down
128 changes: 62 additions & 66 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,7 +1827,7 @@ def forward(self, input):
relay.frontend.from_onnx(onnx_model, {'0': input_size})


def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode):
def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_pad="NOTSET"):
x_np = np.random.uniform(size=x_shape).astype('float32')

if mode == 'max':
Expand All @@ -1837,12 +1837,20 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode):
else:
raise ValueError("Pool method {} is not supported.".format(mode))

pool_node = helper.make_node(node_type,
inputs=["x"],
outputs=["y"],
kernel_shape=kernel_shape,
pads=pads,
strides=strides)
if pads is None:
pool_node = helper.make_node(node_type,
inputs=["x"],
outputs=["y"],
kernel_shape=kernel_shape,
auto_pad=auto_pad,
strides=strides)
else:
pool_node = helper.make_node(node_type,
inputs=["x"],
outputs=["y"],
kernel_shape=kernel_shape,
pads=pads,
strides=strides)

graph = helper.make_graph([pool_node],
"pooling_test",
Expand All @@ -1860,65 +1868,53 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode):
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)

def test_pooling():
# MaxPool1D
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[1],
pads=[1, 1],
out_shape=[1, 1, 32],
mode='max')
# MaxPool2D
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[1, 1],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 32, 32],
mode='max')

#AveragePool1D
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[1],
pads=[1, 1],
out_shape=[1, 1, 32],
mode='average')
#AveragePool2D
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[1, 1],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 32, 32],
mode='average')

# MaxPool1D with stride
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[2],
pads=[1, 1],
out_shape=[1, 1, 16],
mode='max')
# MaxPool2D with stride
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[2, 2],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 16, 16],
mode='max')

#AveragePool1D with stride
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[2],
pads=[1, 1],
out_shape=[1, 1, 16],
mode='average')
#AveragePool2D with stride
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[2, 2],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 16, 16],
mode='average')
for mode in ['max', 'average']:
# Pool1D
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[1],
pads=[1, 1],
out_shape=[1, 1, 32],
mode=mode)
# Pool2D
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[1, 1],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 32, 32],
mode=mode)

# Pool1D with stride
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[2],
pads=[1, 1],
out_shape=[1, 1, 16],
mode=mode)
# Pool2D with stride
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[2, 2],
pads=[1, 1, 1, 1],
out_shape=[1, 1, 16, 16],
mode=mode)

# Pool1D with stride and autopadding
verify_pooling(x_shape=[1, 1, 32],
kernel_shape=[3],
strides=[2],
pads=None,
out_shape=[1, 1, 16],
mode=mode,
auto_pad='SAME_UPPER')
# Pool2D with stride and autopadding
verify_pooling(x_shape=[1, 1, 32, 32],
kernel_shape=[3, 3],
strides=[2, 2],
pads=None,
out_shape=[1, 1, 16, 16],
mode=mode,
auto_pad='SAME_UPPER')


if __name__ == '__main__':
Expand Down

0 comments on commit c481dff

Please sign in to comment.