Skip to content

Commit

Permalink
Added support for tflite quantized maximum and minimum (apache#6018)
Browse files Browse the repository at this point in the history
* Added support for tflite quantized maximum and minimum

* Unit test simplified

Bugfix in unit test. Unit test slightly simplified

* re-trigger CI

* renamed use_real_qnn to ignore_qnn_params
  • Loading branch information
d-smirnov authored and Trevor Morris committed Aug 26, 2020
1 parent 4a03e1d commit 4e9fe27
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 74 deletions.
42 changes: 23 additions & 19 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,16 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
rhs_scale = rhs_tensor.qnn_params['scale']
lhs_zero_point = lhs_tensor.qnn_params['zero_point']
rhs_zero_point = rhs_tensor.qnn_params['zero_point']
lhs_scale_value = get_scalar_from_constant(lhs_scale)
rhs_scale_value = get_scalar_from_constant(rhs_scale)
lhs_zero_point_value = get_scalar_from_constant(lhs_zero_point)
rhs_zero_point_value = get_scalar_from_constant(rhs_zero_point)
return lhs_scale_value == rhs_scale_value and \
lhs_zero_point_value == rhs_zero_point_value
# 0.1 + 0.2 != 0.3
return np.allclose(lhs_scale.data.asnumpy(),
rhs_scale.data.asnumpy(),
rtol=1e-5,
atol=1e-5) \
and \
np.allclose(lhs_zero_point.data.asnumpy(),
rhs_zero_point.data.asnumpy(),
rtol=1e-5,
atol=1e-5)

def is_quantized(self, op):
"""Check if an input tensor is quantized."""
Expand Down Expand Up @@ -1109,7 +1113,7 @@ def convert_square(self, op):

return out

def _convert_elemwise(self, relay_op, op):
def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):
"""Generic method to Convert TFLite elemwise"""
try:
from tflite.AddOptions import AddOptions
Expand All @@ -1132,8 +1136,16 @@ def _convert_elemwise(self, relay_op, op):
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

# TFLite format demands equal scale and zero_point tuple parameters for some operations
# to allow us to use non-quantized operation instead of quantized if ignore_qnn_params=True
if ignore_qnn_params:
assert lhs_tensor.qnn_params \
and self.has_same_qnn_params(lhs_tensor, output_tensor) \
and self.has_same_qnn_params(rhs_tensor, output_tensor), \
"All tensors should be quantized with the same (scale,zero-point) tuple parameters"

# If quantized, extracts qnn params and call QNN add operator.
if lhs_tensor.qnn_params:
if not ignore_qnn_params and lhs_tensor.qnn_params:
assert rhs_tensor.qnn_params, "Both tensors should be quantized."
assert output_tensor.qnn_params, "Output tensor should be quantized."
out = relay_op(lhs=lhs_expr,
Expand Down Expand Up @@ -1164,7 +1176,7 @@ def _convert_elemwise(self, relay_op, op):
fused_activation_fn = options.FusedActivationFunction()

# Handle fused activations
if output_tensor.qnn_params:
if not ignore_qnn_params and output_tensor.qnn_params:
scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
Expand Down Expand Up @@ -1231,19 +1243,11 @@ def convert_pow(self, op):

def convert_maximum(self, op):
"""Convert TFLite MAXIMUM"""
# Check if the input tensor is quantized, call QNN op
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized MAXIMUM operator is not supported yet.')
return self._convert_elemwise(_op.maximum, op)
return self._convert_elemwise(_op.maximum, op, self.is_quantized(op))

def convert_minimum(self, op):
"""Convert TFLite MINIMUM"""
# Check if the input tensor is quantized, call QNN op
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized MINIMUM operator is not supported yet.')
return self._convert_elemwise(_op.minimum, op)
return self._convert_elemwise(_op.minimum, op, self.is_quantized(op))

def convert_greater(self, op):
"""Convert TFLite GREATER"""
Expand Down
118 changes: 63 additions & 55 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@
from tvm import relay
try:
import tensorflow.compat.v1 as tf
# tensorflow.python.framework.ops module itself is not part of
# TensorFlow's public API: the precise contents of that module
# may vary from one version to the next
import tensorflow.compat.v1 as ops
except ImportError:
import tensorflow as tf
import tensorflow as ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops

from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
Expand Down Expand Up @@ -235,7 +240,8 @@ def run_tflite_graph(tflite_model_buf, input_data):

def compare_tflite_with_tvm(in_data, in_name, input_tensors,
output_tensors, init_global_variables=False,
out_names=None, quantized=False, input_range=None, mode='graph_runtime'):
out_names=None, quantized=False, input_range=None,
mode='graph_runtime', experimental_new_converter=False):
"""Generic function to generate and compare TFLite and TVM output"""
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
Expand All @@ -250,7 +256,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
# convert to tflite model
converter = tf.lite.TFLiteConverter.from_session(
sess, input_tensors, output_tensors)

