diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index f2a9e5852990d..dee5c55ac9dc3 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -136,6 +136,7 @@ def __init__(self, model, subgraph, exp_tab): 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, 'REVERSE_SEQUENCE': self.convert_reverse_sequence, + 'REVERSE_V2': self.convert_reverse_v2, 'SELECT': self.convert_select, 'SHAPE': self.convert_shape, 'SIN': self.convert_sin, @@ -2973,6 +2974,22 @@ def convert_one_hot(self, op): return out + def convert_reverse_v2(self, op): + """Convert TFLite REVERSE_V2""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensor's length should be 2" + + input_expr = self.get_expr(input_tensors[0].tensor_idx) + + # Getting axis value + axis = self.get_tensor_value(input_tensors[1]) + if isinstance(axis, np.ndarray): + assert len(axis) == 1, "TFLite does not support multi-axis yet" + axis = int(axis) + + out = _op.reverse(input_expr, axis) + return out + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ebfa10fc35fda..6164ab7cd7024 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2620,6 +2620,32 @@ def test_forward_fully_connected(): _test_fully_connected([5, 1, 1, 150], [150, 100], [100]) +####################################################################### +# REVERSE_V2 +# ---------- + +def _test_reverse_v2(input_shape, axis, dtype): + """ One iteration of REVERSE_V2 """ + with tf.Graph().as_default(): + input = np.random.randint(0, 100, size=input_shape).astype(dtype) + in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input") + in_axis = ops.convert_to_tensor(axis, dtype=axis.dtype) + + out = array_ops.reverse(in_input, in_axis) + + compare_tflite_with_tvm( + [input], + ["input"], + [in_input], + [out]) + +def test_forward_reverse_v2(): + """ REVERSE_V2 """ + for dtype in ['float32', 'int32']: + _test_reverse_v2((5), np.array([0], dtype='int32'), dtype) + _test_reverse_v2((5, 6, 4, 2), np.array([2], dtype='int32'), dtype) + + ####################################################################### # Custom Operators # ---------------- @@ -3098,6 +3124,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_quantize_dequantize() test_forward_arg_min_max() test_forward_expand_dims() + test_forward_reverse_v2() # NN test_forward_convolution()