From af1d611ed2321a484f4328474893520de64a8c15 Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Mon, 11 Feb 2019 13:10:40 -0800 Subject: [PATCH 1/3] ONNX export: Support equal length splits --- .../contrib/onnx/mx2onnx/_op_translations.py | 6 ++++-- .../mxnet/contrib/onnx/mx2onnx/export_onnx.py | 21 ++++++++++--------- .../contrib/onnx/onnx2mx/_op_translations.py | 6 ++++-- tests/python-pytest/onnx/test_cases.py | 4 ++-- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 8e3c46dceb42..32a8a3e3ca82 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1537,12 +1537,14 @@ def convert_slice_channel(node, **kwargs): ) return [node] elif squeeze_axis == 0 and num_outputs > 1: + in_shape = kwargs.get('in_shape')[0] + split = in_shape[axis] // num_outputs node = onnx.helper.make_node( "Split", input_nodes, - [name], + [name+'_output'+str(i) for i in range(num_outputs)], axis=axis, - split=[num_outputs], + split=[split for _ in range(num_outputs)], name=name, ) return [node] diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index d0d4501d89f4..5c636f83ac66 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -262,17 +262,18 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) # If converted node is NodeProto, add it in processed nodes list elif isinstance(converted_node, NodeProto): onnx_processed_nodes.append(converted_node) - node_name = converted_node.name if converted_node.name else converted_node.output[0] - if node_name in graph_outputs: - onnx_processed_outputs.append( - make_tensor_value_info( - name=node_name, - elem_type=in_type, - shape=graph_outputs[node_name] + node_names = list(converted_node.output) + for nodename in node_names: + if nodename in graph_outputs: + onnx_processed_outputs.append( + make_tensor_value_info( + name=nodename, + elem_type=in_type, + shape=graph_outputs[nodename] + ) ) - ) - if verbose: - logging.info("Output node is: %s", converted_node.name) + if verbose: + logging.info("Output node is: %s", nodename) elif isinstance(converted_node, TensorProto): raise ValueError("Did not expect TensorProto") else: diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index dc00feee815b..37905bad7358 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -484,13 +484,15 @@ def split(attrs, inputs, proto_obj): if not split_list: num_outputs = len(proto_obj.model_metadata.get('output_tensor_data')) else: - raise NotImplementedError("Operator {} in MXNet does not support variable splits." + if len(set(split_list)) == 1: + num_outputs = len(split_list) + else: + raise NotImplementedError("Operator {} in MXNet does not support variable splits." "Tracking the issue to support variable split here: " "https://github.com/apache/incubator-mxnet/issues/11594" .format('split')) new_attrs['num_outputs'] = num_outputs - return 'split', new_attrs, inputs def _slice(attrs, inputs, proto_obj): diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index b20db23aa1fd..89b60d15e84f 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -77,7 +77,8 @@ 'test_elu', 'test_max_', 'test_softplus', - 'test_reduce_' + 'test_reduce_', + 'test_split_equal' ], 'import': ['test_gather', 'test_softsign', @@ -88,7 +89,6 @@ 'test_averagepool_2d_precomputed_strides', 'test_averagepool_2d_strides', 'test_averagepool_3d', - 'test_split_equal', 'test_hardmax' ], 'export': ['test_random_uniform', From 8c901e80082820e1adbe3ae0290de398ca72b63a Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Mon, 11 Feb 2019 14:05:39 -0800 Subject: [PATCH 2/3] Fix lint error --- python/mxnet/contrib/onnx/onnx2mx/_op_translations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 37905bad7358..a7cef7674496 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -488,9 +488,9 @@ def split(attrs, inputs, proto_obj): num_outputs = len(split_list) else: raise NotImplementedError("Operator {} in MXNet does not support variable splits." - "Tracking the issue to support variable split here: " - "https://github.com/apache/incubator-mxnet/issues/11594" - .format('split')) + "Tracking the issue to support variable split here: " + "https://github.com/apache/incubator-mxnet/issues/11594" + .format('split')) new_attrs['num_outputs'] = num_outputs return 'split', new_attrs, inputs From cafe86574d05a3d506ceb04149d67ee375849aaa Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Tue, 12 Feb 2019 10:32:02 -0800 Subject: [PATCH 3/3] Add comment about checking for multiple outputs --- python/mxnet/contrib/onnx/mx2onnx/export_onnx.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 5c636f83ac66..a7b11fc902db 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -262,6 +262,8 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) # If converted node is NodeProto, add it in processed nodes list elif isinstance(converted_node, NodeProto): onnx_processed_nodes.append(converted_node) + # some operators have multiple outputs, + # therefore, check all output node names node_names = list(converted_node.output) for nodename in node_names: if nodename in graph_outputs: