Skip to content

Commit

Permalink
[Frontend][Tensorflow] Gather nd bug fix for one dim support in tenso…
Browse files Browse the repository at this point in the history
…rflow (apache#5588)

* [Frontend][Tensorflow] Gather_nd one dim support added

* Test case added

* Doc error handled

* Review comment handled: reverting new attr introduced

* Check added at mxnet frontend

* Doc error handled

* TFLite test case failure resolved
  • Loading branch information
ANSHUMAN TRIPATHY authored and trevor-m committed Jun 18, 2020
1 parent c03e1da commit a1ffd50
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 9 deletions.
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,10 @@ def _mx_take(inputs, attrs):
axis = attrs.get_int("axis", 0)
return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode)

def _mx_gather_nd(inputs, attrs):
assert len(inputs) == 2
assert len(_infer_shape(inputs[1])) > 1, "index tensor to have at least 2 dimensions"
return _op.gather_nd(inputs[0], inputs[1])

def _mx_reverse(inputs, attrs):
assert len(inputs) == 1
Expand Down Expand Up @@ -1770,7 +1774,6 @@ def impl(inputs, input_types):
"zeros_like",
"ones_like",
"where",
"gather_nd",
"cos",
"cosh",
"sin",
Expand Down Expand Up @@ -1918,6 +1921,7 @@ def impl(inputs, input_types):
"pad" : _mx_pad,
"Pad" : _mx_pad,
"take" : _mx_take,
"gather_nd" : _mx_gather_nd,
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ class TransposeAttrs(Attrs):
class ReshapeAttrs(Attrs):
"""Attributes for transform.reshape"""


@tvm._ffi.register_object("relay.attrs.TakeAttrs")
class TakeAttrs(Attrs):
"""Attributes for transform.take"""
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2248,6 +2248,9 @@ bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Array<IndexExpr> oshape;
for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]);
for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]);
if (oshape.size() == 0) {
oshape.push_back(tir::make_const(DataType::Int(32), 1));
}
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
Expand Down
4 changes: 3 additions & 1 deletion tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def verify(shape, indices_src, axis, mode="clip"):
verify((3,4), [-1, 5], 1, mode="wrap")

def test_forward_gather_nd():
def verify(xshape, yshape, y_data):
def verify(xshape, yshape, y_data, error=False):
x_data = np.random.uniform(size=xshape).astype("float32")
ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data))
mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data"))
Expand All @@ -618,10 +618,12 @@ def verify(xshape, yshape, y_data):
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]])
verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify((3, 2, 2), (2, 2), [[0, 1], [1, 0]])
verify((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]])
verify((1, 4), (1, 1), [[0]])

def test_forward_bilinear_resize():
# add tests including scale_height and scale_width when mxnet is updated to version 1.5
Expand Down
26 changes: 21 additions & 5 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,7 +1379,7 @@ def test_forward_truncatemod():


#######################################################################
# Gather, GatherV2, GatherNd
# Gather, GatherV2
# --------------------------

def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
Expand Down Expand Up @@ -1418,16 +1418,32 @@ def test_forward_gather():
_test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32')
_test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32')

#######################################################################
# GatherND
# --------------------------

def test_forward_gather_nd():
def _test_gather_nd(ip_shape, indice_value, dtype):
"""test operator GatherNd"""
np_data = np.random.uniform(1, 100, size=(2, 2, 2)).astype(np.float32)
np_data = np.random.uniform(1, 100, size=ip_shape).astype(dtype)
tf.reset_default_graph()
with tf.Graph().as_default():
in_data = tf.placeholder(tf.float32, (2, 2, 2), name="in_data")
tf.gather_nd(in_data, indices=[[1, 0, 0], [0, 0, 0]], name="gather_nd")
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
tf.gather_nd(in_data, indices=indice_value, name="gather_nd")
compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0')

def test_forward_gather_nd():
"""test operator GatherNd"""
_test_gather_nd((2, 2), [[0, 0], [1, 1]], 'float32')
_test_gather_nd((2, 2, 2), [[1, 0, 0], [0, 0, 0]], 'float32')
_test_gather_nd((4,), [1], 'float32')
_test_gather_nd((4,), [1], 'int32')
_test_gather_nd((1, 4), [0, 3], 'int32')
_test_gather_nd((2, 2), [[[1, 0], [0, 1]]], 'int32')
_test_gather_nd((2, 2), [[[1, 0], [0, 1]]], 'float32')
_test_gather_nd((3, 3, 3), [[[1, 0]]], 'int32')
_test_gather_nd((3, 3, 3), [[[1, 0]]], 'int32')
_test_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], 'float32')
_test_gather_nd((3, 3, 3), [[[2, 1]]], 'int32')

#######################################################################
# BiasAdd
Expand Down
12 changes: 12 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,18 @@ def test_forward_gather_nd():
np.reshape(np.arange(12), [2, 3, 2]).astype('int32'),
np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32')
)
_test_gather_nd(
np.reshape(np.arange(4), [4]).astype('float32'),
np.asarray([1]).astype('int32')
)
_test_gather_nd(
np.reshape(np.arange(4), [1, 4]).astype('float32'),
np.asarray([0]).astype('int32')
)
_test_gather_nd(
np.reshape(np.arange(4), [1, 4]).astype('float32'),
np.asarray([0, 3]).astype('int32')
)

#######################################################################
# StridedSlice
Expand Down
5 changes: 4 additions & 1 deletion topi/include/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
std::string tag = kInjective) {
size_t ndim_d = data->shape.size();
size_t ndim_i = indices->shape.size();
CHECK_GT(ndim_i, 1) << "indices tensor must have at least 2 dimensions";
CHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions";
size_t indices_dim0 = static_cast<size_t>(GetConstInt(indices->shape[0]));
CHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more "
<< "than dimensions of data tensor";
Expand Down Expand Up @@ -1027,6 +1027,9 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string n
real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position)));
}
}
if (real_indices.size() == ndim_d) {
return data(real_indices);
}
for (size_t i = ndim_i - 1; i < out_index.size(); ++i) {
real_indices.push_back(out_index[i]);
}
Expand Down

0 comments on commit a1ffd50

Please sign in to comment.