diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 57ef546de29f..ce20b7da4418 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2784,6 +2784,47 @@ def convert_arange(node, **kwargs): return nodes +@mx_op.register("reverse") +def convert_reverse(node, **kwargs): + """Map MXNet's reverse operator attributes to ONNX + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + axis = int(attrs.get('axis', 0)) + + # Transpose takes perm as a parameter, so we must 'pad' the input to a known dim (10 here) + perm = [i for i in range(10)] + perm[0], perm[axis] = axis, 0 + + nodes = [ + create_tensor([10], name+'_10', kwargs['initializer']), + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([1], name+'_1', kwargs['initializer']), + create_tensor([-1], name+'_m1', kwargs['initializer']), + create_tensor([axis], name+'_axis', kwargs['initializer']), + create_tensor([axis+1], name+'_axis_p1', kwargs['initializer']), + create_tensor([], name+'_void', kwargs['initializer']), + create_const_scalar_node(name+'_m1_s', np.array([-1], dtype='int64'), kwargs), + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Shape', [name+'_shape'], [name+'_dim']), + make_node('Sub', [name+'_10', name+'_dim'], [name+'_sub']), + make_node('Concat', [name+'_0', name+'_sub'], [name+'_concat'], axis=0), + make_node('Pad', [name+'_shape', name+'_concat', name+'_1'], [name+'_shape_10_dim']), + make_node('Reshape', [input_nodes[0], name+'_shape_10_dim'], [name+'_data_10_dim']), + make_node('Transpose', [name+'_data_10_dim'], [name+'_data_t'], perm=perm), + make_node('Slice', [name+'_shape', name+'_axis', name+'_axis_p1'], [name+'_axis_len']), + make_node('Sub', [name+'_axis_len', name+'_1'], [name+'_axis_len_m1']), + make_node('Reshape', [name+'_axis_len_m1', name+'_void'], [name+'_axis_len_m1_s']), + make_node('Range', [name+'_axis_len_m1_s', name+'_m1_s', name+'_m1_s'], [name+'_indices']), + make_node('Gather', [name+'_data_t', name+'_indices'], [name+'_gather']), + make_node('Transpose', [name+'_gather'], [name+'_data_reversed'], perm=perm), + make_node('Reshape', [name+'_data_reversed', name+'_shape'], [name], name=name) + ] + + return nodes + + @mx_op.register('repeat') def convert_repeat(node, **kwargs): """Map MXNet's repeat operator attributes to onnx's Tile operator. diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index c17a03bc3276..2efa582d4c95 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -353,6 +353,14 @@ def test_onnx_export_softmax(tmp_path, dtype): op_export_test('softmax_4', M4, [x, l4], tmp_path) +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) +@pytest.mark.parametrize('axis', [0, 1, 2, 3]) +def test_onnx_export_reverse(tmp_path, dtype, axis): + x = mx.nd.arange(0, 120, dtype=dtype).reshape((2, 3, 4, 5)) + M = def_model('reverse', axis=axis) + op_export_test('reverse', M, [x], tmp_path) + + @pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) @pytest.mark.parametrize('axis', [None, 0, 1, 2, -1, -2, -3]) @pytest.mark.parametrize('repeats', [2, 1, 3])