Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

ONNX export: Support equal length splits #14121

Merged
merged 3 commits into from
Feb 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,12 +1537,14 @@ def convert_slice_channel(node, **kwargs):
)
return [node]
elif squeeze_axis == 0 and num_outputs > 1:
in_shape = kwargs.get('in_shape')[0]
split = in_shape[axis] // num_outputs
node = onnx.helper.make_node(
"Split",
input_nodes,
[name],
[name+'_output'+str(i) for i in range(num_outputs)],
axis=axis,
split=[num_outputs],
split=[split for _ in range(num_outputs)],
name=name,
)
return [node]
Expand Down
23 changes: 13 additions & 10 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,17 +262,20 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# If converted node is NodeProto, add it in processed nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
node_name = converted_node.name if converted_node.name else converted_node.output[0]
if node_name in graph_outputs:
onnx_processed_outputs.append(
make_tensor_value_info(
name=node_name,
elem_type=in_type,
shape=graph_outputs[node_name]
# some operators have multiple outputs,
# therefore, check all output node names
node_names = list(converted_node.output)
for nodename in node_names:
vandanavk marked this conversation as resolved.
Show resolved Hide resolved
if nodename in graph_outputs:
onnx_processed_outputs.append(
make_tensor_value_info(
name=nodename,
elem_type=in_type,
shape=graph_outputs[nodename]
)
)
)
if verbose:
logging.info("Output node is: %s", converted_node.name)
if verbose:
logging.info("Output node is: %s", nodename)
elif isinstance(converted_node, TensorProto):
raise ValueError("Did not expect TensorProto")
else:
Expand Down
12 changes: 7 additions & 5 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,15 @@ def split(attrs, inputs, proto_obj):
if not split_list:
num_outputs = len(proto_obj.model_metadata.get('output_tensor_data'))
else:
raise NotImplementedError("Operator {} in MXNet does not support variable splits."
"Tracking the issue to support variable split here: "
"https://github.com/apache/incubator-mxnet/issues/11594"
.format('split'))
if len(set(split_list)) == 1:
num_outputs = len(split_list)
else:
raise NotImplementedError("Operator {} in MXNet does not support variable splits."
"Tracking the issue to support variable split here: "
"https://github.com/apache/incubator-mxnet/issues/11594"
.format('split'))

new_attrs['num_outputs'] = num_outputs

return 'split', new_attrs, inputs

def _slice(attrs, inputs, proto_obj):
Expand Down
4 changes: 2 additions & 2 deletions tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@
'test_elu',
'test_max_',
'test_softplus',
'test_reduce_'
'test_reduce_',
'test_split_equal'
],
'import': ['test_gather',
'test_softsign',
Expand All @@ -88,7 +89,6 @@
'test_averagepool_2d_precomputed_strides',
'test_averagepool_2d_strides',
'test_averagepool_3d',
'test_split_equal',
'test_hardmax'
],
'export': ['test_random_uniform',
Expand Down