Skip to content

Commit

Permalink
[Bugfix][Relay][Frontend] Fix bug in mxnet converter for slick_like (a…
Browse files Browse the repository at this point in the history
…pache#2744)

* Fix bug in mxnet converter for slick_like

* More tolerance for topi_conv2d_NCHWc
  • Loading branch information
icemelon authored and wweic committed Mar 9, 2019
1 parent 789240b commit 6897580
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
9 changes: 8 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ def _mx_slice(inputs, attrs):
return _op.strided_slice(inputs[0], **new_attrs)


def _mx_slice_like(inputs, attrs):
assert len(inputs) == 2
new_attrs = {}
new_attrs["axes"] = attrs.get_int_tuple("axes", None)
return _op.slice_like(*inputs, **new_attrs)


def _mx_slice_axis(inputs, attrs):
assert len(inputs) == 1
shape = ir_pass.infer_type(inputs[0]).checked_type.shape
Expand Down Expand Up @@ -383,7 +390,6 @@ def _mx_proposal(inputs, attrs):
"exp",
"negative",
"reshape_like",
"slice_like",
"zeros_like",
"ones_like",
"where",
Expand Down Expand Up @@ -473,6 +479,7 @@ def _mx_proposal(inputs, attrs):
"BatchNorm_v1" : _mx_batch_norm,
"LRN" : _mx_lrn,
"slice" : _mx_slice,
"slice_like" : _mx_slice_like,
"slice_axis" : _mx_slice_axis,
"SliceChannel" : _mx_split,
"split" : _mx_split,
Expand Down
23 changes: 22 additions & 1 deletion tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def test_forward_scalar_ops():
op_res = intrp.evaluate(new_sym)(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())


def test_forward_slice_axis():
def verify(shape, axis, begin, end):
data_np = np.random.uniform(size=shape).astype("float32")
Expand All @@ -354,6 +353,27 @@ def verify(shape, axis, begin, end):
verify((3, 4), 1, -3, -1)
verify((3, 4), -1, -3, -1)

def test_forward_slice_like():
def verify(x_shape, y_shape, axes):
x_np = np.random.uniform(size=x_shape).astype("float32")
y_np = np.random.uniform(size=y_shape).astype("float32")
if axes is None:
ref_res = mx.nd.slice_like(mx.nd.array(x_np), mx.nd.array(y_np))
mx_sym = mx.sym.slice_like(mx.sym.var("x"), mx.sym.var("y"))
else:
ref_res = mx.nd.slice_like(mx.nd.array(x_np), mx.nd.array(y_np), axes=axes)
mx_sym = mx.sym.slice_like(mx.sym.var("x"), mx.sym.var("y"), axes=axes)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": x_shape, "y": y_shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np, y_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((3, 4), (2, 3), None)
verify((3, 4), (2, 3), (0, 1))
verify((3, 4), (2, 3), (0))
verify((3, 4), (2, 3), (-1))


if __name__ == '__main__':
test_forward_mlp()
Expand Down Expand Up @@ -382,3 +402,4 @@ def verify(shape, axis, begin, end):
test_forward_elemwise_ops()
test_forward_scalar_ops()
test_forward_slice_axis()
test_forward_slice_like()
4 changes: 2 additions & 2 deletions topi/tests/python/test_topi_conv2d_NCHWc.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def check_device(device):
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3)

# test llvm only for now since conv2d_NCHWc implement is missing in other backend.
for device in ["llvm"]:
Expand Down Expand Up @@ -202,4 +202,4 @@ def test_conv2d_NCHWc():
verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1)

if __name__ == "__main__":
test_conv2d_NCHWc()
test_conv2d_NCHWc()

0 comments on commit 6897580

Please sign in to comment.