From 42a961d21ba0e6c01f709ea862d8eae806645a63 Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Tue, 18 Aug 2020 17:43:29 +0100 Subject: [PATCH] Gather operation with indices as tensor expr in TFLite frontend (#6168) * gather with indices as tensor expr Added handling of indices as tensor expr to gather operation, unit tests amended Code cheking out of boundary error refactored in more "pythonic" way. Fixed bug in negative axis value normalisation * replaced with get_tensor_expr --- python/tvm/relay/frontend/tflite.py | 54 ++++++++------------ tests/python/frontend/tflite/test_forward.py | 41 ++++++++++----- 2 files changed, 49 insertions(+), 46 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 48c88d042ab87..f2a9e5852990d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1347,14 +1347,10 @@ def convert_gather(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" - data = self.get_expr(input_tensors[0].tensor_idx) - + data = self.get_tensor_expr(input_tensors[0]) indices = input_tensors[1] indices_type = indices.tensor.Type() assert indices_type in (TensorType.INT32, TensorType.INT64) - indices_type_str = self.get_tensor_type_str(indices_type) - indices = self.exp_tab.new_const(self.get_tensor_value(indices), - dtype=indices_type_str) assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions op_options = op.BuiltinOptions() @@ -1366,37 +1362,31 @@ def convert_gather(self, op): data_shape = list(input_tensors[0].tensor.ShapeAsNumpy()) data_dim = len(data_shape) - axis_n = axis - if axis_n < 0: - axis_n += axis_n + data_dim - assert axis_n >= 0, "Axis out of bounds" - assert axis_n < data_dim, "Axis out of bounds" - - indices_val = self.get_tensor_value(input_tensors[1]) - indices_shape = list(indices_val.shape) - indices_len = len(indices_shape) - - out_shape = [] - for i in range(data_dim): - if axis_n == i: - for j in range(indices_len): - out_shape.append(indices_shape[j]) - else: - out_shape.append(data_shape[i]) - - loopover = [range(s) for s in out_shape] - for idx in list(itertools.product(*loopover)): - indices_position = [idx[j] for j in range(axis_n, axis_n+indices_len)] + axis = data_dim + axis if axis < 0 else axis + assert axis >= 0, "Axis out of bounds" + assert axis < data_dim, "Axis out of bounds" - real_indices = [idx[j] for j in range(axis_n)] - real_indices.append(indices_val[tuple(indices_position)]) - real_indices.extend([idx[j] for j in range(axis_n + indices_len, len(idx))]) - for r, d in zip(real_indices, data_shape): - if r >= d: + if self.has_expr(indices.tensor_idx): + indices_expr = self.get_expr(indices.tensor_idx) + else: + indices_val = self.get_tensor_value(indices) + indices_expr = self.exp_tab.new_const(indices_val, + dtype=self.get_tensor_type_str(indices_type)) + indices_shape = list(indices_val.shape) + indices_len = len(indices_shape) + + out_shape = data_shape[:axis] + indices_shape[:] + data_shape[axis+1:] + + loopover = [range(s) for s in out_shape] + for idx in list(itertools.product(*loopover)): + real_indices = list(idx[:axis]) \ + + [indices_val[idx[axis: axis + indices_len]]] \ + + list(idx[axis + indices_len:]) + if np.any(np.subtract(data_shape, real_indices) < 0): raise ValueError("TFLite out of bound indices are not supported.") # Use mode 'fast' since indices are already checked within bounds. - out = _op.take(data, indices, axis=axis, mode="fast") + out = _op.take(data, indices_expr, axis=axis, mode="fast") return out def convert_gather_nd(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ebb4d77cce64c..ebfa10fc35fda 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -396,20 +396,31 @@ def test_forward_topk(): # Gather # ------ -def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False): +def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False, wrap_idx=False): """ One iteration of Gather """ indices = np.asarray(indices).astype('int32') data = np.random.uniform(1, 10, size=dshape) data = data.astype(np.uint8) if quantized else data.astype(dtype) with tf.Graph().as_default(): + if wrap_idx: + in_name = "in_indices" + indices_expr = array_ops.placeholder(shape=indices.shape, dtype=indices.dtype, name=in_name) + in_tensor_name = [in_name + ":0"] + in_indices = [indices_expr] + else: + indices_expr = indices + indices = [] + in_tensor_name = [] + in_indices = [] + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data") if axis: - out = array_ops.gather(in_data, indices, axis=axis) + out = array_ops.gather(in_data, indices_expr, axis=axis) else: - out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis + out = array_ops.gather(in_data, indices_expr) #tflite conversion fails for None axis input_range = {'in_data': (-100, 100)} if quantized else None try: - compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], + compare_tflite_with_tvm([data] + indices, ['in_data:0'] + in_tensor_name, [in_data] + in_indices, [out], quantized=quantized, input_range=input_range) except ValueError as e: if not oob: @@ -420,16 +431,18 @@ def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False): def test_forward_gather(): """ GATHER """ for quantized in [False, True]: - _test_gather((4,), [1], 0, 'float32', quantized) - _test_gather((4,), [1], None, 'int32', quantized) - _test_gather((1, 4), [0], 0, 'int32', quantized) - _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized) - _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized) - _test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized) - _test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized) - _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized) - _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized) - _test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized) + for wrap_idx in [False, True]: + _test_gather((4,), [1], 0, 'float32', quantized, wrap_idx) + _test_gather((4,), [1], None, 'int32', quantized, wrap_idx) + _test_gather((1, 4), [0], 0, 'int32', quantized, wrap_idx) + _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized, wrap_idx) + _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized, wrap_idx) + _test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized, wrap_idx) + _test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized, wrap_idx) + _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized, wrap_idx) + _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized, wrap_idx) + _test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized, wrap_idx) + # Out of boundary error cannot be tested with wrapped index _test_gather((4,), [16], 0, 'float32', quantized, oob=True) _test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True) _test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True)