From 8632f249a6566beda77980a7af89cb67242fc365 Mon Sep 17 00:00:00 2001 From: Dhruva Ray Date: Thu, 30 Apr 2020 19:55:26 +0530 Subject: [PATCH 1/2] [TFLITE] SELECT Signed-off-by: Dhruva Ray --- python/tvm/relay/frontend/tflite.py | 34 +++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 39 ++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 5c8bbfb3c8f9..2c63024aeb8a 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -119,6 +119,7 @@ def __init__(self, model, subgraph, exp_tab): 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, + 'SELECT': self.convert_select, 'SIN': self.convert_sin, 'SLICE': self.convert_slice, 'SOFTMAX': self.convert_softmax, @@ -140,6 +141,7 @@ def __init__(self, model, subgraph, exp_tab): 'TRANSPOSE_CONV': self.convert_transpose_conv, 'TRANSPOSE': self.convert_transpose, 'UNPACK': self.convert_unpack, + 'WHERE': self.convert_select, 'ZEROS_LIKE': self.convert_zeros_like, } @@ -2002,6 +2004,38 @@ def convert_unpack(self, op): return squeezed + def convert_select(self, op): + """Convert TFLite select""" + try: + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + input_tensors = self.get_input_tensors(op) + + for t in input_tensors: + assert not t.qnn_params, "Quantized input is not expected." + + assert len(input_tensors) == 3 + + condition, x, y = input_tensors[0], input_tensors[1], input_tensors[1] + assert condition.tensor.Type() in (TensorType.INT32, TensorType.INT64, TensorType.BOOL) + + for type in [x.tensor.Type(), y.tensor.Type()]: + assert type in (TensorType.INT32, TensorType.INT64) + + expressions = [] + + for t in input_tensors: + if self.has_expr(t.tensor_idx): + expressions.append(self.get_expr(t.tensor_idx)) + else: + tensor_type = self.get_tensor_type_str(t.tensor.Type()) + tensor_value = self.get_tensor_value(t) + expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type)) + + return _op.where(expressions[0], expressions[1], expressions[2]) + def convert_batch_to_space_nd(self, op): """batch_to_space_nd implementation.""" diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 26bb86dfe24a..de87bd9fa4c6 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1658,6 +1658,44 @@ def test_forward_spacetodepth(): _test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2) _test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4) +####################################################################### +# Select +# ------ + +def _test_select(data, use_placeholder = True): + """ One iteration of select with placeholders """ + assert len(data) == 3 + + data[0] = np.array(data[0], dtype='bool') + data[1] = None if data[1] is None else np.array(data[1]).astype('int32') + data[2] = None if data[2] is None else np.array(data[2]).astype('int32') + + with tf.Graph().as_default(): + condition = tf.placeholder(dtype='bool', shape=data[0].shape, name="condition") + if use_placeholder: + x = tf.placeholder(dtype='int32', shape=data[1].shape, name="x") + y = tf.placeholder(dtype='int32', shape=data[2].shape, name="y") + out = tf.where(condition, x, y) + + compare_tflite_with_tvm(data, ['condition:0', 'x:0', 'y:0'], [condition, x, y], [out]) + else: + x = tf.constant(data[1], dtype='int32', shape=data[1].shape, name="x") + y = tf.constant(data[2], dtype='int32', shape=data[2].shape, name="y") + out = tf.where(condition, x, y) + + compare_tflite_with_tvm([data[0]], ['condition:0'], [condition], [out]) + + +def test_forward_select(): + #tf converter to tflite has bool data type support in tf version 1.15 + if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): + _test_select([[1, 0], [[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + _test_select([[[False, True], [True, False]], [[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + _test_select([[[False, True], [True, False]], [[1, 2], [3, 4]], [[5, 6], [7, 8]]], False) + + #Not supported at topi/relay layer + #_test_where([[[False, False], [True, True]], None, None]) + ####################################################################### # Fully Connected # --------------- @@ -2014,6 +2052,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_stridedslice() test_forward_depthtospace() test_forward_spacetodepth() + test_forward_select() # NN test_forward_convolution() From d82dd75efa7926851585b6da4e585cdc930cb874 Mon Sep 17 00:00:00 2001 From: Dhruva Ray Date: Thu, 30 Apr 2020 20:08:17 +0530 Subject: [PATCH 2/2] changed variable name Signed-off-by: Dhruva Ray --- python/tvm/relay/frontend/tflite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 2c63024aeb8a..6dddd693c28b 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2021,8 +2021,8 @@ def convert_select(self, op): condition, x, y = input_tensors[0], input_tensors[1], input_tensors[1] assert condition.tensor.Type() in (TensorType.INT32, TensorType.INT64, TensorType.BOOL) - for type in [x.tensor.Type(), y.tensor.Type()]: - assert type in (TensorType.INT32, TensorType.INT64) + for t_type in [x.tensor.Type(), y.tensor.Type()]: + assert t_type in (TensorType.INT32, TensorType.INT64) expressions = []