From b2e35abe2cff299373601155d827b3b01e8578e1 Mon Sep 17 00:00:00 2001 From: Wang Yucheng Date: Thu, 16 Jan 2020 23:33:11 +0800 Subject: [PATCH] [Relay][Frontend][TFLite] Add parser support for squared difference (#4652) * [Relay][Frontend][TFLite] Add parser support for squared difference * fix some error * fix exp_type * add comment --- python/tvm/relay/frontend/tflite.py | 12 ++++++++++++ tests/python/frontend/tflite/test_forward.py | 9 +++++++++ 2 files changed, 21 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 02b8ed980c0b..5902b92c3f56 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -111,6 +111,7 @@ def __init__(self, model, subgraph, exp_tab): 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd, 'PRELU': self.convert_prelu, 'TRANSPOSE_CONV': self.convert_transpose_conv, + 'SQUARED_DIFFERENCE': self.convert_squared_difference, } def check_unsupported_ops(self): @@ -735,6 +736,17 @@ def convert_greater(self, op): 'TFlite quantized greater operator is not supported yet.') return self._convert_elemwise(_op.greater, op) + def convert_squared_difference(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized squared difference operator is not supported yet.') + difference = self._convert_elemwise(_op.subtract, op) + # _convert_elemwise has guaranteed only have one output tensor + exp_type = self.get_tensor_type_str(self.get_output_tensors(op)[0].tensor.Type()) + out = _op.power(difference, relay.const(2, exp_type)) + return out + def convert_zeros_like(self, op): """Convert TFLite ZEROS LIKE""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 837f0f611018..b7550f40af1e 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -864,6 +864,14 @@ def _test_greater(data): """ One iteration of greater """ return _test_elemwise(math_ops.greater, data) +####################################################################### +# Squared_difference +# ------------------ + +def _test_squared_difference(data): + """ One iteration of squared difference """ + return _test_elemwise(math_ops.squared_difference, data) + def _test_forward_elemwise(testop): """ Elewise""" testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), @@ -906,6 +914,7 @@ def test_all_elemwise(): _test_forward_elemwise(_test_maximum) _test_forward_elemwise(_test_minimum) _test_forward_elemwise(_test_greater) + _test_forward_elemwise(_test_squared_difference) ####################################################################### # Zeros like