Skip to content

Commit

Permalink
[RELAY][OP] take (apache#1863)
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 authored and tqchen committed Oct 9, 2018
1 parent 6ffdd28 commit 614793b
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ This level enables additional math and transform operators.
tvm.relay.round
tvm.relay.abs
tvm.relay.negative
tvm.relay.take



Expand Down Expand Up @@ -143,6 +144,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take

Level 3 Definitions
-------------------
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}
}; // struct ReshapeAttrs

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
IndexExpr axis;

TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<IndexExpr>())
.describe("The axis over which to select values.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_TRANSFORM_H_
2 changes: 1 addition & 1 deletion nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ Examples::
.set_attr<FCorrectLayout>("FCorrectLayout", TakeCorrectLayout)
.set_num_inputs(2)
.set_num_outputs(1)
.set_support_level(1)
.set_support_level(3)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,26 @@ def reshape(data, newshape):
if isinstance(newshape, int):
newshape = [newshape]
return _make.reshape(data, list(newshape))


def take(data, indices, axis=None):
"""Take elements from an array along an axis.
Parameters
----------
a : relay.Expr
The source array.
indices : rely.Expr
The indices of the values to extract.
axis : int, optional
The axis over which to select values. By default,
the flattened input array is used.
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.take(data, indices, axis)
89 changes: 89 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,5 +315,94 @@ Example::
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel);

// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);

bool TakeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
const auto* indices = types[1].as<TensorTypeNode>();
CHECK(indices != nullptr);
const auto param = attrs.as<TakeAttrs>();
CHECK(param != nullptr);

if (!param->axis.defined()) {
std::vector<IndexExpr>&& oshape = AsVector(indices->shape);
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}

std::vector<IndexExpr> oshape;
const auto ndim_data = static_cast<int>(data->shape.size());
const auto ndim_indices = static_cast<int>(indices->shape.size());
auto axis = (*as_const_int(param->axis));
if (axis < 0) axis += ndim_data;
CHECK_LE(axis, ndim_data)
<< "axis should be with in data shape"
<< ", but got = " << axis;

oshape.reserve(ndim_data - 1 + ndim_indices);
for (int i = 0; i < axis; ++i) {
oshape.emplace_back(data->shape[i]);
}
for (int i = 0; i < ndim_indices; ++i) {
oshape.emplace_back(indices->shape[i]);
}
for (int i = axis+1; i < ndim_data; ++i) {
oshape.emplace_back(data->shape[i]);
}

reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}

Expr MakeTake(Expr data,
Expr indices,
IndexExpr axis) {
auto attrs = make_node<TakeAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("take");
return CallNode::make(op, {data, indices}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.take")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeTake, args, rv);
});

RELAY_REGISTER_OP("take")
.describe(R"code(Take elements from an array along an axis.
When axis is not None, this function does the same thing as 'fancy' indexing
(indexing arrays using arrays); however, it can be easier to use if you need
elements along a given axis.
**Note** that when axis is none the flattened input array is used.
Examples::
a = [[ 1, 2],
[ 3, 4]]
indices = [3, 0, 2]
take(a, indices) = [ 4, 1, 3]
a = [[ 1., 2.],
[ 3., 4.]]
indices = [1, 0]
take(a, indices, axis=1) = [[ 2., 1.],
[ 4., 3.]]
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.set_support_level(2)
.add_type_rel("Take", TakeRel);

} // namespace relay
} // namespace tvm
22 changes: 22 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ def check_single_op(opfunc):
tvm.relay.round, tvm.relay.abs, tvm.relay.negative]:
check_single_op(opfunc)

def test_take_infer_type():
def verify_take(dshape, indices_shape, oshape, axis=None):
ib = relay.ir_builder.IRBuilder()
x = ib.param("x", relay.ty.TensorType(dshape, "float32"))
indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32"))
with ib.function(x, indices) as func:
ib.ret(relay.take(x.var, indices.var, axis=axis))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(oshape, "float32")

d1, d2, d3 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3")
d4, d5, d6 = tvm.var("d4"), tvm.var("d5"), tvm.var("d6")
verify_take((d1,), (1,), (1,), 0)
verify_take((4,), (d1, d2), (d1, d2))
verify_take((3, 3, 3), (1, d2), (1, d2))
verify_take((d1, d2), (d3, d4, d5), (d3, d4, d5, d2), 0)
verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1)
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)


if __name__ == "__main__":
test_single_op()
Expand All @@ -99,3 +120,4 @@ def check_single_op(opfunc):
test_copy_infer_type()
test_transpose_infer_type()
test_reshape_infer_type()
test_take_infer_type()

0 comments on commit 614793b

Please sign in to comment.