diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index edb23f624a113..26db88a4fdc46 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -317,7 +317,7 @@ def _mx_l2_normalize(inputs, attrs): if mode != 'channel': raise RuntimeError('mode %s is not supported.' % mode) new_attrs['eps'] = attrs.get_float('eps', 1e-10) - new_attrs['axis'] = 1 + new_attrs['axis'] = [1] return _op.nn.l2_normalize(inputs[0], **new_attrs) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 81a12b041ed7d..45e1cdba0a029 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -190,6 +190,21 @@ def test_forward_argmin(): mx_sym = mx.sym.argmin(data, axis=0) verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,)) +def test_forward_box_nms(): + data = mx.sym.var('data') + mx_sym = mx.sym.contrib.box_nms(data, topk=-1, force_suppress=True) + verify_mxnet_frontend_impl(mx_sym, (3, 100, 6), (3, 100, 6)) + +def test_forward_slice_axis(): + data = mx.sym.var('data') + mx_sym = mx.sym.slice_axis(data, axis=1, begin=-5, end=None) + verify_mxnet_frontend_impl(mx_sym, (1, 10, 6), (1, 5, 6)) + +def test_forward_l2_normalize(): + data = mx.sym.var('data') + mx_sym = mx.sym.L2Normalization(data, mode="channel") + verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5)) + if __name__ == '__main__': test_forward_mlp() @@ -212,3 +227,6 @@ def test_forward_argmin(): test_forward_zeros_like() test_forward_argmax() test_forward_argmin() + test_forward_box_nms() + test_forward_slice_axis() + test_forward_l2_normalize()