Skip to content

Commit

Permalink
[TFLITE]TOP_K op parser support (#5051)
Browse files Browse the repository at this point in the history
* [TFLITE]TOP_K op parser support

* Testcase updated
  • Loading branch information
siju-samuel authored Mar 30, 2020
1 parent ae482a3 commit 3326031
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
19 changes: 19 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self, model, subgraph, exp_tab):
'TAN': self.convert_tan,
'TANH':self.convert_tanh,
'TILE': self.convert_tile,
'TOPK_V2': self.convert_topk_v2,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'TRANSPOSE': self.convert_transpose,
'UNPACK': self.convert_unpack,
Expand Down Expand Up @@ -1550,6 +1551,24 @@ def convert_tile(self, op):

return out

def convert_topk_v2(self, op):
""" Convert TFLite TOPK_v2 """
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)
k = self.get_tensor_value(input_tensors[1])
out = _op.topk(in_expr, int(k))

return out

def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
Expand Down
19 changes: 19 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,24 @@ def test_forward_slice():
_test_slice(np.arange(8, dtype=np.int32).reshape((2, 4)), begin=[0, 1], size=[-1, -1])
_test_slice(np.arange(5, dtype=np.int32).reshape((5, )), begin=[4], size=[-1])

#######################################################################
# Topk
# ----
def _test_topk(in_shape, k=1):
""" One iteration of TOPK """
data = np.random.uniform(size=in_shape).astype('float32')
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = nn_ops.top_k(in_data, k, name='TopK')
compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out[0]])

def test_forward_topk():
""" TOPK """
_test_topk((3,), 1)
_test_topk((3,), 3)
_test_topk((3, 5, 7), 3)
_test_topk((3, 5, 7), 3)

#######################################################################
# transpose
# ---------
Expand Down Expand Up @@ -1775,6 +1793,7 @@ def test_forward_mediapipe_hand_landmark():
test_all_resize()
test_forward_squeeze()
test_forward_slice()
test_forward_topk()
test_forward_depthtospace()
test_forward_spacetodepth()

Expand Down

0 comments on commit 3326031

Please sign in to comment.