Skip to content

Commit

Permalink
[Bugfix] Repeat and tile bug fixed, relay tests added (apache#2804)
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly authored and wweic committed Mar 20, 2019
1 parent 5a6d2a3 commit 368f4d1
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 11 deletions.
18 changes: 14 additions & 4 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ RELAY_REGISTER_OP("stack")
.set_attrs_type_key("relay.attrs.StackAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input list of tensors.")
.set_support_level(1)
.set_support_level(3)
.add_type_rel("Stack", StackRel)
.set_attr<FTVMCompute>("FTVMCompute", StackCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
Expand Down Expand Up @@ -1109,7 +1109,7 @@ RELAY_REGISTER_OP("repeat")
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.Repeat")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.set_support_level(3)
.add_type_rel("Repeat", RepeatRel)
.set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
Expand All @@ -1134,9 +1134,15 @@ bool TileRel(const Array<Type>& types,
const size_t ndim = data->shape.size();
const Array<Integer>& reps = param->reps;
// check dimension match
CHECK(!reps.defined())
CHECK(reps.defined())
<< "repetition array is not defined. data.ndim = " << ndim;
const size_t rndim = reps.size();
for (size_t i = 0; i < rndim; ++i) {
if (const tvm::ir::IntImm* val = reps[i].as<tvm::ir::IntImm>()) {
CHECK_GT(val->value, 0)
<< "Tile reps value should always be larger than 0, but get: " << val->value;
}
}
size_t tndim = (ndim > rndim) ? ndim : rndim;
// re-construct data shape or reps shape
std::vector<IndexExpr> data_shape;
Expand All @@ -1158,6 +1164,10 @@ bool TileRel(const Array<Type>& types,
} else {
for (size_t i = 0; i < rndim; ++i)
reps_shape.emplace_back(reps[i]);
for (size_t i = 0; i < (rndim - ndim); ++i)
data_shape.emplace_back(1);
for (size_t i = 0; i < ndim; ++i)
data_shape.emplace_back(data->shape[i]);
}
std::vector<IndexExpr> oshape;
oshape.reserve(tndim);
Expand Down Expand Up @@ -1199,7 +1209,7 @@ RELAY_REGISTER_OP("tile")
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.Tile")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.set_support_level(3)
.add_type_rel("Tile", TileRel)
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
Expand Down
59 changes: 59 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,62 @@ def verify_arange(start, stop, step):
verify_arange(20, 1, -1)
verify_arange(20, 1, -1.5)

def test_tile():
def verify_tile(dshape, reps):
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.tile(x, reps=reps)

func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.tile(x_data, reps=reps)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_tile((2, 3, 4), (3, 2, 1))
verify_tile((2, 3, 4), (1, 2))
verify_tile((2, 3), (3, 2, 1))

def test_repeat():
def verify_repeat(dshape, repeats, axis):
x = relay.Var("x", relay.TensorType(dshape, "float32"))
func = relay.Function([x], relay.repeat(x, repeats, axis))
data = np.random.uniform(size=dshape).astype("float32")
ref_res = np.repeat(data, repeats, axis)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_repeat((3,), 2, 0)
verify_repeat((3, 10), 2, -1)
verify_repeat((3, 2, 4), 3, 1)

def test_stack():
def verify_stack(dshapes, axis):
y = []
for shape in dshapes:
y.append(relay.var("input", relay.TensorType(shape, "float32")))
x = relay.Tuple(y)
z = relay.stack(x, axis=axis)

func = relay.Function(y, z)
x_data = [np.random.normal(size=shape).astype("float32") for shape in dshapes]
ref_res = np.stack(x_data, axis=axis)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(*x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_stack([(2,), (2,), (2,)], -1)
verify_stack([(2,), (2,), (2,)], 0)
verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)



def test_reverse():
def verify_reverse(dshape, axis):
Expand Down Expand Up @@ -536,3 +592,6 @@ def verify_reverse(dshape, axis):
test_split_infer_type()
test_arange()
test_reverse()
test_stack()
test_tile()
test_repeat()
14 changes: 7 additions & 7 deletions topi/tests/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def check_device(device):
check_device(device)


def verify_tranpose(in_shape, axes):
def verify_transpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.transpose(A, axes)
def check_device(device):
Expand All @@ -40,7 +40,7 @@ def check_device(device):
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device, name="tranpose")
foo = tvm.build(s, [A, B], device, name="transpose")
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
out_npy = data_npy.transpose(axes)
data_nd = tvm.nd.array(data_npy, ctx)
Expand Down Expand Up @@ -416,10 +416,10 @@ def test_expand_dims():
verify_expand_dims((3, 10), (1, 3, 10), -3, 1)


def test_tranpose():
verify_tranpose((3, 10, 2), (1, 0, 2))
verify_tranpose((3, 10, 5), (2, 0, 1))
verify_tranpose((3, 10), None)
def test_transpose():
verify_transpose((3, 10, 2), (1, 0, 2))
verify_transpose((3, 10, 5), (2, 0, 1))
verify_transpose((3, 10), None)


def test_reshape():
Expand Down Expand Up @@ -595,7 +595,7 @@ def check_device(device):
test_strided_slice()
test_concatenate()
test_stack()
test_tranpose()
test_transpose()
test_expand_dims()
test_reshape()
test_squeeze()
Expand Down

0 comments on commit 368f4d1

Please sign in to comment.