Skip to content

Commit

Permalink
[Relay][OP] Gather_nd exposed to relay (#2945)
Browse files Browse the repository at this point in the history
* gather_nd added

* gather_nd test added

* more test added

* fix lint

* fix build error

* fix lint

* comments addressed
  • Loading branch information
Laurawly authored and icemelon committed Apr 2, 2019
1 parent fbd1c16 commit ae21edd
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ This level enables additional math and transform operators.
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.gather_nd
tvm.relay.full
tvm.relay.full_like
tvm.relay.cast
Expand Down Expand Up @@ -225,6 +226,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.zeros_like
.. autofunction:: tvm.relay.ones
.. autofunction:: tvm.relay.ones_like
.. autofunction:: tvm.relay.gather_nd
.. autofunction:: tvm.relay.full
.. autofunction:: tvm.relay.full_like
.. autofunction:: tvm.relay.cast
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ def _mx_deformable_convolution(inputs, attrs):
"zeros_like",
"ones_like",
"where",
"gather_nd",
]

_convert_map = {
Expand Down Expand Up @@ -782,7 +783,6 @@ def _mx_deformable_convolution(inputs, attrs):
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
# "gather_nd",
# "Crop" : _crop_like,
}

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_reg.register_schedule("stack", schedule_injective)
_reg.register_schedule("concatenate", schedule_injective)
_reg.register_schedule("_contrib_reverse_reshape", schedule_injective)
_reg.register_schedule("gather_nd", schedule_injective)

# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,3 +651,36 @@ def reverse_reshape(data, newshape):
if isinstance(newshape, int):
newshape = [newshape]
return _make._contrib_reverse_reshape(data, list(newshape))


def gather_nd(data, indices):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Parameters
----------
data : relay.Expr
The input data to the operator.
indices : relay.Expr
The shape of output tensor.
Returns
-------
ret : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
data = [[0, 1], [2, 3]]
indices = [[1, 1, 0], [0, 1, 0]]
relay.gather_nd(data, indices) = [2, 3, 0]
data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
indices = [[0, 1], [1, 0]]
relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
"""

return _make.gather_nd(data, indices)
70 changes: 70 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2122,5 +2122,75 @@ example below::
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// gather_nd operator
bool GatherNDRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, indices, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* indices = types[1].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "GatherND: expect input data type to be TensorType but get "
<< types[0];
return false;
}
if (indices == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "GatherND: expect indices type to be TensorType but get "
<< types[1];
return false;
}
const size_t ndim = data->shape.size();
const IntImm* mdim = data->shape[0].as<IntImm>();
const size_t kdim = indices->shape.size() - 1;
CHECK(size_t(mdim->value) <= ndim)
<< "GatherND: indices shape does satisfy.";

Array<IndexExpr> oshape;
for (size_t i = 1; i < kdim + 1; ++i)
oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value; i < ndim; ++i)
oshape.push_back(data->shape[i]);
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}

Array<Tensor> GatherNDCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return { topi::gather_nd(inputs[0], inputs[1]) };
}

Expr MakeGatherND(Expr data,
Expr indices) {
static const Op& op = Op::Get("gather_nd");
return CallNode::make(op, {data, indices}, {});
}

TVM_REGISTER_API("relay.op._make.gather_nd")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGatherND, args, rv);
});

RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to
a tensor whose shape is defined by indices.
Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with
shape (M, Y_0, ..., Y_{K-1}), the output will have shape
(Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N,
output shape will simply be (Y_0, ..., Y_{K-1}).
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("GatherND", GatherNDRel)
.set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

} // namespace relay
} // namespace tvm
16 changes: 16 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,21 @@ def verify(shape, indices_src, axis, mode="clip"):
verify((3,4), [-1, 5], 1)
verify((3,4), [-1, 5], 1, mode="wrap")

def test_forward_gather_nd():
def verify(xshape, yshape, y_data):
x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x_data": xshape, "y_data": yshape}, {"x_data": "float32", "y_data": "int32"})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand Down Expand Up @@ -527,3 +542,4 @@ def verify(shape, indices_src, axis, mode="clip"):
test_forward_embedding()
test_forward_smooth_l1()
test_forward_take()
test_forward_gather_nd()
21 changes: 20 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ def verify_stack(dshapes, axis):
verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)



def test_reverse():
def verify_reverse(dshape, axis):
x = relay.var("x", relay.TensorType(dshape, "float32"))
Expand All @@ -573,6 +572,25 @@ def verify_reverse(dshape, axis):
verify_reverse((2, 3, 4), -1)


def test_gather_nd():
def verify_gather_nd(xshape, yshape, y_data):
x = relay.var("x", relay.TensorType(xshape, "float32"))
y = relay.var("y", relay.TensorType(yshape, "int32"))
z = relay.gather_nd(x, y)

func = relay.Function([x, y], z)
x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = x_data[y_data]

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_gather_nd((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify_gather_nd((2, 2, 2), (2, 2), [[0, 1], [1, 0]])


if __name__ == "__main__":
test_cast()
test_zeros_ones()
Expand Down Expand Up @@ -601,3 +619,4 @@ def verify_reverse(dshape, axis):
test_stack()
test_tile()
test_repeat()
test_gather_nd()

0 comments on commit ae21edd

Please sign in to comment.