From c2ab10263e607c139956d3008f87185223785ce0 Mon Sep 17 00:00:00 2001 From: waytrue17 <52505574+waytrue17@users.noreply.github.com> Date: Tue, 25 May 2021 17:45:42 -0700 Subject: [PATCH] fix embedding and output order (#20305) Co-authored-by: Wei Chu --- python/mxnet/onnx/mx2onnx/_export_onnx.py | 13 +++++++++++-- .../_op_translations/_op_translations_opset12.py | 3 ++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py index 3af870e6a557..20bf2fe7f980 100644 --- a/python/mxnet/onnx/mx2onnx/_export_onnx.py +++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py @@ -392,8 +392,17 @@ def __init__(self, name, dtype): # if node_output_names is empty then we use the last returned node as output if not node_output_names: node_output_names = [converted[-1].name] - # process node outputs (sort by alphabetical order) - node_output_names.sort() + # process node outputs (sort by output index) + def str2int(s): + import re + i = re.search(r'\d{0,2}$', s).group() + if i == '': + return 0 + else: + return int(i) + + sorted(node_output_names, key=str2int) + # match the output names to output dtypes if dtypes is not None: assert len(node_output_names) == len(dtypes) diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py index b73c5bf837cc..4799a1df5837 100644 --- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py +++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py @@ -3094,13 +3094,14 @@ def convert_embedding(node, **kwargs): name, input_nodes, attrs = get_inputs(node, kwargs) axis = int(attrs.get('axis', 0)) + dtype = str(attrs.get('dtype', 'float32')) nodes = [ make_node('Cast', [input_nodes[0]], [name+'_indices_casted'], to=int(TensorProto.INT64)), make_node('Gather', [input_nodes[1], name+'_indices_casted'], [name], axis=axis, name=name) ] - return nodes + return nodes, (dtype, ) @mx_op.register("stack")