Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLITE] SELECT #5488

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self, model, subgraph, exp_tab):
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
'ROUND': self.convert_round,
'RSQRT': self.convert_rsqrt,
'SELECT': self.convert_select,
'SIN': self.convert_sin,
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
Expand All @@ -140,6 +141,7 @@ def __init__(self, model, subgraph, exp_tab):
'TRANSPOSE_CONV': self.convert_transpose_conv,
'TRANSPOSE': self.convert_transpose,
'UNPACK': self.convert_unpack,
'WHERE': self.convert_select,
'ZEROS_LIKE': self.convert_zeros_like,
}

Expand Down Expand Up @@ -2002,6 +2004,38 @@ def convert_unpack(self, op):

return squeezed

def convert_select(self, op):
"""Convert TFLite select"""
try:
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)

for t in input_tensors:
assert not t.qnn_params, "Quantized input is not expected."

assert len(input_tensors) == 3

condition, x, y = input_tensors[0], input_tensors[1], input_tensors[1]
assert condition.tensor.Type() in (TensorType.INT32, TensorType.INT64, TensorType.BOOL)

for t_type in [x.tensor.Type(), y.tensor.Type()]:
assert t_type in (TensorType.INT32, TensorType.INT64)

expressions = []

for t in input_tensors:
if self.has_expr(t.tensor_idx):
expressions.append(self.get_expr(t.tensor_idx))
else:
tensor_type = self.get_tensor_type_str(t.tensor.Type())
tensor_value = self.get_tensor_value(t)
expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type))

return _op.where(expressions[0], expressions[1], expressions[2])

def convert_batch_to_space_nd(self, op):
"""batch_to_space_nd implementation."""

Expand Down
39 changes: 39 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,44 @@ def test_forward_spacetodepth():
_test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2)
_test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4)

#######################################################################
# Select
# ------

def _test_select(data, use_placeholder = True):
""" One iteration of select with placeholders """
assert len(data) == 3

data[0] = np.array(data[0], dtype='bool')
data[1] = None if data[1] is None else np.array(data[1]).astype('int32')
data[2] = None if data[2] is None else np.array(data[2]).astype('int32')

with tf.Graph().as_default():
condition = tf.placeholder(dtype='bool', shape=data[0].shape, name="condition")
if use_placeholder:
x = tf.placeholder(dtype='int32', shape=data[1].shape, name="x")
y = tf.placeholder(dtype='int32', shape=data[2].shape, name="y")
out = tf.where(condition, x, y)

compare_tflite_with_tvm(data, ['condition:0', 'x:0', 'y:0'], [condition, x, y], [out])
else:
x = tf.constant(data[1], dtype='int32', shape=data[1].shape, name="x")
y = tf.constant(data[2], dtype='int32', shape=data[2].shape, name="y")
out = tf.where(condition, x, y)

compare_tflite_with_tvm([data[0]], ['condition:0'], [condition], [out])


def test_forward_select():
#tf converter to tflite has bool data type support in tf version 1.15
if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'):
_test_select([[1, 0], [[1, 2], [3, 4]], [[5, 6], [7, 8]]])
_test_select([[[False, True], [True, False]], [[1, 2], [3, 4]], [[5, 6], [7, 8]]])
_test_select([[[False, True], [True, False]], [[1, 2], [3, 4]], [[5, 6], [7, 8]]], False)

#Not supported at topi/relay layer
#_test_where([[[False, False], [True, True]], None, None])

#######################################################################
# Fully Connected
# ---------------
Expand Down Expand Up @@ -2014,6 +2052,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_stridedslice()
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_select()

# NN
test_forward_convolution()
Expand Down