diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index be18bf622196..2025a1f9c4c0 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -59,6 +59,16 @@ def __init__(self, model, subgraph, exp_tab): # Add more operators self.convert_map = { + 'ABS': self.convert_abs, + 'EXP': self.convert_exp, + 'FLOOR': self.convert_floor, + 'CEIL': self.convert_ceil, + 'LOG': self.convert_log, + 'SIN': self.convert_sin, + 'COS': self.convert_cos, + 'SQRT': self.convert_sqrt, + 'RSQRT': self.convert_rsqrt, + 'NEG': self.convert_neg, 'CONV_2D': self.convert_conv2d, 'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d, 'AVERAGE_POOL_2D': self.convert_average_pool2d, @@ -483,6 +493,93 @@ def convert_concatenation(self, op): .format('qnn.op.concatenate')) return out + def _convert_unary_elemwise(self, relay_op, op): + """Generic method to convert TFLite unary elemwise functions""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + out = relay_op(in_expr) + + return out + + def convert_abs(self, op): + """Convert TFLite ABS""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized ABS operator is not supported yet.') + return self._convert_unary_elemwise(_op.abs, op) + + def convert_ceil(self, op): + """Convert TFLite CEIL""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized CEIL operator is not supported yet.') + return self._convert_unary_elemwise(_op.ceil, op) + + def convert_floor(self, op): + """Convert TFLite FLOOR""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized FLOOR operator is not supported yet.') + return self._convert_unary_elemwise(_op.floor, op) + + def convert_exp(self, op): + """Convert TFLite EXP""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized EXP operator is not supported yet.') + return self._convert_unary_elemwise(_op.exp, op) + + def convert_log(self, op): + """Convert TFLite LOG""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized LOG operator is not supported yet.') + return self._convert_unary_elemwise(_op.log, op) + + def convert_sin(self, op): + """Convert TFLite SIN""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized SIN operator is not supported yet.') + return self._convert_unary_elemwise(_op.sin, op) + + def convert_cos(self, op): + """Convert TFLite COS""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized COS operator is not supported yet.') + return self._convert_unary_elemwise(_op.cos, op) + + def convert_sqrt(self, op): + """Convert TFLite SQRT""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized SQRT operator is not supported yet.') + return self._convert_unary_elemwise(_op.sqrt, op) + + def convert_rsqrt(self, op): + """Convert TFLite RSQRT""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized RSQRT operator is not supported yet.') + return self._convert_unary_elemwise(_op.rsqrt, op) + + def convert_neg(self, op): + """Convert TFLite NEG""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized NEG operator is not supported yet.') + return self._convert_unary_elemwise(_op.negative, op) + def _convert_elemwise(self, relay_op, op): """Generic method to Convert TFLite elemwise""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 12ea429983e8..c09740ae1b66 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -616,6 +616,109 @@ def test_forward_concatenation(): np.arange(6).reshape((2, 1, 1, 3)), np.arange(6).reshape((2, 1, 1, 3))], 1) +####################################################################### +# Unary elemwise +# -------------- + +def _test_unary_elemwise(math_op, data): + """ One iteration of unary elemwise """ + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name='in') + out = math_op(in_data) + compare_tflite_with_tvm(data, ['in:0'], in_data, [out]) + +####################################################################### +# Abs +# --- + +def _test_abs(data): + """ One iteration of abs """ + return _test_unary_elemwise(math_ops.abs, data) +####################################################################### +# Ceil +# ---- + +def _test_ceil(data): + """ One iteration of ceil """ + return _test_unary_elemwise(math_ops.ceil, data) +####################################################################### +# Floor +# ----- + +def _test_floor(data): + """ One iteration of floor """ + return _test_unary_elemwise(math_ops.floor, data) +####################################################################### +# Exp +# --- + +def _test_exp(data): + """ One iteration of exp """ + return _test_unary_elemwise(math_ops.exp, data) +####################################################################### +# Log +# --- + +def _test_log(data): + """ One iteration of log """ + return _test_unary_elemwise(math_ops.log, data) +####################################################################### +# Sin +# --- + +def _test_sin(data): + """ One iteration of sin """ + return _test_unary_elemwise(math_ops.sin, data) +####################################################################### +# Cos +# --- + +def _test_cos(data): + """ One iteration of cos """ + return _test_unary_elemwise(math_ops.cos, data) +####################################################################### +# Sqrt +# ---- + +def _test_sqrt(data): + """ One iteration of sqrt """ + return _test_unary_elemwise(math_ops.sqrt, data) +####################################################################### +# Rsqrt +# ----- + +def _test_rsqrt(data): + """ One iteration of rsqrt """ + return _test_unary_elemwise(math_ops.rsqrt, data) +####################################################################### +# Neg +# --- + +def _test_neg(data): + """ One iteration of neg """ + return _test_unary_elemwise(math_ops.neg, data) +####################################################################### + +def _test_forward_unary_elemwise(test_op): + # functions that need positive input + if test_op in {'_test_log', '_test_sqrt', '_test_rsqrt'}: + test_op(np.arange(6.0, dtype=np.float32).reshape((2, 1, 3))) + test_op(np.arange(6.0, dtype=np.int32).reshape((2, 1, 3))) + else: + np.array(np.random.uniform(-5, 5, (3, 1)), dtype=np.int32) + +def test_all_unary_elemwise(): + _test_forward_unary_elemwise(_test_abs) + _test_forward_unary_elemwise(_test_ceil) + _test_forward_unary_elemwise(_test_floor) + _test_forward_unary_elemwise(_test_exp) + _test_forward_unary_elemwise(_test_log) + _test_forward_unary_elemwise(_test_sin) + _test_forward_unary_elemwise(_test_cos) + _test_forward_unary_elemwise(_test_sqrt) + _test_forward_unary_elemwise(_test_rsqrt) + _test_forward_unary_elemwise(_test_neg) ####################################################################### # Element-wise @@ -1320,6 +1423,9 @@ def test_forward_mediapipe_hand_landmark(): # Elemwise test_all_elemwise() + # Unary elemwise + test_all_unary_elemwise() + # Zeros Like test_forward_zeros_like()