Skip to content

Commit

Permalink
make dynamic tile compatible with numpy API
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Jul 6, 2020
1 parent 946e6ef commit fe8ed08
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 21 deletions.
33 changes: 25 additions & 8 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,37 @@ def dynamic_reshape_shape_func(attrs, inputs, out_ndims):


@script
def _tile_shape_func(data, reps, ndim):
out = output_tensor((ndim,), "int64")
def _tile_shape_func(data, reps, ndim, tndim, rndim):
out = output_tensor((tndim,), "int64")

for i in const_range(ndim):
out[i] = data.shape[i] * int64(reps[i])
if ndim == rndim:
for i in const_range(tndim):
out[i] = int64(data.shape[i] * reps[i])
elif ndim > rndim:
ngap = ndim - rndim
for i in const_range(ndim):
if i < ngap:
out[i] = int64(data.shape[i])
else:
out[i] = int64(data.shape[i] * reps[i - ngap])
else:
rgap = rndim - ndim
for i in const_range(rndim):
if i < rgap:
out[i] = int64(reps[i])
else:
out[i] = int64(reps[i] * data.shape[i - rgap])
return out


@_reg.register_shape_func("dyn.tile", True)
def tile_shape_func(attrs, inputs, _):
"""
Shape function for tile op.
Shape function for dyn.tile op.
"""
reps = inputs[1]
ndim = len(inputs[0].shape)
rdim = inputs[1].shape[0].value
assert ndim == rdim, "tile data and reps ranks don't match"
return [_tile_shape_func(inputs[0], inputs[1], convert(ndim))]
rndim = inputs[1].shape[0].value
tndim = ndim if ndim > rndim else rndim
return [_tile_shape_func(inputs[0], reps, convert(ndim),
convert(tndim), convert(rndim))]
13 changes: 7 additions & 6 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ bool TileRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "tile: expect input type to be TensorType but get " << types[1];
return false;
}
const size_t ndim = data->shape.size();
const IntImmNode* reps_shape = reps->shape[0].as<IntImmNode>();
CHECK(reps_shape) << "Parameter reps must have static shape";
// check dimension match
CHECK_EQ(ndim, reps_shape->value) << "tile: the shape of reps must match the rank of data";
const size_t ndim = data->shape.size();
const size_t rndim = reps_shape->value;
size_t tndim = (ndim > rndim) ? ndim : rndim;
std::vector<IndexExpr> oshape;
oshape.reserve(ndim);
for (size_t i = 0; i < ndim; ++i) {
oshape.reserve(tndim);
for (size_t i = 0; i < tndim; ++i) {
oshape.emplace_back(Any());
}
reporter->Assign(types[2], TensorType(oshape, data->dtype));
Expand All @@ -167,7 +167,8 @@ Array<te::Tensor> TileCompute(const Attrs& attrs, const Array<te::Tensor>& input
const Type& out_type) {
CHECK_EQ(inputs.size(), 2);
const auto* out_ttype = out_type.as<TensorTypeNode>();
return {topi::dyn_tile(inputs[0], out_ttype->shape)};
size_t rndim = inputs[1]->shape[0].as<IntImmNode>()->value;
return {topi::dyn_tile(inputs[0], out_ttype->shape, rndim)};
}

Expr MakeTile(Expr data, Expr reps) {
Expand Down
6 changes: 3 additions & 3 deletions tests/python/relay/dyn/test_dynamic_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ def verify_reshape(shape, newshape, oshape):
def test_dyn_tile():
def verify_tile(dshape, reps):
x = relay.var("x", relay.TensorType(dshape, "float32"))
r = relay.var("reps", relay.TensorType((len(dshape), ), "float32"))
r = relay.var("reps", relay.TensorType((len(reps), ), "float32"))
z = relay.tile(x, r)

func = relay.Function([x, r], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.tile(x_data, reps=reps)
verify_func(func, [x_data, np.array(reps).astype("float32")], ref_res)
verify_tile((2, 3, 4), (3, 2, 1))
verify_tile((2, 3, 4), (1, 2, 1))
verify_tile((1, 2, 3), (3, 2, 1))
verify_tile((2, 3, 4), (1, 2))
verify_tile((2, 3), (3, 2, 1))

if __name__ == "__main__":
test_dyn_reshape()
Expand Down
14 changes: 10 additions & 4 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1026,8 +1026,8 @@ inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_t
*
* \return A Tensor whose op member is the tile operation
*/
inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, std::string name = "T_tile",
std::string tag = kBroadcast) {
inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, size_t rdim,
std::string name = "T_tile", std::string tag = kBroadcast) {
size_t ndim = x->shape.size();
if (is_empty_shape(new_shape)) {
return compute(
Expand All @@ -1037,8 +1037,14 @@ inline Tensor dyn_tile(const Tensor& x, Array<PrimExpr> new_shape, std::string n
new_shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> idx;
for (size_t i = 0; i < ndim; ++i) {
idx.push_back(indexmod(indices[i], x->shape[i]));
if (ndim >= rdim) {
for (size_t i = 0; i < ndim; ++i) {
idx.push_back(indexmod(indices[i], x->shape[i]));
}
} else {
for (size_t i = 0; i < ndim; ++i) {
idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i]));
}
}
return x(idx);
},
Expand Down

0 comments on commit fe8ed08

Please sign in to comment.