diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 5c8e726f7be3..e6aa5f1fb05d 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -127,6 +127,17 @@ def _mx_unravel_index(inputs, attrs): return _op.unravel_index(inputs[0], shape_expr) +def _mx_swap_axis(inputs, attrs): + assert len(inputs) == 1 + dim1 = attrs.get_int('dim1') + dim2 = attrs.get_int('dim2') + shape = _infer_type(inputs[0]).checked_type.shape + axes = list(range(len(shape))) + axes[dim1] = dim2 + axes[dim2] = dim1 + return _op.transpose(inputs[0], axes=axes) + + def _mx_zeros(inputs, attrs): assert len(inputs) == 0 shape = attrs.get_int_tuple("shape") @@ -1813,6 +1824,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "slice_axis" : _mx_slice_axis, "SliceChannel" : _mx_split, "split" : _mx_split, + "SwapAxis" : _mx_swap_axis, "expand_dims" : _mx_expand_dims, "Concat" : _mx_concat, "concat" : _mx_concat, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index eb308c574858..4a9848e03b5e 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -983,6 +983,18 @@ def verify(x, shape, dtype): # verify([0, 1, 2, 5], [2, 2], dtype) +def test_forward_swap_axis(): + def _verify_swap_axis(in_shape, out_shape, dim1, dim2): + data = mx.sym.var('data') + mx_sym = mx.sym.swapaxes(data, dim1, dim2) + verify_mxnet_frontend_impl(mx_sym, in_shape, out_shape) + + _verify_swap_axis((4, 5), (5, 4), 0, 1) + _verify_swap_axis((2, 4, 4, 5), (2, 5, 4, 4), 1, 3) + # MXNet errors out when dim1 == dim2 + # _verify_swap_axis((4, 5), (5, 4), 0, 0) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -1040,3 +1052,4 @@ def verify(x, shape, dtype): test_forward_cond() test_forward_make_loss() test_forward_unravel_index() + test_forward_swap_axis() \ No newline at end of file