Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
new cases (#19835)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zha0q1 authored Feb 4, 2021
1 parent 8216e20 commit efa3eb2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
38 changes: 36 additions & 2 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,9 +1623,17 @@ def convert_reshape(node, **kwargs):
# In general -2, -3, -4 in the target shape are not supoorted, but there are
# a few special cases that we can convert to supported scenarios

# If -2 and -3 are not used, then we can just remove the -4
# If -2 and -3 are not used and there is no 0 to the right of -4, then we can just remove -4
if -4 in targ_shape and -3 not in targ_shape and -2 not in targ_shape and reverse != 'True':
targ_shape = [i for i in targ_shape if i != -4]
if 0 not in targ_shape:
targ_shape = [i for i in targ_shape if i != -4]
else:
# index of first -4
ind_4 = targ_shape.index(-4)
# index of last 0
ind0 = len(targ_shape) - 1 - targ_shape[::-1].index(0)
if ind_4 > ind0:
targ_shape = [i for i in targ_shape if i != -4]

if targ_shape == [-3, 0] and reverse != 'True':
targ_shape = [-1, 0]
Expand All @@ -1645,6 +1653,32 @@ def convert_reshape(node, **kwargs):
]
return nodes

if targ_shape == [0, -4, -1, 4, 0, 0] and reverse != 'True':
create_tensor([4], name+'_4', kwargs['initializer'])
nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2',
name+'_dim3'], axis=0),
make_node('Div', [name+'_dim1', name+'_4'], [name+'_div']),
make_node('Concat', [name+'_dim0', name+'_div', name+'_4', name+'_dim2', name+'_dim3'],
[name+'_shape_new'], axis=0),
make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name)
]
return nodes

if targ_shape == [0, 0, -4, 2, 2, 0, 0] and reverse != 'True':
create_tensor([2], name+'_2', kwargs['initializer'])
nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2',
name+'_dim3', name+'_dim4'], axis=0),
make_node('Concat', [name+'_dim0', name+'_dim1', name+'_2', name+'_2',
name+'_dim3', name+'_dim4'], [name+'_shape_new'], axis=0),
make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name)
]
return nodes


not_supported_shape = [-2, -3, -4]
for val in targ_shape:
if val in not_supported_shape:
Expand Down
8 changes: 8 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ def test_onnx_export_reshape_special_cases(tmp_path, dtype):
M5 = def_model('reshape', shape=(0, 0, -3, -3))
op_export_test('reshape_spec_5', M5, [x3], tmp_path)

x4 = mx.nd.ones((5, 8, 6, 7), dtype=dtype)
M6 = def_model('reshape', shape=(0, -4, -1, 4, 0, 0))
op_export_test('reshape_spec_6', M6, [x4], tmp_path)

x5 = mx.nd.ones((2, 3, 4, 5, 6), dtype=dtype)
M7 = def_model('reshape', shape=(0, 0, -4, 2, 2, 0, 0))
op_export_test('reshape_spec_7', M7, [x5], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64'])
def test_onnx_export_embedding(tmp_path, dtype):
Expand Down

0 comments on commit efa3eb2

Please sign in to comment.