Skip to content

Commit

Permalink
tile data shape 0 error fixed; relay tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Mar 12, 2019
1 parent e2dbbf2 commit da6985e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
12 changes: 8 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ RELAY_REGISTER_OP("stack")
.set_attrs_type_key("relay.attrs.StackAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.")
.set_support_level(1)
.set_support_level(3)
.add_type_rel("Stack", StackRel)
.set_attr<FTVMCompute>("FTVMCompute", StackCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
Expand Down Expand Up @@ -1109,7 +1109,7 @@ RELAY_REGISTER_OP("repeat")
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.Repeat")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.set_support_level(3)
.add_type_rel("Repeat", RepeatRel)
.set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
Expand All @@ -1134,7 +1134,7 @@ bool TileRel(const Array<Type>& types,
const size_t ndim = data->shape.size();
const Array<Integer>& reps = param->reps;
// check dimension match
CHECK(!reps.defined())
CHECK(reps.defined())
<< "repetition array is not defined. data.ndim = " << ndim;
const size_t rndim = reps.size();
size_t tndim = (ndim > rndim) ? ndim : rndim;
Expand All @@ -1158,6 +1158,10 @@ bool TileRel(const Array<Type>& types,
} else {
for (size_t i = 0; i < rndim; ++i)
reps_shape.emplace_back(reps[i]);
for (size_t i = 0; i < (rndim - ndim); ++i)
data_shape.emplace_back(1);
for (size_t i = 0; i < ndim; ++i)
data_shape.emplace_back(data->shape[i]);
}
std::vector<IndexExpr> oshape;
oshape.reserve(tndim);
Expand Down Expand Up @@ -1199,7 +1203,7 @@ RELAY_REGISTER_OP("tile")
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.Tile")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.set_support_level(3)
.add_type_rel("Tile", TileRel)
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
Expand Down
58 changes: 58 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,61 @@ def verify_arange(start, stop, step):
verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5)

def test_tile():
def verify_tile(dshape, reps):
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.tile(x, reps=reps)

func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.tile(x_data, reps=reps)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_tile((2, 3, 4), (0, 2, 1))
verify_tile((2, 3, 4), (0, 2))
verify_tile((2, 3), (0, 2, 1))

def test_repeat():
def verify_repeat(dshape, repeats, axis):
x = relay.Var("x", relay.TensorType(dshape, "float32"))
func = relay.Function([x], relay.repeat(x, repeats, axis))
data = np.random.uniform(size=dshape).astype("float32")
ref_res = np.repeat(data, repeats, axis)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_repeat((3,), 2, 0)
verify_repeat((3, 10), 2, -1)
verify_repeat((3, 2, 4), 3, 1)

def test_stack():
def verify_stack(dshapes, axis):
dshapes = np.array(dshapes, dtype="int32")
x = relay.var("input_shapes", relay.TensorType(dshapes.shape, "int32"))
z = relay.stack(x, axis=axis)

func = relay.Function([x], z)
x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
ref_res = np.stack(x_data, axis=axis)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_stack([[2,], [2,], [2,]], -1)
verify_stack([(2,), (2,), (2,)], 1)
verify_stack([(2,), (2,), (2,)], 0)
verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)



if __name__ == "__main__":
test_cast()
Expand All @@ -515,3 +570,6 @@ def verify_arange(start, stop, step):
test_squeeze_bad_axes_infer_type()
test_split_infer_type()
test_arange()
test_stack()
test_repeat()
test_tile()
4 changes: 2 additions & 2 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def check_device(device):
check_device(device)


def verify_tranpose(in_shape, axes):
def verify_transpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes)
def check_device(device):
Expand Down Expand Up @@ -568,7 +568,7 @@ def check_device(device):
test_strided_slice()
test_concatenate()
test_stack()
test_tranpose()
test_transpose()
test_expand_dims()
test_reshape()
test_squeeze()
Expand Down

0 comments on commit da6985e

Please sign in to comment.