Skip to content

Commit

Permalink
[TFLITE]GATHER_ND
Browse files Browse the repository at this point in the history
Signed-off-by: Dhruva Ray <[email protected]>
  • Loading branch information
dhruvaray committed May 4, 2020
1 parent 0abf581 commit 3ad0f87
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
39 changes: 39 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, model, subgraph, exp_tab):
'FLOOR': self.convert_floor,
'FULLY_CONNECTED': self.convert_fully_connected,
'GATHER': self.convert_gather,
'GATHER_ND' : self.convert_gather_nd,
'GREATER_EQUAL': self.convert_greater_equal,
'GREATER': self.convert_greater,
'HARD_SWISH': self.convert_hard_swish,
Expand Down Expand Up @@ -1067,6 +1068,31 @@ def convert_gather(self, op):
out = _op.take(data, indices, axis=axis, mode="fast")
return out

def convert_gather_nd(self, op):
"""Method to Convert TFLite GATHER_ND operator"""
try:
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

for t in input_tensors:
assert not t.qnn_params, "Quantized input is not expected."

data = self.get_tensor_expr(input_tensors[0])
indices = self.get_tensor_expr(input_tensors[1])

indices_type = input_tensors[1].tensor.Type()
assert indices_type in (TensorType.INT32, TensorType.INT64)

indices_dims = len(_infer_shape(indices))
indices_t = _op.transpose(indices, axes=[-1] + list(range(indices_dims-1)))

out = _op.gather_nd(data, indices_t)
return out

def convert_strided_slice(self, op):
"""Method to Convert TFLite STRIDED_SLICE operator.
NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask
Expand Down Expand Up @@ -2357,6 +2383,19 @@ def get_expr(self, input_tensor_idx):
def has_expr(self, input_tensor_idx):
return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))

def get_tensor_expr(self, tensor):
""" Returns constant expr for constant else a tensor expr"""
if self.has_expr(tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors.
expr = self.get_expr(tensor.tensor_idx)
else:
# However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant.
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)

return expr

def get_scalar_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
Expand Down
31 changes: 31 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,36 @@ def test_forward_gather():
_test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)
_test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True)

#######################################################################
# Gather_ND
# ---------

def _test_gather_nd(data, indices):
""" One iteration of GATHER_ND """
with tf.Graph().as_default():
in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, name="data")
indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype,
name="indices")
out = tf.gather_nd(in_data, indices_data)

compare_tflite_with_tvm([data, indices], ['data:0', 'indices:0'],
[in_data, indices_data], [out])

def test_forward_gather_nd():
""" GATHER_ND """
_test_gather_nd(
np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype('float32'),
np.asarray([[0, 1], [1, 0]]).astype('int32')
)
_test_gather_nd(
np.reshape(np.arange(30), [5, 6]).astype('int32'),
np.asarray([[1, 2]]).astype('int32')
)
_test_gather_nd(
np.reshape(np.arange(12), [2, 3, 2]).astype('int32'),
np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32')
)

#######################################################################
# StridedSlice
# ------------
Expand Down Expand Up @@ -1994,6 +2024,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_slice()
test_forward_topk()
test_forward_gather()
test_forward_gather_nd()
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()
Expand Down

0 comments on commit 3ad0f87

Please sign in to comment.