Skip to content

Commit

Permalink
tests updated
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Mar 13, 2019
1 parent b74570a commit 9842d9f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
6 changes: 6 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,12 @@ bool TileRel(const Array<Type>& types,
CHECK(reps.defined())
<< "repetition array is not defined. data.ndim = " << ndim;
const size_t rndim = reps.size();
for (size_t i = 0; i < rndim; ++i) {
if (const tvm::ir::IntImm* val = reps[i].as<tvm::ir::IntImm>()) {
CHECK_GT(val->value, 0)
<< "Tile reps value should always be larger than 0, but get: " << val->value;
}
}
size_t tndim = (ndim > rndim) ? ndim : rndim;
// re-construct data shape or reps shape
std::vector<IndexExpr> data_shape;
Expand Down
21 changes: 11 additions & 10 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,9 @@ def verify_tile(dshape, reps):
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_tile((2, 3, 4), (0, 2, 1))
verify_tile((2, 3, 4), (0, 2))
verify_tile((2, 3), (0, 2, 1))
verify_tile((2, 3, 4), (3, 2, 1))
verify_tile((2, 3, 4), (1, 2))
verify_tile((2, 3), (3, 2, 1))

def test_repeat():
def verify_repeat(dshape, repeats, axis):
Expand All @@ -525,21 +525,22 @@ def verify_repeat(dshape, repeats, axis):

def test_stack():
def verify_stack(dshapes, axis):
dshapes = np.array(dshapes, dtype="int32")
x = relay.var("input_shapes", relay.TensorType(dshapes.shape, "int32"))
y = []
for shape in dshapes:
y.append(relay.var("input", relay.TensorType(shape, "float32")))
x = relay.Tuple(y)
z = relay.stack(x, axis=axis)

func = relay.Function([x], z)
func = relay.Function(y, z)
x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
ref_res = np.stack(x_data, axis=axis)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
op_res = intrp.evaluate(func)(*x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_stack([[2,], [2,], [2,]], -1)
verify_stack([(2,), (2,), (2,)], 1)
verify_stack([(2,), (2,), (2,)], -1)
verify_stack([(2,), (2,), (2,)], 0)
verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
Expand Down Expand Up @@ -571,5 +572,5 @@ def verify_stack(dshapes, axis):
test_split_infer_type()
test_arange()
test_stack()
test_repeat()
test_tile()
test_repeat()

0 comments on commit 9842d9f

Please sign in to comment.