From 47cca6d9108cb6b9b123a99781532d0b0c189f10 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 5 Feb 2021 10:27:23 -0800 Subject: [PATCH] rewrite take --- .../contrib/onnx/mx2onnx/_op_translations.py | 67 ++++++++++++++++--- tests/python-pytest/onnx/test_operators.py | 12 ++++ 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index f472e18d8b26..873cb5c67cb1 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2397,18 +2397,67 @@ def convert_topk(node, **kwargs): def convert_take(node, **kwargs): """Map MXNet's Take operator attributes to onnx's Gather operator. """ + from onnx.helper import make_node + from onnx import TensorProto name, input_nodes, attrs = get_inputs(node, kwargs) - axis = int(attrs.get('axis', 0)) + mode = str(attrs.get('mode', 'clip')) - node = onnx.helper.make_node( - "Gather", - input_nodes, - [name], - axis=axis, - name=name, - ) - return [node] + data = input_nodes[0] + indices = input_nodes[1] + + nodes = [ + make_node('Cast', [indices], [name+'_indices'], to=int(TensorProto.INT64)), + ] + + if mode == 'raise': + nodes += [ + make_node('Gather', [data, name+'_indices'], [name], axis=axis, name=name) + ] + + return nodes + + nodes += [ + create_tensor([-1], name+'_-1', kwargs["initializer"]), + make_node('Shape', [data], [name+'_data_shape']), + ] + + # cornor case + if axis == -1: + nodes += [ + make_node('Shape', [name+'_data_shape'], [name+'_data_dim']), + make_node('Add', [name+'_data_dim', name+'_-1'], [name+'_axis_max']), + make_node('Slice', [name+'_data_shape', name+'_axis_max', name+'_data_dim'], [name+'_slice0_out']), + ] + + else: + nodes += [ + create_tensor([axis], name+'_axis', kwargs["initializer"]), + create_tensor([axis+1], name+'_axis+1', kwargs["initializer"]), + make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis+1'], [name+'_slice0_out']), + ] + + if mode == 'clip': + nodes += [ + create_tensor([0], name+'_0', kwargs["initializer"]), + make_node('Add', [name+'_slice0_out', name+'_-1'], [name+'_max']), + make_node('Greater', [name+'_indices', name+'_max'], [name+'_max_mask']), + make_node('Where', [name+'_max_mask', name+'_max', name+'_indices'], [name+'_where0_out']), + make_node('Less', [name+'_indices', name+'_0'], [name+'_min_mask']), + make_node('Where', [name+'_min_mask', name+'_0', name+'_where0_out'], [name+'_where1_out']), + make_node('Gather', [data, name+'_where1_out'], [name], axis=axis, name=name) + ] + + elif mode == 'wrap': + nodes += [ + make_node('Mod', [name+'_indices', name+'_slice0_out'], [name+'_mod0_out']), + make_node('Gather', [data, name+'_mod0_out'], [name], axis=axis, name=name) + ] + + else: + raise NotImplementedError("mode must be clip, wrap or raise.") + + return nodes @mx_op.register("LayerNorm") diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 7dc5b5c9c6d3..af4160dd12d3 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1112,3 +1112,15 @@ def test_onnx_export_argsort(tmp_path, dtype, axis, is_ascend, dtype_i): kwargs['is_ascend'] = is_ascend M = def_model('argsort', axis=axis, dtype=dtype_i, **kwargs) op_export_test('argsort', M, [A], tmp_path) + + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2]) +@pytest.mark.parametrize('mode', ['clip', 'wrap']) +def test_onnx_export_take(tmp_path, dtype, axis, mode): + x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype) + y = mx.random.randint(-100, 100, (6, 7)).astype(dtype) + M1 = def_model('take') + op_export_test('take1', M1, [x, y], tmp_path) + M2 = def_model('take', axis=axis, mode=mode) + op_export_test('take2', M2, [x, y], tmp_path)