Skip to content

Commit

Permalink
[topi][relay] Add operation gather to relay. (apache#5716)
Browse files Browse the repository at this point in the history
  • Loading branch information
notoraptor authored and trevor-m committed Jun 18, 2020
1 parent 7530d24 commit 5cb72f3
Show file tree
Hide file tree
Showing 14 changed files with 358 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ List of operators
topi.concatenate
topi.split
topi.take
topi.gather
topi.gather_nd
topi.full
topi.full_like
Expand Down Expand Up @@ -160,6 +161,7 @@ topi
.. autofunction:: topi.concatenate
.. autofunction:: topi.split
.. autofunction:: topi.take
.. autofunction:: topi.gather
.. autofunction:: topi.gather_nd
.. autofunction:: topi.full
.. autofunction:: topi.full_like
Expand Down
1 change: 1 addition & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ This level enables additional math and transform operators.
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.gather
tvm.relay.gather_nd
tvm.relay.full
tvm.relay.full_like
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
}
};

struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {
Integer axis;

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs") {
TVM_ATTR_FIELD(axis)
.set_default(NullValue<Integer>())
.describe("The axis over which to select values.");
}
};

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer axis;
std::string mode;
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
_reg.register_injective_schedule("transpose")
_reg.register_injective_schedule("stack")
_reg.register_injective_schedule("_contrib_reverse_reshape")
_reg.register_injective_schedule("gather")
_reg.register_injective_schedule("gather_nd")
_reg.register_injective_schedule("sequence_mask")
_reg.register_injective_schedule("one_hot")
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ class TransposeAttrs(Attrs):
class ReshapeAttrs(Attrs):
"""Attributes for transform.reshape"""

@tvm._ffi.register_object("relay.attrs.GatherAttrs")
class GatherAttrs(Attrs):
"""Attributes for transform.gather"""

@tvm._ffi.register_object("relay.attrs.TakeAttrs")
class TakeAttrs(Attrs):
"""Attributes for transform.take"""
Expand Down
37 changes: 37 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,43 @@ def reverse_reshape(data, newshape):
return _make._contrib_reverse_reshape(data, list(newshape))


def gather(data, axis, indices):
"""Gather values along given axis from given indices.
E.g. for a 3D tensor, output is computed as:
.. code-block:: python
out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0
out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1
out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2
``indices`` must have same shape as ``data``, except at dimension ``axis``
which must just be not null. Output will have same shape as ``indices``.
Parameters
----------
data: relay.Expr
The input data to the operator.
axis: int
The axis along which to index.
indices: relay.Expr
The indices of values to gather.
Examples
--------
.. code-block:: python
data = [[1, 2], [3, 4]]
axis = 1
indices = [[0, 0], [1, 0]]
relay.gather(data, axis, indices) = [[1, 1], [4, 3]]
"""
return _make.gather(data, axis, indices)


def gather_nd(data, indices):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Expand Down
82 changes: 82 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2397,6 +2397,88 @@ example below::
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// gather operator
TVM_REGISTER_NODE_TYPE(GatherAttrs);

bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* indices = types[1].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "Gather: expect input data type to be TensorType but get " << types[0];
return false;
}
if (indices == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "Gather: expect indices type to be TensorType but get " << types[1];
return false;
}
CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
const auto param = attrs.as<GatherAttrs>();
CHECK(param != nullptr);
CHECK(param->axis.defined());

const auto ndim_data = data->shape.size();
const auto ndim_indices = indices->shape.size();
int axis = param->axis->value;
CHECK_EQ(ndim_data, ndim_indices);
CHECK_GE(axis, 0);
CHECK_LT(axis, ndim_data);

std::vector<IndexExpr> oshape;
oshape.reserve(ndim_data);
for (size_t i = 0; i < ndim_data; ++i) {
if (i == (size_t)axis) {
const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
CHECK_GE(*indice_shape_i, 1);
} else {
CHECK(reporter->AssertEQ(indices->shape[i], data->shape[i]));
}
oshape.emplace_back(indices->shape[i]);
}
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}

Array<te::Tensor> GatherCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<GatherAttrs>();
return {topi::gather(inputs[0], param->axis, inputs[1])};
}

Expr MakeGather(Expr data, Integer axis, Expr indices) {
auto attrs = make_object<GatherAttrs>();
attrs->axis = std::move(axis);
static const Op& op = Op::Get("gather");
return Call(op, {data, indices}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.gather").set_body_typed(MakeGather);

RELAY_REGISTER_OP("gather")
.describe(R"code(Gather values along given axis from given indices.
E.g. for a 3D tensor, output is computed as:
out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0
out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1
out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2
``indices`` must have same shape as ``data``, except at dimension ``axis``
which must just be not null. Output will have same shape as ``indices``.
)code" TVM_ADD_FILELINE)
.set_attrs_type<GatherAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input data to the operator.")
.add_argument("indices", "Tensor", "The indices of values to gather.")
.set_support_level(3)
.add_type_rel("Gather", GatherRel)
.set_attr<FTVMCompute>("FTVMCompute", GatherCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// gather_nd operator
bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down
52 changes: 52 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,58 @@ def verify_scatter(dshape, ishape, axis=0):
verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3)


def test_gather():
def verify_gather(data, axis, indices, ref_res):
data = np.asarray(data, dtype='float32')
indices = np.asarray(indices, dtype='int32')
ref_res = np.asarray(ref_res)

d = relay.var("x", relay.TensorType(data.shape, "float32"))
i = relay.var("y", relay.TensorType(indices.shape, "int32"))
z = relay.gather(d, axis, i)

func = relay.Function([d, i], z)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data, indices)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res,
rtol=1e-5)

