Skip to content

Commit

Permalink
[TFLITE]TOP_K op parser support
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Mar 22, 2020
1 parent 050f2bd commit 5be8ba3
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self, model, subgraph, exp_tab):
'TAN': self.convert_tan,
'TANH':self.convert_tanh,
'TILE': self.convert_tile,
'TOPK_V2': self.convert_topk_v2,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'TRANSPOSE': self.convert_transpose,
'UNPACK': self.convert_unpack,
Expand Down Expand Up @@ -1550,6 +1551,24 @@ def convert_tile(self, op):

return out

def convert_topk_v2(self, op):
""" Convert TFLite TOPK_v2 """
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"
input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)
k = self.get_tensor_value(input_tensors[1])
out = _op.topk(in_expr, int(k))

return out

def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
Expand Down

0 comments on commit 5be8ba3

Please sign in to comment.