Skip to content

Commit

Permalink
[TFLite] TFLite 2.x parser quantization support.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Jun 24, 2020
1 parent 8931cfa commit bd1cc7c
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 23 deletions.
109 changes: 87 additions & 22 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,46 @@ def get_tensors(self, tensors_idx_list):
qnn_params = None
tflite_qnn_params = tensor.Quantization()
if tflite_qnn_params is not None:
scale = float(tflite_qnn_params.ScaleAsNumpy())
zero_point = int(tflite_qnn_params.ZeroPointAsNumpy())
# Params might be per-tensor or per-axis quantized. For per-tensor, scale and zero
# points are scalar. For per-axis, scale and zero points are tensors. But as per
# TFLite quantization spec, the restrictions on ops suggest that for per-axis, even
# if zero point is a tensor - all the zero points are identical. More infomration
# here - https://www.tensorflow.org/lite/performance/quantization_spec

tflite_scale = tflite_qnn_params.ScaleAsNumpy()
tflite_zero_point = tflite_qnn_params.ZeroPointAsNumpy()
is_qnn_params_valid = True

# Handle Per-axis and per-tensor cases
if isinstance(tflite_scale, np.ndarray):
assert isinstance(tflite_zero_point, np.ndarray)

# Tensor - Per-axis quantization
if tflite_scale.shape != (1,) and tflite_zero_point.shape != (1,):
scale = tflite_scale
# Ensure that all zero points are identical
zero_point = tflite_zero_point
assert all(x == zero_point[0] for x in zero_point)
zero_point = int(zero_point[0])

# Scalar - Per-tensor quantization
elif tflite_scale.shape == (1,) and tflite_zero_point.shape == (1,):
scale = float(tflite_scale[0])
zero_point = int(tflite_zero_point[0])

else:
raise NotImplementedError("Quantized type {} not supported"
.format(type(tflite_scale)))
elif tflite_scale == 0 and tflite_zero_point == 0:
# Handle corner case for ops like quantized reshape whose second operand (shape)
# has zero scale and zero zero point. This is not used.
is_qnn_params_valid = False
else:
raise NotImplementedError("Quantized type {} not supported"
.format(type(tflite_scale)))

# Check that the scale and zero points are valid.
if scale != 0 or zero_point != 0:
if is_qnn_params_valid:
qnn_params = dict()
qnn_params['scale'] = relay.const(scale, 'float32')
qnn_params['zero_point'] = relay.const(zero_point, 'int32')
Expand All @@ -262,21 +298,25 @@ def get_tensor_value(self, tensor_wrapper):
except ImportError:
raise ImportError("The tflite package must be installed")

data = tensor_wrapper.buffer.DataAsNumpy()
shape = tensor_wrapper.tensor.ShapeAsNumpy()

# Set shape to 1 if the data is a scalar type
if data.shape == (1,) and isinstance(shape, int) and shape == 0:
shape = (1,)

if tensor_wrapper.tensor.Type() == TensorType.INT8:
return np.frombuffer(data, dtype=np.int8).reshape(shape)
if tensor_wrapper.tensor.Type() == TensorType.UINT8:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
return np.frombuffer(data, dtype=np.uint8).reshape(shape)
if tensor_wrapper.tensor.Type() == TensorType.INT32:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
return np.frombuffer(data, dtype=np.int32).reshape(shape)
if tensor_wrapper.tensor.Type() == TensorType.INT64:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
return np.frombuffer(data, dtype=np.int64).reshape(shape)
if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
return np.frombuffer(data, dtype=np.float32).reshape(shape)
if tensor_wrapper.tensor.Type() == TensorType.BOOL:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
return np.frombuffer(data, dtype=np.bool).reshape(shape)
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_wrapper.tensor.Type())))

Expand Down Expand Up @@ -1605,7 +1645,7 @@ def convert_fully_connected(self, op):

# weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

if self.has_expr(weight_tensor.tensor_idx):
Expand Down Expand Up @@ -1796,7 +1836,7 @@ def convert_conv(self, op, conv_type):

# weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

in_expr = self.get_expr(input_tensor_idx)
Expand Down Expand Up @@ -1855,9 +1895,15 @@ def convert_conv(self, op, conv_type):
if output_tensor.qnn_params:
# Calculate the intermediate scale and zero point of the int32 output.
data_scale = input_tensor.qnn_params['scale']
weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
weight_scale_val = get_scalar_from_constant(weight_scale)

weight_scale = weight_tensor.qnn_params['scale']
# If weight scale is scalar, it is per-tensor quantization
if isinstance(weight_scale, float):
weight_scale_val = get_scalar_from_constant(weight_scale)
else:
weight_scale_val = get_vector_from_constant(weight_scale)

new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')
Expand All @@ -1868,7 +1914,8 @@ def convert_conv(self, op, conv_type):
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)
out_dtype=output_tensor_type_str,
axis=3)

# Call activation function
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
Expand Down Expand Up @@ -2566,17 +2613,27 @@ def convert_quantize(self, op):
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())

