From 3ad0f878593523dbcfb285713ffe5c9d9cc794df Mon Sep 17 00:00:00 2001 From: Dhruva Ray Date: Mon, 4 May 2020 19:45:12 +0530 Subject: [PATCH] [TFLITE]GATHER_ND Signed-off-by: Dhruva Ray --- python/tvm/relay/frontend/tflite.py | 39 ++++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 31 ++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 703ef9c8b6b06..8b54e3913be33 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -84,6 +84,7 @@ def __init__(self, model, subgraph, exp_tab): 'FLOOR': self.convert_floor, 'FULLY_CONNECTED': self.convert_fully_connected, 'GATHER': self.convert_gather, + 'GATHER_ND' : self.convert_gather_nd, 'GREATER_EQUAL': self.convert_greater_equal, 'GREATER': self.convert_greater, 'HARD_SWISH': self.convert_hard_swish, @@ -1067,6 +1068,31 @@ def convert_gather(self, op): out = _op.take(data, indices, axis=axis, mode="fast") return out + def convert_gather_nd(self, op): + """Method to Convert TFLite GATHER_ND operator""" + try: + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + for t in input_tensors: + assert not t.qnn_params, "Quantized input is not expected." + + data = self.get_tensor_expr(input_tensors[0]) + indices = self.get_tensor_expr(input_tensors[1]) + + indices_type = input_tensors[1].tensor.Type() + assert indices_type in (TensorType.INT32, TensorType.INT64) + + indices_dims = len(_infer_shape(indices)) + indices_t = _op.transpose(indices, axes=[-1] + list(range(indices_dims-1))) + + out = _op.gather_nd(data, indices_t) + return out + def convert_strided_slice(self, op): """Method to Convert TFLite STRIDED_SLICE operator. NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask @@ -2357,6 +2383,19 @@ def get_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + def get_tensor_expr(self, tensor): + """ Returns constant expr for constant else a tensor expr""" + if self.has_expr(tensor.tensor_idx): + # In most cases, we can assume that TOCO fuses elemwise operators + # with constants - it means both will be tensors. + expr = self.get_expr(tensor.tensor_idx) + else: + # However, in some corner cases, the elemwise operator is not fused, + # we can receive as constant. + type_str = self.get_tensor_type_str(tensor.tensor.Type()) + expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str) + + return expr def get_scalar_from_constant(expr): """ Returns scalar value from Relay constant scalar. """ diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 283d87d5078a1..f77186808de1b 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -343,6 +343,36 @@ def test_forward_gather(): _test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True) _test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True) +####################################################################### +# Gather_ND +# --------- + +def _test_gather_nd(data, indices): + """ One iteration of GATHER_ND """ + with tf.Graph().as_default(): + in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, name="data") + indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype, + name="indices") + out = tf.gather_nd(in_data, indices_data) + + compare_tflite_with_tvm([data, indices], ['data:0', 'indices:0'], + [in_data, indices_data], [out]) + +def test_forward_gather_nd(): + """ GATHER_ND """ + _test_gather_nd( + np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype('float32'), + np.asarray([[0, 1], [1, 0]]).astype('int32') + ) + _test_gather_nd( + np.reshape(np.arange(30), [5, 6]).astype('int32'), + np.asarray([[1, 2]]).astype('int32') + ) + _test_gather_nd( + np.reshape(np.arange(12), [2, 3, 2]).astype('int32'), + np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32') + ) + ####################################################################### # StridedSlice # ------------ @@ -1994,6 +2024,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_slice() test_forward_topk() test_forward_gather() + test_forward_gather_nd() test_forward_stridedslice() test_forward_depthtospace() test_forward_spacetodepth()