From 5902568705ca1e4cb0990fa28a97b739f595b2f0 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 30 Apr 2020 13:37:49 +0530 Subject: [PATCH] [FRONTEND][TFLITE]Logical not op support (#5475) --- python/tvm/relay/frontend/tflite.py | 11 +++++++++++ tests/python/frontend/tflite/test_forward.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 66d0ff326ce0..5c8bbfb3c8f9 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -94,6 +94,7 @@ def __init__(self, model, subgraph, exp_tab): 'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn, 'LOG': self.convert_log, 'LOGICAL_AND': self.convert_logical_and, + 'LOGICAL_NOT': self.convert_logical_not, 'LOGICAL_OR': self.convert_logical_or, 'LOGISTIC': self.convert_logistic, 'MAX_POOL_2D': self.convert_max_pool2d, @@ -992,6 +993,16 @@ def convert_logical_or(self, op): """Convert tflite LOGICAL_OR""" return self._convert_logical_binary(_op.logical_or, op) + def convert_logical_not(self, op): + """Convert tflite LOGICAL_NOT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + data = self.get_expr(input_tensors[0].tensor_idx) + out = _op.logical_not(data) + + return out + def convert_gather(self, op): """Method to Convert TFLite GATHER operator""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ae2a4c6b5f3d..75146c3cc74a 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1183,7 +1183,12 @@ def _test_logical_binary(logical_bin_op, data): with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'), array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')] - out = logical_bin_op(in_data[0], in_data[1], name='out') + if logical_bin_op == math_ops.logical_not: + out = math_ops.logical_or(in_data[0], in_data[1], name='out1') + out = logical_bin_op(out, name='out') + else: + out = logical_bin_op(in_data[0], in_data[1], name='out') + compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out]) def _test_forward_logical_and(data): @@ -1194,6 +1199,10 @@ def _test_forward_logical_or(data): """ One iteration of logical or """ return _test_logical_binary(math_ops.logical_or, data) +def _test_forward_logical_not(data): + """ One iteration of logical not """ + return _test_logical_binary(math_ops.logical_not, data) + def test_all_logical(): data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'), np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')] @@ -1201,6 +1210,7 @@ def test_all_logical(): if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): _test_forward_logical_and(data) _test_forward_logical_or(data) + _test_forward_logical_not(data) ####################################################################### # Zeros like