From 614793b08efc10bebf05cf189c6f0a030e9a5c6c Mon Sep 17 00:00:00 2001 From: Siva <sivar.b@huawei.com> Date: Tue, 9 Oct 2018 11:14:52 +0530 Subject: [PATCH] [RELAY][OP] take (#1863) --- docs/langref/relay_op.rst | 2 + include/tvm/relay/attrs/transform.h | 9 +++ nnvm/src/top/tensor/transform.cc | 2 +- python/tvm/relay/op/transform.py | 23 +++++++ src/relay/op/tensor/transform.cc | 89 ++++++++++++++++++++++++++++ tests/python/relay/test_op_level3.py | 22 +++++++ 6 files changed, 146 insertions(+), 1 deletion(-) diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 0ac6851ba9de..d5f92f567b17 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -73,6 +73,7 @@ This level enables additional math and transform operators. tvm.relay.round tvm.relay.abs tvm.relay.negative + tvm.relay.take @@ -143,6 +144,7 @@ Level 3 Definitions .. autofunction:: tvm.relay.reshape .. autofunction:: tvm.relay.copy .. autofunction:: tvm.relay.transpose +.. autofunction:: tvm.relay.take Level 3 Definitions ------------------- diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index d501e6cb7255..5c4cbca4a4a8 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -59,6 +59,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> { } }; // struct ReshapeAttrs +struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> { + IndexExpr axis; + + TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { + TVM_ATTR_FIELD(axis).set_default(NullValue<IndexExpr>()) + .describe("The axis over which to select values."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 40c8c930a029..270172856a75 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -1135,7 +1135,7 @@ Examples:: .set_attr<FCorrectLayout>("FCorrectLayout", TakeCorrectLayout) .set_num_inputs(2) .set_num_outputs(1) -.set_support_level(1) +.set_support_level(3) .set_attr<FTVMCompute>( "FTVMCompute", [](const NodeAttrs& attrs, const Array<Tensor>& inputs, diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index b530883d006c..830c1b18e42c 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -116,3 +116,26 @@ def reshape(data, newshape): if isinstance(newshape, int): newshape = [newshape] return _make.reshape(data, list(newshape)) + + +def take(data, indices, axis=None): + """Take elements from an array along an axis. + + Parameters + ---------- + a : relay.Expr + The source array. + + indices : rely.Expr + The indices of the values to extract. + + axis : int, optional + The axis over which to select values. By default, + the flattened input array is used. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.take(data, indices, axis) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f85fd706a52f..ac9763a0f562 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -315,5 +315,94 @@ Example:: .set_support_level(3) .add_type_rel("Reshape", ReshapeRel); +// Take +TVM_REGISTER_NODE_TYPE(TakeAttrs); + +bool TakeRel(const Array<Type>& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, indices, result] + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as<TensorTypeNode>(); + CHECK(data != nullptr); + const auto* indices = types[1].as<TensorTypeNode>(); + CHECK(indices != nullptr); + const auto param = attrs.as<TakeAttrs>(); + CHECK(param != nullptr); + + if (!param->axis.defined()) { + std::vector<IndexExpr>&& oshape = AsVector(indices->shape); + reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + return true; + } + + std::vector<IndexExpr> oshape; + const auto ndim_data = static_cast<int>(data->shape.size()); + const auto ndim_indices = static_cast<int>(indices->shape.size()); + auto axis = (*as_const_int(param->axis)); + if (axis < 0) axis += ndim_data; + CHECK_LE(axis, ndim_data) + << "axis should be with in data shape" + << ", but got = " << axis; + + oshape.reserve(ndim_data - 1 + ndim_indices); + for (int i = 0; i < axis; ++i) { + oshape.emplace_back(data->shape[i]); + } + for (int i = 0; i < ndim_indices; ++i) { + oshape.emplace_back(indices->shape[i]); + } + for (int i = axis+1; i < ndim_data; ++i) { + oshape.emplace_back(data->shape[i]); + } + + reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +Expr MakeTake(Expr data, + Expr indices, + IndexExpr axis) { + auto attrs = make_node<TakeAttrs>(); + attrs->axis = axis; + static const Op& op = Op::Get("take"); + return CallNode::make(op, {data, indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.take") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call<Expr, 3>(MakeTake, args, rv); +}); + +RELAY_REGISTER_OP("take") +.describe(R"code(Take elements from an array along an axis. + +When axis is not None, this function does the same thing as 'fancy' indexing +(indexing arrays using arrays); however, it can be easier to use if you need +elements along a given axis. + +**Note** that when axis is none the flattened input array is used. + +Examples:: + + a = [[ 1, 2], + [ 3, 4]] + indices = [3, 0, 2] + take(a, indices) = [ 4, 1, 3] + + a = [[ 1., 2.], + [ 3., 4.]] + indices = [1, 0] + take(a, indices, axis=1) = [[ 2., 1.], + [ 4., 3.]] + +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("indices", "Tensor", "The indices tensor.") +.set_support_level(2) +.add_type_rel("Take", TakeRel); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index c6b83b39c276..55717bbe23df 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -91,6 +91,27 @@ def check_single_op(opfunc): tvm.relay.round, tvm.relay.abs, tvm.relay.negative]: check_single_op(opfunc) +def test_take_infer_type(): + def verify_take(dshape, indices_shape, oshape, axis=None): + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.ty.TensorType(dshape, "float32")) + indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32")) + with ib.function(x, indices) as func: + ib.ret(relay.take(x.var, indices.var, axis=axis)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType(oshape, "float32") + + d1, d2, d3 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3") + d4, d5, d6 = tvm.var("d4"), tvm.var("d5"), tvm.var("d6") + verify_take((d1,), (1,), (1,), 0) + verify_take((4,), (d1, d2), (d1, d2)) + verify_take((3, 3, 3), (1, d2), (1, d2)) + verify_take((d1, d2), (d3, d4, d5), (d3, d4, d5, d2), 0) + verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1) + verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) + if __name__ == "__main__": test_single_op() @@ -99,3 +120,4 @@ def check_single_op(opfunc): test_copy_infer_type() test_transpose_infer_type() test_reshape_infer_type() + test_take_infer_type()