From 6897580c3ff73d78fd721866bd5eca5ce9a6b984 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 7 Mar 2019 14:45:30 -0800 Subject: [PATCH] [Bugfix][Relay][Frontend] Fix bug in mxnet converter for slick_like (#2744) * Fix bug in mxnet converter for slick_like * More tolerance for topi_conv2d_NCHWc --- python/tvm/relay/frontend/mxnet.py | 9 +++++++- tests/python/frontend/mxnet/test_forward.py | 23 ++++++++++++++++++++- topi/tests/python/test_topi_conv2d_NCHWc.py | 4 ++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 45329e1b3fe5..2e0ccd07fdc1 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -194,6 +194,13 @@ def _mx_slice(inputs, attrs): return _op.strided_slice(inputs[0], **new_attrs) +def _mx_slice_like(inputs, attrs): + assert len(inputs) == 2 + new_attrs = {} + new_attrs["axes"] = attrs.get_int_tuple("axes", None) + return _op.slice_like(*inputs, **new_attrs) + + def _mx_slice_axis(inputs, attrs): assert len(inputs) == 1 shape = ir_pass.infer_type(inputs[0]).checked_type.shape @@ -383,7 +390,6 @@ def _mx_proposal(inputs, attrs): "exp", "negative", "reshape_like", - "slice_like", "zeros_like", "ones_like", "where", @@ -473,6 +479,7 @@ def _mx_proposal(inputs, attrs): "BatchNorm_v1" : _mx_batch_norm, "LRN" : _mx_lrn, "slice" : _mx_slice, + "slice_like" : _mx_slice_like, "slice_axis" : _mx_slice_axis, "SliceChannel" : _mx_split, "split" : _mx_split, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 74a87e29a0c0..2dfe20c503e6 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -336,7 +336,6 @@ def test_forward_scalar_ops(): op_res = intrp.evaluate(new_sym)(a_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) - def test_forward_slice_axis(): def verify(shape, axis, begin, end): data_np = np.random.uniform(size=shape).astype("float32") @@ -354,6 +353,27 @@ def verify(shape, axis, begin, end): verify((3, 4), 1, -3, -1) verify((3, 4), -1, -3, -1) +def test_forward_slice_like(): + def verify(x_shape, y_shape, axes): + x_np = np.random.uniform(size=x_shape).astype("float32") + y_np = np.random.uniform(size=y_shape).astype("float32") + if axes is None: + ref_res = mx.nd.slice_like(mx.nd.array(x_np), mx.nd.array(y_np)) + mx_sym = mx.sym.slice_like(mx.sym.var("x"), mx.sym.var("y")) + else: + ref_res = mx.nd.slice_like(mx.nd.array(x_np), mx.nd.array(y_np), axes=axes) + mx_sym = mx.sym.slice_like(mx.sym.var("x"), mx.sym.var("y"), axes=axes) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": x_shape, "y": y_shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x_np, y_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((3, 4), (2, 3), None) + verify((3, 4), (2, 3), (0, 1)) + verify((3, 4), (2, 3), (0)) + verify((3, 4), (2, 3), (-1)) + if __name__ == '__main__': test_forward_mlp() @@ -382,3 +402,4 @@ def verify(shape, axis, begin, end): test_forward_elemwise_ops() test_forward_scalar_ops() test_forward_slice_axis() + test_forward_slice_like() diff --git a/topi/tests/python/test_topi_conv2d_NCHWc.py b/topi/tests/python/test_topi_conv2d_NCHWc.py index a3af43c8d810..73c1fdae2d66 100644 --- a/topi/tests/python/test_topi_conv2d_NCHWc.py +++ b/topi/tests/python/test_topi_conv2d_NCHWc.py @@ -105,7 +105,7 @@ def check_device(device): name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func(a, w, c) - tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3) # test llvm only for now since conv2d_NCHWc implement is missing in other backend. for device in ["llvm"]: @@ -202,4 +202,4 @@ def test_conv2d_NCHWc(): verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1) if __name__ == "__main__": - test_conv2d_NCHWc() \ No newline at end of file + test_conv2d_NCHWc()