From efa3eb2d527e242c3063acb520cf47fb613d2616 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Thu, 4 Feb 2021 14:49:20 -0800 Subject: [PATCH] new cases (#19835) --- .../contrib/onnx/mx2onnx/_op_translations.py | 38 ++++++++++++++++++- tests/python-pytest/onnx/test_operators.py | 8 ++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index b9a7ef03528c..7e0a24e04e67 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1623,9 +1623,17 @@ def convert_reshape(node, **kwargs): # In general -2, -3, -4 in the target shape are not supoorted, but there are # a few special cases that we can convert to supported scenarios - # If -2 and -3 are not used, then we can just remove the -4 + # If -2 and -3 are not used and there is no 0 to the right of -4, then we can just remove -4 if -4 in targ_shape and -3 not in targ_shape and -2 not in targ_shape and reverse != 'True': - targ_shape = [i for i in targ_shape if i != -4] + if 0 not in targ_shape: + targ_shape = [i for i in targ_shape if i != -4] + else: + # index of first -4 + ind_4 = targ_shape.index(-4) + # index of last 0 + ind0 = len(targ_shape) - 1 - targ_shape[::-1].index(0) + if ind_4 > ind0: + targ_shape = [i for i in targ_shape if i != -4] if targ_shape == [-3, 0] and reverse != 'True': targ_shape = [-1, 0] @@ -1645,6 +1653,32 @@ def convert_reshape(node, **kwargs): ] return nodes + if targ_shape == [0, -4, -1, 4, 0, 0] and reverse != 'True': + create_tensor([4], name+'_4', kwargs['initializer']) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2', + name+'_dim3'], axis=0), + make_node('Div', [name+'_dim1', name+'_4'], [name+'_div']), + make_node('Concat', [name+'_dim0', name+'_div', name+'_4', name+'_dim2', name+'_dim3'], + [name+'_shape_new'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) + ] + return nodes + + if targ_shape == [0, 0, -4, 2, 2, 0, 0] and reverse != 'True': + create_tensor([2], name+'_2', kwargs['initializer']) + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2', + name+'_dim3', name+'_dim4'], axis=0), + make_node('Concat', [name+'_dim0', name+'_dim1', name+'_2', name+'_2', + name+'_dim3', name+'_dim4'], [name+'_shape_new'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) + ] + return nodes + + not_supported_shape = [-2, -3, -4] for val in targ_shape: if val in not_supported_shape: diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index de3b1e93461d..77c317b80557 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -235,6 +235,14 @@ def test_onnx_export_reshape_special_cases(tmp_path, dtype): M5 = def_model('reshape', shape=(0, 0, -3, -3)) op_export_test('reshape_spec_5', M5, [x3], tmp_path) + x4 = mx.nd.ones((5, 8, 6, 7), dtype=dtype) + M6 = def_model('reshape', shape=(0, -4, -1, 4, 0, 0)) + op_export_test('reshape_spec_6', M6, [x4], tmp_path) + + x5 = mx.nd.ones((2, 3, 4, 5, 6), dtype=dtype) + M7 = def_model('reshape', shape=(0, 0, -4, 2, 2, 0, 0)) + op_export_test('reshape_spec_7', M7, [x5], tmp_path) + @pytest.mark.parametrize('dtype', ['int32', 'int64']) def test_onnx_export_embedding(tmp_path, dtype):