diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index ebb58d972b8b..6d6116367b29 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -118,6 +118,7 @@ def convert_string_to_list(string_val): return result_list + def get_boolean_attribute_value(attrs, attr_name): """ Helper function to convert a string version of Boolean attributes to integer for ONNX. @@ -126,21 +127,35 @@ def get_boolean_attribute_value(attrs, attr_name): """ return 1 if attrs.get(attr_name, 0) in ["True", "1"] else 0 -def get_inputs(node, kwargs): + +def get_inputs(node, kwargs, with_shapes=False): """Helper function to get inputs""" name = node["name"] proc_nodes = kwargs["proc_nodes"] index_lookup = kwargs["index_lookup"] + graph_shapes = kwargs["graph_shapes"] inputs = node["inputs"] attrs = node.get("attrs", {}) input_nodes = [] + input_shapes = [] for ip in inputs: input_node_id = index_lookup[ip[0]] - input_nodes.append(proc_nodes[input_node_id].name) + try: + # ip[1] defines which output index to use + input_nodes.append(proc_nodes[input_node_id].output[ip[1]]) + except AttributeError: + # fallback to the name attribute as output if the output attribute does not exist (e.g. for data nodes) + input_nodes.append(proc_nodes[input_node_id].name) + + input_shapes.append(graph_shapes.get(input_nodes[-1])) + + if with_shapes: + return name, input_nodes, input_shapes, attrs return name, input_nodes, attrs + def create_basic_op_node(op_name, node, kwargs): """Helper function to create a basic operator node that doesn't contain op specific attrs""" @@ -154,6 +169,7 @@ def create_basic_op_node(op_name, node, kwargs): ) return [node] + @mx_op.register("null") def convert_weights_and_inputs(node, **kwargs): """Helper function to convert weights and inputs. @@ -1565,7 +1581,7 @@ def convert_slice_axis(node, **kwargs): """Map MXNet's slice_axis operator attributes to onnx's Slice operator and return the created node. """ - name, input_nodes, attrs = get_inputs(node, kwargs) + name, input_nodes, input_shapes, attrs = get_inputs(node, kwargs, with_shapes=True) axes = int(attrs.get("axis")) starts = int(attrs.get("begin")) @@ -1573,7 +1589,7 @@ def convert_slice_axis(node, **kwargs): if not ends or ends == 'None': # ONNX doesn't support None for ends. Since ends=None depicts # length of dimension, passing dimension in this case. - in_shape = kwargs['in_shape'][0] + in_shape = input_shapes[0] ends = in_shape[axes] export_nodes = [] @@ -1612,7 +1628,7 @@ def convert_slice_channel(node, **kwargs): operator based on squeeze_axis attribute and return the created node. """ - name, input_nodes, attrs = get_inputs(node, kwargs) + name, input_nodes, input_shapes, attrs = get_inputs(node, kwargs, with_shapes=True) num_outputs = int(attrs.get("num_outputs")) axis = int(attrs.get("axis", 1)) @@ -1628,7 +1644,7 @@ def convert_slice_channel(node, **kwargs): ) return [node] elif squeeze_axis == 0 and num_outputs > 1: - in_shape = kwargs.get('in_shape')[0] + in_shape = input_shapes[0] split = in_shape[axis] // num_outputs node = onnx.helper.make_node( "Split", diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 14aa52b29c6a..8e36685b2d40 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -116,7 +116,7 @@ def split_params(sym, params): return arg_params, aux_params @staticmethod - def get_outputs(sym, params, in_shape, in_label): + def get_outputs(sym, params, in_shape, in_label, verbose=True): """ Infer output shapes and return dictionary of output name to shape :param :class:`~mxnet.symbol.Symbol` sym: symbol to perform infer shape on @@ -124,6 +124,7 @@ def get_outputs(sym, params, in_shape, in_label): :param list of tuple(int, ...) in_shape: list of all input shapes :param in_label: name of label typically used in loss that may be left in graph. This name is removed from list of inputs required by symbol + :param verbose: If false, info logging messages are deactivated :return: dictionary of output name to shape :rtype: dict of (str, tuple(int, ...)) """ @@ -142,7 +143,8 @@ def get_outputs(sym, params, in_shape, in_label): if name.endswith('_output'): out_names.append(name[:-len('_output')]) else: - logging.info("output '%s' does not end with '_output'", name) + if verbose: + logging.info("output '%s' does not end with '_output'", name) out_names.append(name) assert len(out_shapes) == len(out_names) @@ -203,8 +205,9 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) onnx_processed_outputs = [] index_lookup = [] - # Determine output shape + # Determine output and internal shapes graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label) + graph_shapes = MXNetGraph.get_outputs(sym.get_internals(), params, in_shape, output_label, verbose=False) graph_input_idx = 0 for idx, node in enumerate(mx_graph): @@ -230,6 +233,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) in_shape=in_shape[graph_input_idx], in_type=in_type, proc_nodes=all_processed_nodes, + graph_shapes=graph_shapes, initializer=initializer, index_lookup=index_lookup) graph_input_idx += 1 @@ -244,6 +248,7 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False) in_shape=in_shape, in_type=in_type, proc_nodes=all_processed_nodes, + graph_shapes=graph_shapes, initializer=initializer, index_lookup=index_lookup, idx=idx diff --git a/tests/python/unittest/onnx/mxnet_export_test.py b/tests/python/unittest/onnx/mxnet_export_test.py index 40d7d4e3e072..3b3f1c5ba1ee 100644 --- a/tests/python/unittest/onnx/mxnet_export_test.py +++ b/tests/python/unittest/onnx/mxnet_export_test.py @@ -28,6 +28,7 @@ from mxnet import nd, sym from mxnet.test_utils import set_default_context from mxnet.gluon import nn +from mxnet.gluon import HybridBlock from mxnet.contrib import onnx as onnx_mxnet import mxnet as mx @@ -80,6 +81,16 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params= mx.test_utils.assert_almost_equal(out, imp_out, atol=1e-5, rtol=1e-5) +class SplitConcatBlock(HybridBlock): + """Block which creates two splits and later concatenates them""" + def __init__(self, name): + super(SplitConcatBlock, self).__init__(name) + + def hybrid_forward(self, F, x): + splits = F.split(x, axis=1, num_outputs=2) + return F.concat(*splits) + + class TestExport(unittest.TestCase): """ Tests ONNX export. """ @@ -126,3 +137,17 @@ def test_onnx_export_extra_params(self): net.add(nn.Dense(100, activation='relu'), nn.Dense(10)) _check_onnx_export(net, extra_params={'extra_param': nd.array([1, 2])}) + @with_seed() + def test_onnx_export_slice(self): + net = nn.HybridSequential(prefix='slice_net') + with net.name_scope(): + net.add(nn.Dense(100, activation='relu'), SplitConcatBlock("splitConcat"), nn.Dense(10)) + _check_onnx_export(net) + + @with_seed() + def test_onnx_export_slice_changing_shape(self): + net = nn.HybridSequential(prefix='slice_net_changing_shape') + with net.name_scope(): + net.add(nn.Dense(100, activation='relu'), SplitConcatBlock("splitConcat"), + nn.Dense(50, activation='relu'), SplitConcatBlock("splitConcat2"), nn.Dense(10)) + _check_onnx_export(net)