Skip to content

Commit

Permalink
[TFLite] Support quantized EQUAL op in TFLite frontend (apache#11520)
Browse files Browse the repository at this point in the history
* [TFLite] Support quantized EQUAL op in TFLite frontend

Support EQUAL quantization operation conversion as part of issue apache#9187

* [TFLite] Support quantized EQUAL op in TFLite frontend

Update elementwise quantized test for EQUAL op
Change-Id: I3897d1ac07051ebfc10356ad45397117b592f878
  • Loading branch information
dchauhan-arm authored and blackkker committed Jul 7, 2022
1 parent 1464a31 commit f18854b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 30 deletions.
6 changes: 1 addition & 5 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,11 +1448,7 @@ def convert_less_equal(self, op):

def convert_equal(self, op):
"""Convert TFLite EQUAL"""
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
"TFlite quantized EQUAL operator is not supported yet."
)
return self._convert_elemwise(_op.equal, op)
return self._convert_elemwise(_op.equal, op, self.is_quantized(op))

def convert_not_equal(self, op):
"""Convert TFLite NOT_EQUAL"""
Expand Down
81 changes: 56 additions & 25 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2214,22 +2214,33 @@ def __test_elemwise(in_data):
if None != x[0]
}

out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)
out = tf.quantization.fake_quant_with_min_max_args(
out, min=out_min, max=out_max, name="out"
)
if math_op is math_ops.equal:
out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)

# Note same_qnn_params uses experimental_new_converter as toco failed
compare_tflite_with_tvm(
[x[1] for x in zip(in_data, data) if None != x[0]],
[x + ":0" for x in input_range.keys()],
[x[1] for x in zip(in_data, inq_data) if None != x[0]],
[out],
quantized=True,
input_range=input_range,
experimental_new_converter=same_qnn_params,
)
compare_tflite_with_tvm(
[x[1] for x in zip(in_data, data) if None != x[0]],
[x + ":0" for x in input_range.keys()],
[x[1] for x in zip(in_data, inq_data) if None != x[0]],
[out],
)
else:
out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)
out = tf.quantization.fake_quant_with_min_max_args(
out, min=out_min, max=out_max, name="out"
)

# Note same_qnn_params uses experimental_new_converter as toco failed
compare_tflite_with_tvm(
[x[1] for x in zip(in_data, data) if None != x[0]],
[x + ":0" for x in input_range.keys()],
[x[1] for x in zip(in_data, inq_data) if None != x[0]],
[out],
quantized=True,
input_range=input_range,
experimental_new_converter=same_qnn_params,
)
else:
out = math_op(
in_data[0]
Expand Down Expand Up @@ -2386,9 +2397,16 @@ def _test_less_equal(data):
# -----


def _test_equal(data):
def _test_equal(data, fused_activation_function=None, quantized=False, qnn_op=None):
"""One iteration of equal"""
return _test_elemwise(math_ops.equal, data)
return _test_elemwise(
math_ops.equal,
data,
fused_activation_function,
quantized,
qnn_op,
same_qnn_params=True,
)


#######################################################################
Expand Down Expand Up @@ -2454,14 +2472,25 @@ def _test_forward_elemwise(testop):


def _test_forward_elemwise_quantized(testop):
testop(
[
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
],
quantized=True,
qnn_op=testop,
)
if testop is not _test_equal:
testop(
[
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
],
quantized=True,
qnn_op=testop,
)
else:
# no need for fake_quant to hold tensors in float32 until conversion
testop(
[
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.float32),
np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.float32),
],
quantized=True,
qnn_op=testop,
)


def _test_elemwise_qnn_out_range(qnn_op):
Expand All @@ -2472,6 +2501,7 @@ def _test_elemwise_qnn_out_range(qnn_op):
_test_mul: (-5e3, 5e3),
_test_maximum: (-112, 111),
_test_minimum: (-128, 127),
_test_equal: (-150, 150),
}

return qnn_out_range[qnn_op]
Expand Down Expand Up @@ -2506,6 +2536,7 @@ def test_all_elemwise():
_test_forward_elemwise(_test_less)
_test_forward_elemwise(_test_less_equal)
_test_forward_elemwise(_test_equal)
_test_forward_elemwise_quantized(_test_equal)
_test_forward_elemwise(_test_not_equal)
if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
_test_forward_elemwise(_test_floor_divide)
Expand Down

0 comments on commit f18854b

Please sign in to comment.