Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX Export Support for Pooling & Convolution (#19831)
Browse files Browse the repository at this point in the history
* max and lp

* pooling global

* remove old implementation

* Update _op_translations.py

* fix

* conv

* Update _op_translations.py
  • Loading branch information
Zha0q1 authored Feb 4, 2021
1 parent efa3eb2 commit 9aff832
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 98 deletions.
186 changes: 91 additions & 95 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,34 +229,45 @@ def convert_weights_and_inputs(node, **kwargs):
return [tval_node]


@mx_op.register("Convolution")
@mx_op.register('Convolution')
def convert_convolution(node, **kwargs):
"""Map MXNet's convolution operator attributes to onnx's Conv operator
and return the created node.
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

kernel_dims = list(parse_helper(attrs, "kernel"))
stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
num_group = int(attrs.get("num_group", 1))
dilations = list(parse_helper(attrs, "dilate", [1, 1]))
kernel = convert_string_to_list(attrs.get('kernel', '()'))
stride = convert_string_to_list(attrs.get('stride', '(1, 1)'))
dilate = convert_string_to_list(attrs.get('dilate', '(1, 1)'))
pad = convert_string_to_list(attrs.get('pad', '(0, 0)'))
num_group = int(attrs.get('num_group', 1))
no_bias = attrs.get('no_bias', 'False')
layout = attrs.get('layout', 'NCHW')

pad_dims = pad_dims + pad_dims
if layout != 'NCHW':
raise NotImplementedError('Pooling currently does not support layout!=\'NCHW\'')

conv_node = onnx.helper.make_node(
"Conv",
inputs=input_nodes,
outputs=[name],
kernel_shape=kernel_dims,
strides=stride_dims,
dilations=dilations,
pads=pad_dims,
group=num_group,
name=name
)
if no_bias == 'True':
assert len(input_nodes) == 2, 'Convolution takes 2 input if no_bias==True'
else:
assert len(input_nodes) == 3, 'Convolution takes 3 input if no_bias==False'

kwargs_ = {}
if kernel:
kwargs_['kernel_shape'] = tuple(kernel)
if pad:
kwargs_['pads'] = tuple(pad) + tuple(pad)
if stride:
kwargs_['strides'] = stride
if dilate:
kwargs_['dilations'] = dilate

return [conv_node]
nodes = [
make_node('Conv', input_nodes, [name], group=num_group, **kwargs_)
]

return nodes


@mx_op.register("Deconvolution")
Expand Down Expand Up @@ -679,92 +690,77 @@ def convert_linalg_gemm2(node, **kwargs):
return [trans_a_node, trans_b_node, matmul_node]


@mx_op.register("Pooling")
@mx_op.register('Pooling')
def convert_pooling(node, **kwargs):
"""Map MXNet's Pooling operator attributes to onnx's
MaxPool/AveragePool/GlobalMaxPool/GlobalAveragePool operators
based on the input node's attributes and return the created node.
"""
opset_version = kwargs["opset_version"]
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

kernel = eval(attrs["kernel"])
pool_type = attrs["pool_type"] if attrs.get("pool_type") else "max"
stride = eval(attrs["stride"]) if attrs.get("stride") else (1, 1)
global_pool = get_boolean_attribute_value(attrs, "global_pool")
p_value = attrs.get('p_value', 'None')

kernel = convert_string_to_list(attrs.get('kernel', '()'))
pool_type = attrs.get('pool_type', 'max')
global_pool = attrs.get('global_pool', 'False')
_ = attrs.get('cudnn_off', 'False')
pooling_convention = attrs.get('pooling_convention', 'valid')
ceil_mode = False
if pooling_convention == 'full':
if opset_version < 10:
pooling_warning = "Pooling: ONNX lower than 1.5.0 doesn't support pooling_convention. " \
"This might lead to shape or accuracy issues. " \
"https://github.com/onnx/onnx/issues/549"
logging.warning(pooling_warning)
ceil_mode = True

pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
pad_dims = pad_dims + pad_dims
pool_types = {"max": "MaxPool", "avg": "AveragePool", "lp": "LpPool"}
global_pool_types = {"max": "GlobalMaxPool", "avg": "GlobalAveragePool",
"lp": "GlobalLpPool"}
stride = convert_string_to_list(attrs.get('stride', '(1, 1)'))
pad = convert_string_to_list(attrs.get('pad', '()'))
p_value = int(attrs.get('p_value', '0'))
count_include_pad = attrs.get('count_include_pad', 'True')
layout = attrs.get('layout', 'NCHW')

if pooling_convention == 'same':
raise NotImplementedError('Pooling currently does not support '
'pooling_convention==\'same\'')
if pool_type == 'sum':
raise NotImplementedError('Pooling currently does not support pool_type==\'sum\'')
if pool_type == 'lp' and global_pool == 'False' and pooling_convention != 'valid':
raise NotImplementedError('Pooling currently does not support '
'pooling_convention!=\'valid\' when pool_type==\'lp\' and global_pool==False')
if layout != 'NCHW':
raise NotImplementedError('Pooling currently does not support layout!=\'NCHW\'')

kwargs_ = {}
if kernel:
kwargs_['kernel_shape'] = tuple(kernel)
if pad:
kwargs_['pads'] = tuple(pad) + tuple(pad)
if stride:
kwargs_['strides'] = stride

ceil_mode = 1 if pooling_convention == 'full' else 0
count_include_pad = 1 if count_include_pad == 'True' else 0

if pool_type == 'lp' and p_value == 'None':
raise AttributeError('ONNX requires a p value for LpPool and GlobalLpPool')

if global_pool:
if pool_type == 'lp':
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
p=int(p_value),
name=name
)
else:
node = onnx.helper.make_node(
global_pool_types[pool_type],
input_nodes, # input
[name],
name=name
)
nodes = []
if pool_type == 'avg' and global_pool == 'False':
nodes += [
make_node('AveragePool', [input_nodes[0]], [name], ceil_mode=ceil_mode,
count_include_pad=count_include_pad, **kwargs_)
]
elif pool_type == 'max' and global_pool == 'False':
nodes += [
make_node('MaxPool', [input_nodes[0]], [name], ceil_mode=ceil_mode, **kwargs_)
]
elif pool_type == 'lp' and global_pool == 'False':
nodes += [
make_node('LpPool', [input_nodes[0]], [name], p=p_value, **kwargs_)
]
elif pool_type == 'avg' and global_pool == 'True':
nodes += [
make_node('GlobalAveragePool', [input_nodes[0]], [name])
]
elif pool_type == 'max' and global_pool == 'True':
nodes += [
make_node('GlobalMaxPool', [input_nodes[0]], [name])
]
elif pool_type == 'lp' and global_pool == 'True':
nodes += [
make_node('GlobalLpPool', [input_nodes[0]], [name], p=p_value)
]
else:
if pool_type == 'lp':
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
p=int(p_value),
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)
else:
if opset_version >= 10:
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name,
ceil_mode=ceil_mode
)
else:
node = onnx.helper.make_node(
pool_types[pool_type],
input_nodes, # input
[name],
kernel_shape=kernel,
pads=pad_dims,
strides=stride,
name=name
)
raise NotImplementedError('Unknown parameter values in Pooling')

return [node]
return nodes


@mx_op.register("exp")
Expand Down
142 changes: 139 additions & 3 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def hybrid_forward(self, F, *inputs):
return func(*inputs, **params)
return Model

def op_export_test(model_name, Model, inputs, tmp_path, dummy_input=False):
def op_export_test(model_name, Model, inputs, tmp_path, dummy_input=False, onnx_map=None):
def export_to_onnx(model, model_name, inputs):
model_path = '{}/{}'.format(tmp_path, model_name)
model.export(model_path, epoch=0)
Expand Down Expand Up @@ -69,9 +69,11 @@ def onnx_rt(onnx_file, inputs):
pred_nat = pred_nat[0]
if isinstance(pred_nat, list):
for i in range(len(pred_nat)):
assert_almost_equal(pred_nat[i], pred_onx[i], equal_nan=True)
pred_onx_i = onnx_map(pred_onx[i]) if onnx_map else pred_onx[i]
assert_almost_equal(pred_nat[i], pred_onx_i, equal_nan=True)
else:
assert_almost_equal(pred_nat, pred_onx[0], equal_nan=True)
pred_onx = onnx_map(pred_onx[0]) if onnx_map else pred_onx[0]
assert_almost_equal(pred_nat, pred_onx, equal_nan=True)


def test_onnx_export_abs(tmp_path):
Expand Down Expand Up @@ -760,6 +762,107 @@ def test_onnx_export_batch_dot(tmp_path, dtype, transpose_a, transpose_b):
op_export_test('batch_dot2', M2, [x2, y2], tmp_path)


@pytest.mark.parametrize('dtype', ['float32'])
@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)])
@pytest.mark.parametrize('count_include_pad', [True, False])
@pytest.mark.parametrize('pooling_convention', ['full', 'valid'])
@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)])
@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)])
@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)])
def test_onnx_export_pooling_avg(tmp_path, dtype, shape, count_include_pad, pooling_convention,
kernel, stride, pad):
# mxnet and onnxruntime has different implementation of count_include_pad on the left column
# and bottom row
if pooling_convention == 'full' and count_include_pad == True:
return
# onnxruntime requires that pad is smaller than kernel
if pad and pad[0] >= kernel[0] and pad[1] >= kernel[1]:
return
x = mx.random.uniform(0, 1, shape, dtype=dtype)
kwargs = {}
if kernel:
kwargs['kernel'] = kernel
if stride:
kwargs['stride'] = stride
if pad:
kwargs['pad'] = pad
M = def_model('Pooling', count_include_pad=count_include_pad, pool_type='avg',
pooling_convention=pooling_convention, **kwargs)
# Note here we use np.nan_to_num to map the onnx output because onnxruntime AveragePool will
# output NaN in some edge cases where mxnet outputs 0
op_export_test('pooling_avg', M, [x], tmp_path, onnx_map=np.nan_to_num)


@pytest.mark.parametrize('dtype', ['float32'])
@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)])
@pytest.mark.parametrize('pooling_convention', ['full', 'valid'])
@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)])
@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)])
@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)])
def test_onnx_export_pooling_max(tmp_path, dtype, shape, pooling_convention, kernel, stride, pad):
# onnxruntime requires that pad is smaller than kernel
if pad and pad[0] >= kernel[0] and pad[1] >= kernel[1]:
return
x = mx.random.uniform(0, 1, shape, dtype=dtype)
kwargs = {}
if kernel:
kwargs['kernel'] = kernel
if stride:
kwargs['stride'] = stride
if pad:
kwargs['pad'] = pad
M = def_model('Pooling', pool_type='max', pooling_convention=pooling_convention, **kwargs)
op_export_test('pooling_max', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float32'])
@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)])
@pytest.mark.parametrize('p_value', [1, 2])
@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)])
@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)])
@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)])
def test_onnx_export_pooling_lp(tmp_path, dtype, shape, p_value, kernel, stride, pad):
# onnxruntime requires that pad is smaller than kernel
if pad and pad[0] >= kernel[0] and pad[1] >= kernel[1]:
return
x = mx.random.uniform(0, 1, shape, dtype=dtype)
kwargs = {}
if kernel:
kwargs['kernel'] = kernel
if stride:
kwargs['stride'] = stride
if pad:
kwargs['pad'] = pad
M = def_model('Pooling', pool_type='lp', pooling_convention='valid',
p_value=p_value, **kwargs)
op_export_test('pooling_lp', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float32'])
@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 1, 60, 60)])
@pytest.mark.parametrize('pool_type', ['avg', 'max', 'lp'])
@pytest.mark.parametrize('p_value', [1, 2])
@pytest.mark.parametrize('kernel', [(3, 3), (14, 14)])
@pytest.mark.parametrize('stride', [None, (3, 4)])
@pytest.mark.parametrize('pad', [None, (3, 4)])
def test_onnx_export_pooling_global(tmp_path, dtype, shape, pool_type, p_value, kernel, stride, pad):
# onnxruntime requires that pad is smaller than kernel
if pad and pad[0] >= kernel[0] and pad[1] >= kernel[1]:
return
x = mx.random.uniform(0, 1, shape, dtype=dtype)
kwargs = {}
if kernel:
kwargs['kernel'] = kernel
if stride:
kwargs['stride'] = stride
if pad:
kwargs['pad'] = pad
# kernel, stride, and pad should have no effect on the results
M = def_model('Pooling', global_pool=True, pool_type=pool_type, pooling_convention='valid',
p_value=p_value, **kwargs)
op_export_test('pooling_global', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32'])
def test_onnx_export_log2(tmp_path, dtype):
x = mx.random.normal(0, 10, (2, 3, 4, 5)).astype(dtype)
Expand Down Expand Up @@ -787,3 +890,36 @@ def test_onnx_export_broadcast_mul(tmp_path, dtype):
x = mx.nd.array([[1,2,3],[4,5,6]], dtype=dtype)
y = mx.nd.array([[0],[3]], dtype=dtype)
op_export_test('broadcast_mul', M, [x, y], tmp_path)


@pytest.mark.parametrize('dtype', ['float32'])
@pytest.mark.parametrize('shape', [(1, 3, 64, 64), (2, 6, 60, 60)])
@pytest.mark.parametrize('num_filter', [2, 4, 32])
@pytest.mark.parametrize('num_group', [1, 2])
@pytest.mark.parametrize('no_bias', [True, False])
@pytest.mark.parametrize('kernel', [(3, 3), (4, 5), (14, 14)])
@pytest.mark.parametrize('stride', [None, (1, 1), (2, 2), (3, 4), (4, 5)])
@pytest.mark.parametrize('pad', [None, (1, 1), (3, 4), (4, 5)])
@pytest.mark.parametrize('dilate', [None, (1, 1)])
def test_onnx_export_convolution(tmp_path, dtype, shape, num_filter, num_group, no_bias,
kernel, stride, pad, dilate):
if shape[1] % num_group:
return
x = mx.random.uniform(0, 1, shape, dtype=dtype)
w_shape = (num_filter,) + (shape[1] // num_group,) + kernel
w = mx.random.uniform(0, 1, w_shape, dtype=dtype)
b_shape = (num_filter)
b = mx.random.uniform(0, 1, b_shape, dtype=dtype)
kwargs = {}
if kernel:
kwargs['kernel'] = kernel
if stride:
kwargs['stride'] = stride
if pad:
kwargs['pad'] = pad
if dilate:
kwargs['dilate'] = dilate
M = def_model('Convolution', num_filter=num_filter, num_group=num_group, no_bias=no_bias,
**kwargs)
inputs = [x, w] if no_bias else [x, w, b]
op_export_test('convolution', M, inputs, tmp_path)

0 comments on commit 9aff832

Please sign in to comment.