From 614793b08efc10bebf05cf189c6f0a030e9a5c6c Mon Sep 17 00:00:00 2001
From: Siva <sivar.b@huawei.com>
Date: Tue, 9 Oct 2018 11:14:52 +0530
Subject: [PATCH] [RELAY][OP] take (#1863)

---
 docs/langref/relay_op.rst            |  2 +
 include/tvm/relay/attrs/transform.h  |  9 +++
 nnvm/src/top/tensor/transform.cc     |  2 +-
 python/tvm/relay/op/transform.py     | 23 +++++++
 src/relay/op/tensor/transform.cc     | 89 ++++++++++++++++++++++++++++
 tests/python/relay/test_op_level3.py | 22 +++++++
 6 files changed, 146 insertions(+), 1 deletion(-)

diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst
index 0ac6851ba9de..d5f92f567b17 100644
--- a/docs/langref/relay_op.rst
+++ b/docs/langref/relay_op.rst
@@ -73,6 +73,7 @@ This level enables additional math and transform operators.
    tvm.relay.round
    tvm.relay.abs
    tvm.relay.negative
+   tvm.relay.take
 
 
 
@@ -143,6 +144,7 @@ Level 3 Definitions
 .. autofunction:: tvm.relay.reshape
 .. autofunction:: tvm.relay.copy
 .. autofunction:: tvm.relay.transpose
+.. autofunction:: tvm.relay.take
 
 Level 3 Definitions
 -------------------
diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index d501e6cb7255..5c4cbca4a4a8 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -59,6 +59,15 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
   }
 };  // struct ReshapeAttrs
 
+struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
+  IndexExpr axis;
+
+  TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
+    TVM_ATTR_FIELD(axis).set_default(NullValue<IndexExpr>())
+        .describe("The axis over which to select values.");
+  }
+};
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc
index 40c8c930a029..270172856a75 100644
--- a/nnvm/src/top/tensor/transform.cc
+++ b/nnvm/src/top/tensor/transform.cc
@@ -1135,7 +1135,7 @@ Examples::
 .set_attr<FCorrectLayout>("FCorrectLayout", TakeCorrectLayout)
 .set_num_inputs(2)
 .set_num_outputs(1)
-.set_support_level(1)
+.set_support_level(3)
 .set_attr<FTVMCompute>(
     "FTVMCompute", [](const NodeAttrs& attrs,
                       const Array<Tensor>& inputs,
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index b530883d006c..830c1b18e42c 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -116,3 +116,26 @@ def reshape(data, newshape):
     if isinstance(newshape, int):
         newshape = [newshape]
     return _make.reshape(data, list(newshape))
+
+
+def take(data, indices, axis=None):
+    """Take elements from an array along an axis.
+
+    Parameters
+    ----------
+    a : relay.Expr
+        The source array.
+
+    indices : rely.Expr
+        The indices of the values to extract.
+
+    axis : int, optional
+        The axis over which to select values. By default,
+        the flattened input array is used.
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result.
+    """
+    return _make.take(data, indices, axis)
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index f85fd706a52f..ac9763a0f562 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -315,5 +315,94 @@ Example::
 .set_support_level(3)
 .add_type_rel("Reshape", ReshapeRel);
 
+// Take
+TVM_REGISTER_NODE_TYPE(TakeAttrs);
+
+bool TakeRel(const Array<Type>& types,
+             int num_inputs,
+             const Attrs& attrs,
+             const TypeReporter& reporter) {
+  // `types` contains: [data, indices, result]
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  CHECK(data != nullptr);
+  const auto* indices = types[1].as<TensorTypeNode>();
+  CHECK(indices != nullptr);
+  const auto param = attrs.as<TakeAttrs>();
+  CHECK(param != nullptr);
+
+  if (!param->axis.defined()) {
+    std::vector<IndexExpr>&& oshape = AsVector(indices->shape);
+    reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+    return true;
+  }
+
+  std::vector<IndexExpr> oshape;
+  const auto ndim_data = static_cast<int>(data->shape.size());
+  const auto ndim_indices = static_cast<int>(indices->shape.size());
+  auto axis = (*as_const_int(param->axis));
+  if (axis < 0) axis += ndim_data;
+  CHECK_LE(axis, ndim_data)
+    << "axis should be with in data shape"
+    << ", but got = " << axis;
+
+  oshape.reserve(ndim_data - 1 + ndim_indices);
+  for (int i = 0; i < axis; ++i) {
+    oshape.emplace_back(data->shape[i]);
+  }
+  for (int i = 0; i < ndim_indices; ++i) {
+    oshape.emplace_back(indices->shape[i]);
+  }
+  for (int i = axis+1; i < ndim_data; ++i) {
+    oshape.emplace_back(data->shape[i]);
+  }
+
+  reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
+  return true;
+}
+
+Expr MakeTake(Expr data,
+              Expr indices,
+              IndexExpr axis) {
+  auto attrs = make_node<TakeAttrs>();
+  attrs->axis = axis;
+  static const Op& op = Op::Get("take");
+  return CallNode::make(op, {data, indices}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op._make.take")
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+    runtime::detail::unpack_call<Expr, 3>(MakeTake, args, rv);
+});
+
+RELAY_REGISTER_OP("take")
+.describe(R"code(Take elements from an array along an axis.
+
+When axis is not None, this function does the same thing as 'fancy' indexing
+(indexing arrays using arrays); however, it can be easier to use if you need
+elements along a given axis.
+
+**Note** that when axis is none the flattened input array is used.
+
+Examples::
+
+  a = [[ 1, 2],
+       [ 3, 4]]
+  indices = [3, 0, 2]
+  take(a, indices) = [ 4, 1, 3]
+
+  a = [[ 1., 2.],
+       [ 3., 4.]]
+  indices = [1, 0]
+  take(a, indices, axis=1) = [[ 2., 1.],
+                              [ 4., 3.]]
+
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(2)
+.add_argument("data", "Tensor", "The input tensor.")
+.add_argument("indices", "Tensor", "The indices tensor.")
+.set_support_level(2)
+.add_type_rel("Take", TakeRel);
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index c6b83b39c276..55717bbe23df 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -91,6 +91,27 @@ def check_single_op(opfunc):
                    tvm.relay.round, tvm.relay.abs, tvm.relay.negative]:
         check_single_op(opfunc)
 
+def test_take_infer_type():
+    def verify_take(dshape, indices_shape, oshape, axis=None):
+        ib = relay.ir_builder.IRBuilder()
+        x = ib.param("x", relay.ty.TensorType(dshape, "float32"))
+        indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32"))
+        with ib.function(x, indices) as func:
+            ib.ret(relay.take(x.var, indices.var, axis=axis))
+        ib.ret(func)
+        func = relay.ir_pass.infer_type(ib.env, func.to_func())
+        ftype = func.checked_type
+        assert ftype.ret_type == relay.ty.TensorType(oshape, "float32")
+
+    d1, d2, d3 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3")
+    d4, d5, d6 = tvm.var("d4"), tvm.var("d5"), tvm.var("d6")
+    verify_take((d1,), (1,), (1,), 0)
+    verify_take((4,), (d1, d2), (d1, d2))
+    verify_take((3, 3, 3), (1, d2), (1, d2))
+    verify_take((d1, d2), (d3, d4, d5), (d3, d4, d5, d2), 0)
+    verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1)
+    verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)
+
 
 if __name__ == "__main__":
     test_single_op()
@@ -99,3 +120,4 @@ def check_single_op(opfunc):
     test_copy_infer_type()
     test_transpose_infer_type()
     test_reshape_infer_type()
+    test_take_infer_type()