diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 981074b6adb2..2a9d66acff07 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1448,11 +1448,7 @@ def convert_less_equal(self, op): def convert_equal(self, op): """Convert TFLite EQUAL""" - if self.is_quantized(op): - raise tvm.error.OpNotImplemented( - "TFlite quantized EQUAL operator is not supported yet." - ) - return self._convert_elemwise(_op.equal, op) + return self._convert_elemwise(_op.equal, op, self.is_quantized(op)) def convert_not_equal(self, op): """Convert TFLite NOT_EQUAL""" diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 76b0766dae28..23b5a03ffb5f 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2214,22 +2214,33 @@ def __test_elemwise(in_data): if None != x[0] } - out = math_op(inq_data[0], inq_data[1]) - out = with_fused_activation_function(out, fused_activation_function) - out = tf.quantization.fake_quant_with_min_max_args( - out, min=out_min, max=out_max, name="out" - ) + if math_op is math_ops.equal: + out = math_op(inq_data[0], inq_data[1]) + out = with_fused_activation_function(out, fused_activation_function) - # Note same_qnn_params uses experimental_new_converter as toco failed - compare_tflite_with_tvm( - [x[1] for x in zip(in_data, data) if None != x[0]], - [x + ":0" for x in input_range.keys()], - [x[1] for x in zip(in_data, inq_data) if None != x[0]], - [out], - quantized=True, - input_range=input_range, - experimental_new_converter=same_qnn_params, - ) + compare_tflite_with_tvm( + [x[1] for x in zip(in_data, data) if None != x[0]], + [x + ":0" for x in input_range.keys()], + [x[1] for x in zip(in_data, inq_data) if None != x[0]], + [out], + ) + else: + out = math_op(inq_data[0], inq_data[1]) + out = with_fused_activation_function(out, fused_activation_function) + out = tf.quantization.fake_quant_with_min_max_args( + out, min=out_min, max=out_max, name="out" + ) + + # Note same_qnn_params uses experimental_new_converter as toco failed + compare_tflite_with_tvm( + [x[1] for x in zip(in_data, data) if None != x[0]], + [x + ":0" for x in input_range.keys()], + [x[1] for x in zip(in_data, inq_data) if None != x[0]], + [out], + quantized=True, + input_range=input_range, + experimental_new_converter=same_qnn_params, + ) else: out = math_op( in_data[0] @@ -2386,9 +2397,16 @@ def _test_less_equal(data): # ----- -def _test_equal(data): +def _test_equal(data, fused_activation_function=None, quantized=False, qnn_op=None): """One iteration of equal""" - return _test_elemwise(math_ops.equal, data) + return _test_elemwise( + math_ops.equal, + data, + fused_activation_function, + quantized, + qnn_op, + same_qnn_params=True, + ) ####################################################################### @@ -2454,14 +2472,25 @@ def _test_forward_elemwise(testop): def _test_forward_elemwise_quantized(testop): - testop( - [ - np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8), - np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8), - ], - quantized=True, - qnn_op=testop, - ) + if testop is not _test_equal: + testop( + [ + np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8), + np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8), + ], + quantized=True, + qnn_op=testop, + ) + else: + # no need for fake_quant to hold tensors in float32 until conversion + testop( + [ + np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.float32), + np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.float32), + ], + quantized=True, + qnn_op=testop, + ) def _test_elemwise_qnn_out_range(qnn_op): @@ -2472,6 +2501,7 @@ def _test_elemwise_qnn_out_range(qnn_op): _test_mul: (-5e3, 5e3), _test_maximum: (-112, 111), _test_minimum: (-128, 127), + _test_equal: (-150, 150), } return qnn_out_range[qnn_op] @@ -2506,6 +2536,7 @@ def test_all_elemwise(): _test_forward_elemwise(_test_less) _test_forward_elemwise(_test_less_equal) _test_forward_elemwise(_test_equal) + _test_forward_elemwise_quantized(_test_equal) _test_forward_elemwise(_test_not_equal) if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"): _test_forward_elemwise(_test_floor_divide)