From 79ce87f8da5fd5bcfd107e2def243eb02f3819cc Mon Sep 17 00:00:00 2001 From: Ina Dobreva <55383260+inadob@users.noreply.github.com> Date: Wed, 5 Feb 2020 20:12:44 +0000 Subject: [PATCH] [Relay][Frontend][TFLite] Add parser support for logical operators (#4642) * [Relay][Frontend][TFLite] Add parser support for logical operators * Add parser support for logical_and, logical_or * Add boolean dtype as a valid tensor type * BOOLEAN dtype is supported only from tf 1.15 so logical ops work only in that and newer versions * Logical_not is ommited since tflite can't convert it --> throws errors for addv2 * Add TFLite vesion check in tests for logical ops * Check is added because of boolean dtype lack of support --- python/tvm/relay/frontend/tflite.py | 34 ++++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 31 ++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 791c056c4a3d..7e4c37ad8235 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -117,6 +117,8 @@ def __init__(self, model, subgraph, exp_tab): 'PRELU': self.convert_prelu, 'TRANSPOSE_CONV': self.convert_transpose_conv, 'SQUARED_DIFFERENCE': self.convert_squared_difference, + 'LOGICAL_AND': self.convert_logical_and, + 'LOGICAL_OR': self.convert_logical_or, } def check_unsupported_ops(self): @@ -222,6 +224,9 @@ def get_tensor_value(self, tensor_wrapper): if tensor_wrapper.tensor.Type() == TensorType.INT64: return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape( tensor_wrapper.tensor.ShapeAsNumpy()) + if tensor_wrapper.tensor.Type() == TensorType.BOOL: + return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape( + tensor_wrapper.tensor.ShapeAsNumpy()) raise NotImplementedError("Tensor type {} is currently not supported" .format(str(tensor_wrapper.tensor.Type()))) @@ -240,6 +245,8 @@ def get_tensor_type_str(self, tensor_type): return "int32" if tensor_type == TensorType.INT64: return "int64" + if tensor_type == TensorType.BOOL: + return "bool" raise NotImplementedError("Tensor type {} is currently not supported" .format(str(tensor_type))) @@ -792,6 +799,33 @@ def convert_not_equal(self, op): 'TFlite quantized NOT_EQUAL operator is not supported yet.') return self._convert_elemwise(_op.not_equal, op) + def _convert_logical_binary(self, relay_op, op): + """Generic method to convert logical binary ops""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + lhs_tensor = input_tensors[0] + lhs_expr = self.get_expr(lhs_tensor.tensor_idx) + rhs_tensor = input_tensors[1] + rhs_expr = self.get_expr(rhs_tensor.tensor_idx) + out = relay_op(lhs_expr, rhs_expr) + + return out + + def convert_logical_and(self, op): + """Convert tflite LOGICAL_AND""" + return self._convert_logical_binary(_op.logical_and, op) + + def convert_logical_or(self, op): + """Convert tflite LOGICAL_OR""" + return self._convert_logical_binary(_op.logical_or, op) + def convert_zeros_like(self, op): """Convert TFLite ZEROS LIKE""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index aa29cf587bb2..acc25d968f43 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -965,6 +965,34 @@ def test_all_elemwise(): _test_forward_elemwise(_test_equal) _test_forward_elemwise(_test_not_equal) +####################################################################### +# Logical operators +# ----------------- + +def _test_logical_binary(logical_bin_op, data): + + with tf.Graph().as_default(): + in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'), + array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')] + out = logical_bin_op(in_data[0], in_data[1], name='out') + compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out]) + +def _test_forward_logical_and(data): + """ One iteration of logical and """ + return _test_logical_binary(math_ops.logical_and, data) + +def _test_forward_logical_or(data): + """ One iteration of logical or """ + return _test_logical_binary(math_ops.logical_or, data) + +def test_all_logical(): + data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'), + np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')] + # boolean dtype is not supported by older versions than TFLite 1.15.0 + if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): + _test_forward_logical_and(data) + _test_forward_logical_or(data) + ####################################################################### # Zeros like # -------- @@ -1530,6 +1558,9 @@ def test_forward_mediapipe_hand_landmark(): # Reduce test_all_reduce() + # Logical + test_all_logical() + # End to End test_forward_mobilenet_v1() test_forward_mobilenet_v2()