From b0934bc0e656bf53610d861f0520ad896b6303cd Mon Sep 17 00:00:00 2001 From: wyc-ruiker Date: Wed, 8 Jan 2020 13:38:03 +0800 Subject: [PATCH] [Relay][Frontend][TFLite] Add parser support for squared difference --- python/tvm/relay/frontend/tflite.py | 10 ++++++++++ tests/python/frontend/tflite/test_forward.py | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 284a8c8cd8582..41e221c3a680a 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -100,6 +100,7 @@ def __init__(self, model, subgraph, exp_tab): 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd, 'PRELU': self.convert_prelu, 'TRANSPOSE_CONV': self.convert_transpose_conv, + 'SQUARED_DIFFERENCE': self.convert_squared_difference, } def check_unsupported_ops(self): @@ -628,6 +629,15 @@ def convert_greater(self, op): 'TFlite quantized greater operator is not supported yet.') return self._convert_elemwise(_op.greater, op) + def convert_squared_difference(self, op): + # Check if the input tensor is quantized, call QNN op + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized greater operator is not supported yet.') + difference = self._convert_elemwise(_op.subtract, op) + out = _op.multiply(difference, difference) + return out + def convert_zeros_like(self, op): """Convert TFLite ZEROS LIKE""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index fd43ff3f42290..bc0cf47f91ffe 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -723,6 +723,14 @@ def _test_greater(data): """ One iteration of greater """ return _test_elemwise(math_ops.greater, data) +####################################################################### +# Squared_difference +# ------- + +def _test_squared_difference(data): + """ One iteration of squared difference """ + return _test_elemwise(math_ops.squared_difference, data) + def _test_forward_elemwise(testop): """ Elewise""" testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)), @@ -765,6 +773,7 @@ def test_all_elemwise(): _test_forward_elemwise(_test_maximum) _test_forward_elemwise(_test_minimum) _test_forward_elemwise(_test_greater) + _test_forward_elemwise(_test_squared_difference) ####################################################################### # Zeros like