converter.experimental_new_converter=experimental_new_converter
if quantized:
converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
input_arrays = converter.get_input_arrays()
Expand Down Expand Up @@ -1300,70 +1306,68 @@ def test_all_unary_elemwise():
# Element-wise
# ------------

def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None):
def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None, same_qnn_params=False):
""" One iteration of elemwise """

assert len(data) == 2

# Test with two tensors
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')]

def __test_elemwise( in_data ):
assert 2 == len( in_data )
if quantized:
# fake_quant will keep the tensors in float32 until the conversion in the session
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0"),
tf.quantization.fake_quant_with_min_max_args(in_data[1], min=-50, max=50, name="inq_1")]
input_range = {'inq_0': (-100, 100), 'inq_1': (-50, 50)}
out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)
# set the fp32 output range with respect to the operation
out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
compare_tflite_with_tvm(data, ['inq_0:0', 'inq_1:0'], inq_data, [out],
quantized=True, input_range=input_range)
else:
out = math_op(in_data[0], in_data[1])
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out])
inq0_min, inq0_max = (-100, 100)
inq1_min, inq1_max = (-50, 50)

# Test with tensor and constant
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0')]
# if requested use same quantization parameters provided by _test_elemwise_qnn_out_range
if same_qnn_params:
inq0_min, inq0_max = (out_min, out_max)
inq1_min, inq1_max = (out_min, out_max)

if quantized:
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0")]
inq_const = tf.quantization.fake_quant_with_min_max_args(data[1], min=-50, max=50, name="const_tensor")
input_range = {'inq_0': (-100, 100)}
# the 2nd tensor is treated as constant and directly added as part of the operation
out = math_op(inq_data, ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'))
# fake_quant will keep the tensors in float32 until the conversion in the session
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=out_min, max=out_max, name="inq_0")\
if None != in_data[0]\
else tf.quantization.fake_quant_with_min_max_args(data[0], min=out_min, max=out_max, name="const_tensor0"),
tf.quantization.fake_quant_with_min_max_args(in_data[1], min=out_min, max=out_max, name="inq_1")\
if None != in_data[1]\
else tf.quantization.fake_quant_with_min_max_args(data[1], min=out_min, max=out_max, name="const_tensor1")]

input_range = {x[1][0]:x[1][1] for x in zip(in_data, (('inq_0', (inq0_min, inq0_max)),\
('inq_1', (inq1_min, inq1_max)))) if None != x[0]}

out = math_op(inq_data[0], inq_data[1])
out = with_fused_activation_function(out, fused_activation_function)
out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
compare_tflite_with_tvm(data[0], ['inq_0:0'], inq_data, [out], quantized=True, input_range=input_range)

