diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 352bc6302ee0..e132d4ca3585 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -121,7 +121,8 @@ def __init__(self, model, subgraph, exp_tab): 'SQUARED_DIFFERENCE': self.convert_squared_difference, 'LOGICAL_AND': self.convert_logical_and, 'LOGICAL_OR': self.convert_logical_or, - 'DETECTION_POSTPROCESS': self.convert_detection_postprocess + 'DETECTION_POSTPROCESS': self.convert_detection_postprocess, + 'SQUARE': self.convert_square, } def check_unsupported_ops(self): @@ -636,6 +637,32 @@ def convert_neg(self, op): 'TFlite quantized NEG operator is not supported yet.') return self._convert_unary_elemwise(_op.negative, op) + def convert_square(self, op): + """Convert TFLite SQUARE""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized SQUARE operator is not supported yet.') + + exp_type = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = _op.power(in_expr, relay.const(2, exp_type)) + + return out + def _convert_elemwise(self, relay_op, op): """Generic method to Convert TFLite elemwise""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 427d4bfe2810..f4b7ee0cd8b1 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -742,6 +742,12 @@ def _test_neg(data): """ One iteration of neg """ return _test_unary_elemwise(math_ops.neg, data) ####################################################################### +# Square +# ------ + +def _test_square(data): + """ One iteration of square """ + return _test_unary_elemwise(math_ops.square, data) def _test_forward_unary_elemwise(test_op): # functions that need positive input @@ -759,6 +765,7 @@ def test_all_unary_elemwise(): _test_forward_unary_elemwise(_test_sqrt) _test_forward_unary_elemwise(_test_rsqrt) _test_forward_unary_elemwise(_test_neg) + _test_forward_unary_elemwise(_test_square) # ceil and cos come with TFLite 1.14.0.post1 fbs schema if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): _test_forward_unary_elemwise(_test_ceil)