Skip to content

Commit

Permalink
[TFLite] Implemented REVERSE_V2 Operator for TFLite.
Browse files Browse the repository at this point in the history
* Added implementation for REVERSE_V2 Operator.
* Added tests for REVERSE_V2 Operator.
  • Loading branch information
Rishabh Jain committed Aug 19, 2020
1 parent 7aa2de3 commit 4e96846
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(self, model, subgraph, exp_tab):
'ROUND': self.convert_round,
'RSQRT': self.convert_rsqrt,
'REVERSE_SEQUENCE': self.convert_reverse_sequence,
'REVERSE_V2': self.convert_reverse_v2,
'SELECT': self.convert_select,
'SHAPE': self.convert_shape,
'SIN': self.convert_sin,
Expand Down Expand Up @@ -2973,6 +2974,22 @@ def convert_one_hot(self, op):

return out

def convert_reverse_v2(self, op):
"""Convert TFLite REVERSE_V2"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensor's length should be 2"

input_expr = self.get_expr(input_tensors[0].tensor_idx)

# Getting axis value
axis = self.get_tensor_value(input_tensors[1])
if isinstance(axis, np.ndarray):
assert len(axis) == 1, "TFLite does not support multi-axis yet"
axis = int(axis)

out = _op.reverse(input_expr, axis)
return out


def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
Expand Down
27 changes: 27 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2620,6 +2620,32 @@ def test_forward_fully_connected():
_test_fully_connected([5, 1, 1, 150], [150, 100], [100])


#######################################################################
# REVERSE_V2
# ----------

def _test_reverse_v2(input_shape, axis, dtype):
""" One iteration of REVERSE_V2 """
with tf.Graph().as_default():
input = np.random.randint(0, 100, size=input_shape).astype(dtype)
in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input")
in_axis = ops.convert_to_tensor(axis, dtype=axis.dtype)

out = array_ops.reverse(in_input, in_axis)

compare_tflite_with_tvm(
[input],
["input"],
[in_input],
[out])

def test_forward_reverse_v2():
""" REVERSE_V2 """
for dtype in ['float32', 'int32']:
_test_reverse_v2((5), np.array([0], dtype='int32'), dtype)
_test_reverse_v2((5, 6, 4, 2), np.array([2], dtype='int32'), dtype)


#######################################################################
# Custom Operators
# ----------------
Expand Down Expand Up @@ -3098,6 +3124,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_quantize_dequantize()
test_forward_arg_min_max()
test_forward_expand_dims()
test_forward_reverse_v2()

# NN
test_forward_convolution()
Expand Down

0 comments on commit 4e96846

Please sign in to comment.