# Note same_qnn_params uses experimental_new_converter as toco failed
compare_tflite_with_tvm([x[1] for x in zip(in_data, data) if None != x[0]],
[x + ":0" for x in input_range.keys()],
[x[1] for x in zip(in_data, inq_data) if None != x[0]],
[out],
quantized=True,
input_range=input_range,
experimental_new_converter=same_qnn_params)
else:
out = math_op(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
out = math_op(in_data[0] if None != in_data[0] else ops.convert_to_tensor(data[0], dtype=data[0].dtype),
in_data[1] if None != in_data[1] else ops.convert_to_tensor(data[1], dtype=data[1].dtype))
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data[0], ['in_0:0'], in_data, [out])
compare_tflite_with_tvm([x[1] for x in zip( in_data, data ) if None != x[0]],
[x[1] for x in zip( in_data, ('in_0:0', 'in_1:0') ) if None != x[0]],
[x for x in in_data if None != x],
[out])

# Test with two tensors
with tf.Graph().as_default():
__test_elemwise( in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0'),
array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')])
# Test with tensor and constant
with tf.Graph().as_default():
__test_elemwise( in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0'),
None])
# Test with constant and tensor
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')]

if quantized:
inq_const = tf.quantization.fake_quant_with_min_max_args(data[0], min=-100, max=100, name="const_tensor")
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-50, max=50, name="inq_1")]
input_range = {'inq_1': (-50, 50)}
# the 1st tensor is treated as constant and directly added as part of the operation
out = math_op(ops.convert_to_tensor(inq_const, dtype='float32', name='inq_const'), inq_data)
out = with_fused_activation_function(out, fused_activation_function)
out_min, out_max = _test_elemwise_qnn_out_range(qnn_op)
out = tf.quantization.fake_quant_with_min_max_args(out, min=out_min, max=out_max, name="out")
compare_tflite_with_tvm(data[1], ['inq_1:0'], inq_data, [out], quantized=True, input_range=input_range)
else:
out = math_op(ops.convert_to_tensor(data[0], dtype=data[0].dtype), in_data[0])
out = with_fused_activation_function(out, fused_activation_function)
compare_tflite_with_tvm(data[1], ['in_1:0'], in_data, [out])
__test_elemwise( in_data = [None,
array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')])

#######################################################################
# Add
Expand Down Expand Up @@ -1406,16 +1410,16 @@ def _test_pow(data):
# Maximum
# -------

def _test_maximum(data):
def _test_maximum(data, fused_activation_function=None, quantized=False, qnn_op=None):
""" One iteration of maximum """
return _test_elemwise(math_ops.maximum, data)
return _test_elemwise(math_ops.maximum, data, fused_activation_function, quantized, qnn_op, same_qnn_params=True)
#######################################################################
# Minimum
# -------

def _test_minimum(data):
def _test_minimum(data, fused_activation_function=None, quantized=False, qnn_op=None):
""" One iteration of minimum """
return _test_elemwise(math_ops.minimum, data)
return _test_elemwise(math_ops.minimum, data, fused_activation_function, quantized, qnn_op, same_qnn_params=True)
#######################################################################
# Greater
# -------
Expand Down Expand Up @@ -1501,6 +1505,8 @@ def _test_elemwise_qnn_out_range(qnn_op):
_test_add: (-150, 150),
_test_sub: (-150, 150),
_test_mul: (-5e+3, 5e+3),
_test_maximum: (-112, 111),
_test_minimum: (-128, 127)
}

return qnn_out_range[qnn_op]
Expand All @@ -1525,7 +1531,9 @@ def test_all_elemwise():
_test_forward_elemwise(partial(_test_div, fused_activation_function="RELU6"))
_test_forward_elemwise(_test_pow)
_test_forward_elemwise(_test_maximum)
_test_forward_elemwise_quantized(_test_maximum)
_test_forward_elemwise(_test_minimum)
_test_forward_elemwise_quantized(_test_minimum)
_test_forward_elemwise(_test_greater)
_test_forward_elemwise(_test_squared_difference)
_test_forward_elemwise(_test_greater_equal)
Expand Down

0 comments on commit 4e9fe27

Please sign in to comment.