diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 5902b92c3f56..791c056c4a3d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -89,6 +89,11 @@ def __init__(self, model, subgraph, exp_tab): 'MAXIMUM': self.convert_maximum, 'MINIMUM': self.convert_minimum, 'GREATER': self.convert_greater, + 'GREATER_EQUAL': self.convert_greater_equal, + 'LESS': self.convert_less, + 'LESS_EQUAL': self.convert_less_equal, + 'EQUAL': self.convert_equal, + 'NOT_EQUAL': self.convert_not_equal, 'ZEROS_LIKE': self.convert_zeros_like, 'REDUCE_MIN': self._convert_reduce_min, 'REDUCE_MAX': self._convert_reduce_max, @@ -690,7 +695,7 @@ def convert_sub(self, op): # Check if the input tensor is quantized, call QNN op if self.is_quantized(op): raise tvm.error.OpNotImplemented( - 'TFlite quantized sub operator is not supported yet.') + 'TFlite quantized SUB operator is not supported yet.') return self._convert_elemwise(_op.subtract, op) def convert_mul(self, op): @@ -705,38 +710,43 @@ def convert_div(self, op): # Check if the input tensor is quantized, call QNN op if self.is_quantized(op): raise tvm.error.OpNotImplemented( - 'TFlite quantized div operator is not supported yet.') + 'TFlite quantized DIV operator is not supported yet.') return self._convert_elemwise(_op.divide, op) def convert_pow(self, op): + """Convert TFLite POW""" # Check if the input tensor is quantized, call QNN op if self.is_quantized(op): raise tvm.error.OpNotImplemented( - 'TFlite quantized pow operator is not supported yet.') + 'TFlite quantized POW operator is not supported yet.') return self._convert_elemwise(_op.power, op) def convert_maximum(self, op): + """Convert TFLite MAXIMUM""" # Check if the input tensor is quantized, call QNN op if self.is_quantized(op): raise tvm.error.OpNotImplemented( - 'TFlite quantized maximum operator is not supported yet.') + 'TFlite quantized MAXIMUM operator is not supported yet.') return self._convert_elemwise(_op.maximum, op) def convert_minimum(self, op): + """Convert TFLite MINIMUM""" # Check if the input tensor is quantized, call QNN op if self.is_quantized(op): raise tvm.error.OpNotImplemented( - 'TFlite quantized minimum operator is not supported yet.') + 'TFlite quantized MINIMUM operator is not supported yet.') return self._convert_elemwise(_op.minimum, op) def convert_greater(self, op): + """Convert TFLite GREATER""" # Check if the input tensor is quantized, call QNN op if self.is_quantized(op): raise tvm.error.OpNotImplemented( - 'TFlite quantized greater operator is not supported yet.') + 'TFlite quantized GREATER operator is not supported yet.') return self._convert_elemwise(_op.greater, op) def convert_squared_difference(self, op): + """Convert TFLite SQUARED DIFFERENCE""" # Check if the input tensor is quantized, call QNN op if self.is_quantized(op): raise tvm.error.OpNotImplemented( @@ -747,6 +757,41 @@ def convert_squared_difference(self, op): out = _op.power(difference, relay.const(2, exp_type)) return out + def convert_greater_equal(self, op): + """Convert TFLite GREATER_EQUAL""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized GREATER_EQUAL operator is not supported yet.') + return self._convert_elemwise(_op.greater_equal, op) + + def convert_less(self, op): + """Convert TFLite LESS""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized LESS operator is not supported yet.') + return self._convert_elemwise(_op.less, op) + + def convert_less_equal(self, op): + """Convert TFLite LESS_EQUAL""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized LESS_EQUAL operator is not supported yet.') + return self._convert_elemwise(_op.less_equal, op) + + def convert_equal(self, op): + """Convert TFLite EQUAL""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized EQUAL operator is not supported yet.') + return self._convert_elemwise(_op.equal, op) + + def convert_not_equal(self, op): + """Convert TFLite NOT_EQUAL""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized NOT_EQUAL operator is not supported yet.') + return self._convert_elemwise(_op.not_equal, op) + def convert_zeros_like(self, op): """Convert TFLite ZEROS LIKE""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index b7550f40af1e..9835bfcb46bf 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -863,7 +863,41 @@ def _test_minimum(data): def _test_greater(data): """ One iteration of greater """ return _test_elemwise(math_ops.greater, data) +####################################################################### +# Greater_equal +# ------------- + +def _test_greater_equal(data): + """ One iteration of greater_equal """ + return _test_elemwise(math_ops.greater_equal, data) +####################################################################### +# Less +# ---- + +def _test_less(data): + """ One iteration of less """ + return _test_elemwise(math_ops.less, data) +####################################################################### +# Less_equal +# ---------- + +def _test_less_equal(data): + """ One iteration of less_equal """ + return _test_elemwise(math_ops.less_equal, data) +####################################################################### +# Equal +# ----- + +def _test_equal(data): + """ One iteration of equal """ + return _test_elemwise(math_ops.equal, data) +####################################################################### +# Not_equal +# --------- +def _test_not_equal(data): + """ One iteration of not_equal""" + return _test_elemwise(math_ops.not_equal, data) ####################################################################### # Squared_difference # ------------------ @@ -915,6 +949,11 @@ def test_all_elemwise(): _test_forward_elemwise(_test_minimum) _test_forward_elemwise(_test_greater) _test_forward_elemwise(_test_squared_difference) + _test_forward_elemwise(_test_greater_equal) + _test_forward_elemwise(_test_less) + _test_forward_elemwise(_test_less_equal) + _test_forward_elemwise(_test_equal) + _test_forward_elemwise(_test_not_equal) ####################################################################### # Zeros like