Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX support for broadcast_mod (#19770)
Browse files Browse the repository at this point in the history
* broadcast_mod

* Update _op_translations.py

* Update _op_translations.py

* improve tests
  • Loading branch information
Zha0q1 authored Jan 22, 2021
1 parent 93ae5cf commit c31bed1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
28 changes: 28 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3107,6 +3107,34 @@ def convert_contrib_AdaptiveAvgPooling2D(node, **kwargs):
return nodes


@mx_op.register('broadcast_mod')
def convert_broadcast_mod(node, **kwargs):
"""Map MXNet's broadcast_mod operator
"""
from onnx.helper import make_node
name, input_nodes, _ = get_inputs(node, kwargs)

# The behavior of MXNet mod is a mixture of np.mod and np.fmod
# note: the behavior when divison by 0 is supposed to be platform dependent
# but here we set the result to 0 to be consistent with MXNet
nodes = [
make_node('Sub', [input_nodes[1], input_nodes[1]], [name+'_zero']),
make_node('Mod', [input_nodes[0], input_nodes[1]], [name+'_mod'], fmod=1),
make_node('Less', [input_nodes[0], name+'_zero'], [name+'_mask_0']),
make_node('Less', [input_nodes[1], name+'_zero'], [name+'_mask_1']),
make_node('Equal', [name+'_mod', name+'_zero'], [name+'_mask_2_']),
make_node('Not', [name+'_mask_2_'], [name+'_mask_2']),
make_node('Xor', [name+'_mask_0', name+'_mask_1'], [name+'_mask_']),
make_node('And', [name+'_mask_', name+'_mask_2'], [name+'_mask']),
make_node('Where', [name+'_mask', input_nodes[1], name+'_zero'], [name+'_adjustment']),
make_node('Add', [name+'_mod', name+'_adjustment'], [name+'_adjusted']),
make_node('Equal', [input_nodes[1], name+'_zero'], [name+'_mask_div_0']),
make_node('Where', [name+'_mask_div_0', name+'_zero', name+'_adjusted'], [name])
]

return nodes


@mx_op.register("reshape_like")
def convert_reshape_like(node, **kwargs):
"""Map MXNet's reshape_like operator attributes to onnx's operator.
Expand Down
11 changes: 11 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,17 @@ def test_onnx_export_contrib_AdaptiveAvgPooling2D(tmp_path, dtype):
op_export_test('contrib_AdaptiveAvgPooling2D', M4, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32', 'int32', 'int64'])
@pytest.mark.parametrize('shapes', [((3, 3, 3), (1, 3)), ((4, 5, 6, 7), (6, 7))])
def test_onnx_export_broadcast_mod(tmp_path, dtype, shapes):
A = mx.nd.random.uniform(-300, 300, shapes[0]).astype(dtype)
B = mx.nd.random.uniform(-30, 30, shapes[1]).astype(dtype)
# test when dividend is zero
B[-1] = 0
M = def_model('broadcast_mod')
op_export_test('broadcast_mod', M, [A, B], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
def test_onnx_export_reshape_like(tmp_path, dtype):
if 'int' in dtype:
Expand Down

0 comments on commit c31bed1

Please sign in to comment.