From fd51d3092bce92525cde4d4617d0fda9548b55c8 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 7 May 2020 12:57:39 +0530 Subject: [PATCH] [TFLITE]Select op support for tflite frontend (#5486) * [TFLITE]Select/Where op support for tflite frontend * Review comment fixed * Review comment fixed --- python/tvm/relay/frontend/tflite.py | 28 ++++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 22 +++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 915240e59344..a55a57f8a32e 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -119,6 +119,7 @@ def __init__(self, model, subgraph, exp_tab): 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, + 'SELECT': self.convert_select, 'SIN': self.convert_sin, 'SLICE': self.convert_slice, 'SOFTMAX': self.convert_softmax, @@ -140,6 +141,7 @@ def __init__(self, model, subgraph, exp_tab): 'TRANSPOSE_CONV': self.convert_transpose_conv, 'TRANSPOSE': self.convert_transpose, 'UNPACK': self.convert_unpack, + 'WHERE': self.convert_select, 'ZEROS_LIKE': self.convert_zeros_like, } @@ -1697,6 +1699,18 @@ def convert_slice(self, op): return out + def convert_select(self, op): + """Convert TFLite SELECT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be == 3" + cond = self.get_tensor_expr(input_tensors[0]) + x = self.get_tensor_expr(input_tensors[1]) + y = self.get_tensor_expr(input_tensors[2]) + + out = _op.where(cond, x, y) + + return out + def convert_transpose(self, op): """transpose implementation.""" input_tensors = self.get_input_tensors(op) @@ -2357,6 +2371,20 @@ def get_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + def get_tensor_expr(self, tensor): + """ Returns constant expr for constant else a tensor expr""" + if self.has_expr(tensor.tensor_idx): + # In most cases, we can assume that TOCO fuses elemwise operators + # with constants - it means both will be tensors. + expr = self.get_expr(tensor.tensor_idx) + else: + # However, in some corner cases, the elemwise operator is not fused, + # we can receive as constant. + type_str = self.get_tensor_type_str(tensor.tensor.Type()) + expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str) + + return expr + def get_scalar_from_constant(expr): """ Returns scalar value from Relay constant scalar. """ diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 283d87d5078a..725774895723 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1378,6 +1378,27 @@ def test_all_reduce(): ####################################################################### +# Select, Where +# ------------- + +def test_forward_select(): + with tf.Graph().as_default(): + with tf.Session() as sess: + input1 = tf.placeholder( + tf.int32, shape=[1, 4, 4, 3], name='input1') + input2 = tf.placeholder( + tf.int32, shape=[1, 4, 4, 3], name='input2') + mask = input1 > input2 + out = tf.where(mask, input1 + 1, input2 * 2) + in_data1 = np.random.uniform( + 0, 10, size=(1, 4, 4, 3)).astype("int32") + in_data2 = np.random.uniform( + 0, 10, size=(1, 4, 4, 3)).astype("int32") + + compare_tflite_with_tvm([in_data1, in_data2], [ + 'input1:0', 'input2:0'], [input1, input2], [out]) + + # Squeeze # ------- @@ -1997,6 +2018,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_stridedslice() test_forward_depthtospace() test_forward_spacetodepth() + test_forward_select() # NN test_forward_convolution()