verify_gather([[1, 2], [3, 4]],
1,
[[0, 0], [1, 0]],
[[1, 1], [4, 3]])
verify_gather([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]],
0,
[[[1, 0, 1], [1, 1, 0]]],
[[[6, 1, 8], [9, 10, 5]]])
verify_gather([[[-0.2321, -0.2024, -1.7624], [-0.3829, -0.4246, 0.2448],
[0.1822, 0.2360, -0.8965], [0.4497, -0.2224, 0.6103]],
[[0.0408, -0.7667, -0.4303], [-0.3216, 0.7489, -0.1502],
[0.0144, -0.4699, -0.0064], [-0.0768, -1.6064, 1.3390]]],
1,
[[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]],
[[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]],
[[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]]])
verify_gather([[[0.3050, 1.6986, 1.1034], [0.7020, -0.6960, -2.1818],
[0.3116, -0.5773, -0.9912], [0.0835, -1.3915, -1.0720]],
[[0.1694, -0.6091, -0.6539], [-0.5234, -0.1218, 0.5084],
[0.2374, -1.9537, -2.0078], [-0.5700, -1.0302, 0.1558]]],
2,
[[[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]],
[[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]]],
[[[1.6986, 1.6986, 0.3050, 1.6986],
[0.7020, 0.7020, -2.1818, -2.1818],
[-0.5773, -0.9912, -0.5773, -0.9912],
[-1.0720, -1.0720, -1.3915, 0.0835]],
[[0.1694, 0.1694, -0.6091, -0.6539],
[0.5084, 0.5084, -0.1218, -0.5234],
[-1.9537, -2.0078, 0.2374, 0.2374],
[-0.5700, 0.1558, -0.5700, 0.1558]]])


def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
x = relay.var("x", relay.TensorType(xshape, "float32"))
Expand Down
48 changes: 48 additions & 0 deletions topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,54 @@ inline Tensor tile(const Tensor& x, Array<Integer> reps, std::string name = "T_t
}
}

/*!
* \brief Gather values along given axis from given indices.
*
* \param data The input data to the operator.
* \param axis The axis along which to index.
* \param indices The indices of values to gather.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the gather operation
*/
inline Tensor gather(const Tensor& data, int axis, const Tensor& indices,
std::string name = "T_gather", std::string tag = kInjective) {
size_t ndim_d = data->shape.size();
size_t ndim_i = indices->shape.size();
CHECK_GE(ndim_d, 1) << "Cannot gather from a scalar.";
CHECK_EQ(ndim_d, ndim_i);
CHECK_GE(axis, 0);
CHECK_LT(axis, ndim_d);
size_t indices_dim_i = static_cast<size_t>(GetConstInt(indices->shape[axis]));
CHECK_GE(indices_dim_i, 1);
CHECK(indices->dtype.is_int());

Array<PrimExpr> out_shape;
for (size_t i = 0; i < ndim_i; ++i) {
out_shape.push_back(indices->shape[i]);
}

return compute(
out_shape,
[&](const Array<Var>& out_index) {
Array<PrimExpr> indices_position;
for (size_t i = 0; i < ndim_i; ++i) {
indices_position.push_back(out_index[i]);
}
Array<PrimExpr> real_indices;
for (size_t i = 0; i < ndim_i; ++i) {
if (i == (size_t)axis) {
real_indices.push_back(indices(indices_position));
} else {
real_indices.push_back(indices_position[i]);
}
}
return data(real_indices);
},
name, tag);
}

/*!
* \brief Gather elements from a n-dimension array.
*
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .roi_pool_python import roi_pool_nchw_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_python import gather_python
from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python, strided_set_python
from .batch_matmul import batch_matmul
Expand Down
46 changes: 46 additions & 0 deletions topi/python/topi/testing/gather_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""gather in python"""
import numpy as np

def gather_python(data, axis, indices):
""" Python version of Gather operator
Parameters
----------
data : numpy.ndarray
Numpy array
axis: int
integer
indices : numpy.ndarray
Numpy array
Returns
-------
b_np : numpy.ndarray
Numpy array
"""
shape_indices = indices.shape
out = np.zeros(shape_indices, dtype=data.dtype)
for index in np.ndindex(*shape_indices):
new_index = list(index)
new_index[axis] = indices[index]
out[index] = data[tuple(new_index)]
return out
Loading

0 comments on commit 5cb72f3

Please sign in to comment.