# The output must be quantized
assert output_tensor.qnn_params
# Quantize the input
out = self.quantize(in_expr, output_tensor)

# TFLite Quantize op can also act as Requantize op
if input_tensor_type_str == "float32":
out = self.quantize(in_expr, output_tensor)
else:
out = _qnn.op.requantize(in_expr,
input_scale=input_tensor.qnn_params['scale'],
input_zero_point=input_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)
return out

def convert_dequantize(self, op):
Expand Down Expand Up @@ -2710,7 +2767,6 @@ def get_tensor_expr(self, tensor):
# we can receive as constant.
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)

return expr


Expand All @@ -2723,6 +2779,15 @@ def get_scalar_from_constant(expr):
"value must be float32/int32"
return np.asscalar(value)

def get_vector_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
assert isinstance(expr, _expr.Constant)
value = expr.data.asnumpy()
assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
"value must be float32/int32"
return value



def build_str_map(obj):
"""Build string map of TFLite enum int value
Expand Down
137 changes: 136 additions & 1 deletion tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import variables
import tensorflow_hub as hub
try:
from tensorflow import lite as interpreter_wrapper
except ImportError:
Expand Down Expand Up @@ -73,6 +74,28 @@ def get_real_image(im_height, im_width):
data = np.reshape(x, (1, im_height, im_width, 3))
return data


def pre_processed_image(height, width):
repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
img_name = 'elephant-299.jpg'
image_url = os.path.join(repo_base, img_name)
img_path = download_testdata(image_url, img_name, module='data')
image = tf.io.read_file(img_path)
image = tf.image.decode_jpeg(image, channels=3)
with tf.name_scope('eval_image'):
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.central_crop(image, central_fraction=0.875)
# Resize the image to the specified height and width.
image = tf.expand_dims(image, 0)
image = tf.image.resize(image, [height, width],
align_corners=False)
image = tf.image.resize(image, [height, width])
image = tf.squeeze(image, [0])
image = tf.expand_dims(image, axis=0)
return image


def get_real_image_object_detection(im_height, im_width):
repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/'
img_name = 'street_small.jpg'
Expand Down Expand Up @@ -1707,7 +1730,6 @@ def representative_data_gen():

# Convert the model to TensorFlow Lite format
tflite_model_quant = converter.convert()

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
Expand Down Expand Up @@ -2445,6 +2467,112 @@ def test_forward_qnn_mobilenet_v3_net():
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)



def _quantize_tf_hub_keras_model(url, height, width):
keras_model = tf.keras.Sequential([hub.KerasLayer(url, output_shape=[1001])])
data = pre_processed_image(height, width)

# Set the input shapes of the keras model
keras_model._set_inputs(data)

# Get the converter
converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
for i in range(1):
yield [data]

converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.representative_dataset = representative_data_gen
return converter.convert()


def test_forward_tflite2_qnn_resnet50():
"""Test the Quantized TFLite version 2.1.0 Resnet50 model."""
if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
# Quantize the model
url = "https://tfhub.dev/tensorflow/resnet_50/classification/1"
tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224)
data = pre_processed_image(224, 224)

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


def test_forward_tflite2_qnn_inception_v1():
"""Test the Quantized TFLite version 2.1.0 Inception V1 model."""
if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
# Quantize the model
url = "https://tfhub.dev/google/imagenet/inception_v1/classification/4"
tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224)
data = pre_processed_image(224, 224)

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


def test_forward_tflite2_qnn_inception_v3():
"""Test the Quantized TFLite version 2.1.0 Inception V3 model."""
if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
# Quantize the model
url = "https://tfhub.dev/google/imagenet/inception_v3/classification/4"
tflite_model_buf = _quantize_tf_hub_keras_model(url, 299, 299)
data = pre_processed_image(299, 299)

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


def test_forward_tflite2_qnn_mobilenet_v1():
"""Test the Quantized TFLite version 2.1.0 Mobilenet V1 model."""
if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
# Quantize the model
url = "https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/4"
tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224)
data = pre_processed_image(224, 224)

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


def test_forward_tflite2_qnn_mobilenet_v2():
"""Test the Quantized TFLite version 2.1.0 Mobilenet V2 model."""
if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
# Quantize the model
url = "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/4"
tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224)
data = pre_processed_image(224, 224)

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)


#######################################################################
# Quantized SSD Mobilenet
# -----------------------
Expand Down Expand Up @@ -2662,3 +2790,10 @@ def test_forward_mediapipe_hand_landmark():
#with Tflite 1.15.2
test_forward_qnn_mobilenet_v3_net()
test_forward_qnn_coco_ssd_mobilenet_v1()

# TFLite 2.1.0 quantized tests
test_forward_tflite2_qnn_resnet50()
test_forward_tflite2_qnn_inception_v1()
test_forward_tflite2_qnn_inception_v3()
test_forward_tflite2_qnn_mobilenet_v1()
test_forward_tflite2_qnn_mobilenet_v2()

0 comments on commit bd1cc7c

Please sign in to comment.