diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4ca875023f912..a0ea8f2e60a36 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1137,6 +1137,12 @@ bool TileRel(const Array& types, CHECK(reps.defined()) << "repetition array is not defined. data.ndim = " << ndim; const size_t rndim = reps.size(); + for (size_t i = 0; i < rndim; ++i) { + if (const tvm::ir::IntImm* val = reps[i].as()) { + CHECK_GT(val->value, 0) + << "Tile reps value should always be larger than 0, but get: " << val->value; + } + } size_t tndim = (ndim > rndim) ? ndim : rndim; // re-construct data shape or reps shape std::vector data_shape; diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index ec92c0a9a0eb8..1933c97c48b61 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -504,9 +504,9 @@ def verify_tile(dshape, reps): intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - verify_tile((2, 3, 4), (0, 2, 1)) - verify_tile((2, 3, 4), (0, 2)) - verify_tile((2, 3), (0, 2, 1)) + verify_tile((2, 3, 4), (3, 2, 1)) + verify_tile((2, 3, 4), (1, 2)) + verify_tile((2, 3), (3, 2, 1)) def test_repeat(): def verify_repeat(dshape, repeats, axis): @@ -525,21 +525,22 @@ def verify_repeat(dshape, repeats, axis): def test_stack(): def verify_stack(dshapes, axis): - dshapes = np.array(dshapes, dtype="int32") - x = relay.var("input_shapes", relay.TensorType(dshapes.shape, "int32")) + y = [] + for shape in dshapes: + y.append(relay.var("input", relay.TensorType(shape, "float32"))) + x = relay.Tuple(y) z = relay.stack(x, axis=axis) - func = relay.Function([x], z) + func = relay.Function(y, z) x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes] ref_res = np.stack(x_data, axis=axis) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x_data) + op_res = intrp.evaluate(func)(*x_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - verify_stack([[2,], [2,], [2,]], -1) - verify_stack([(2,), (2,), (2,)], 1) + verify_stack([(2,), (2,), (2,)], -1) verify_stack([(2,), (2,), (2,)], 0) verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1) verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1) @@ -591,5 +592,5 @@ def verify_reverse(dshape, axis): test_arange() test_reverse() test_stack() - test_repeat() test_tile() + test_repeat()