diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 9e13b053894e..37fc5422400f 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1758,7 +1758,7 @@ def convert_slice_channel(node, **kwargs): num_outputs = int(attrs.get("num_outputs")) axis = int(attrs.get("axis", 1)) - squeeze_axis = int(attrs.get("squeeze_axis", 0)) + squeeze_axis = int(attrs.get("squeeze_axis", 0) in [1, 'True']) if squeeze_axis == 1 and num_outputs == 1: node = onnx.helper.make_node( @@ -1810,17 +1810,22 @@ def convert_squeeze(node, **kwargs): axis = attrs.get("axis", None) if not axis: - raise AttributeError("Squeeze: Missing axis attribute: ONNX currently requires axis to " - "be specified for squeeze operator") - axis = convert_string_to_list(axis) + node = onnx.helper.make_node( + "Squeeze", + input_nodes, + [name], + name=name + ) + else: + axis = convert_string_to_list(axis) - node = onnx.helper.make_node( - "Squeeze", - input_nodes, - [name], - axes=axis, - name=name, - ) + node = onnx.helper.make_node( + "Squeeze", + input_nodes, + [name], + axes=axis, + name=name, + ) return [node] @@ -3141,8 +3146,7 @@ def convert_greater_scalar(node, **kwargs): tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar]) nodes = [ - make_node("Shape", [input_nodes[0]], [name+"_shape"]), - make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value), + make_node("Constant", [], [name+"_rhs"], value=tensor_value), make_node("Greater", [input_nodes[0], name+"_rhs"], [name+"_gt"]), make_node("Cast", [name+"_gt"], [name], to=input_type, name=name) ] @@ -3171,14 +3175,41 @@ def convert_lesser_scalar(node, **kwargs): tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar]) nodes = [ - make_node("Shape", [input_nodes[0]], [name+"_shape"]), - make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value), + make_node("Constant", [], [name+"_rhs"], value=tensor_value), make_node("Less", [input_nodes[0], name+"_rhs"], [name+"_lt"]), make_node("Cast", [name+"_lt"], [name], to=input_type, name=name) ] return nodes +@mx_op.register("_equal_scalar") +def convert_equal_scalar(node, **kwargs): + """Map MXNet's equal_scalar operator attributes to onnx. + """ + from onnx.helper import make_node, make_tensor + name, input_nodes, attrs = get_inputs(node, kwargs) + + scalar = float(attrs.get('scalar')) + input_type = kwargs['in_type'] + dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] + + if str(dtype).startswith('int'): + scalar = int(scalar) + else: + if dtype == 'float16': + # when using float16, we must convert it to np.uint16 view first + # pylint: disable=too-many-function-args + scalar = np.float16(scalar).view(np.uint16) + + tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar]) + nodes = [ + make_node("Constant", [], [name+"_rhs"], value=tensor_value), + make_node("Equal", [input_nodes[0], name+"_rhs"], [name+"_eq"]), + make_node("Cast", [name+"_eq"], [name], to=input_type, name=name) + ] + return nodes + + @mx_op.register("where") def convert_where(node, **kwargs): """Map MXNet's where operator attributes to onnx's Where diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 8d09d9896b31..049f33e36bbc 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -511,6 +511,18 @@ def test_onnx_export_lesser_scalar(tmp_path, dtype, scalar): op_export_test('_internal._lesser_scalar', M, [x], tmp_path) +@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"]) +@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.]) +def test_onnx_export_equal_scalar(tmp_path, dtype, scalar): + if 'int' in dtype: + scalar = int(scalar) + x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4)) + else: + x = mx.random.uniform(0, 9999, (5,10), dtype=dtype) + M = def_model('_internal._equal_scalar', scalar=scalar) + op_export_test('_internal._equal_scalar', M, [x], tmp_path) + + @pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"]) @pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)]) def test_onnx_export_where(tmp_path, dtype, shape):