Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLITE]Quantize & Dequantize op #5394

Merged
merged 3 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, model, subgraph, exp_tab):
'COS': self.convert_cos,
'DEPTH_TO_SPACE': self.convert_depth_to_space,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'DEQUANTIZE': self.convert_dequantize,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'DIV': self.convert_div,
'ELU': self.convert_elu,
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(self, model, subgraph, exp_tab):
'PAD': self.convert_pad,
'POW': self.convert_pow,
'PRELU': self.convert_prelu,
'QUANTIZE': self.convert_quantize,
'REDUCE_ANY': self.convert_reduce_any,
'REDUCE_MAX': self.convert_reduce_max,
'REDUCE_MIN': self.convert_reduce_min,
Expand Down Expand Up @@ -277,6 +279,8 @@ def get_tensor_type_str(self, tensor_type):
except ImportError:
raise ImportError("The tflite package must be installed")

if tensor_type == TensorType.INT8:
return "int8"
if tensor_type == TensorType.UINT8:
return "uint8"
if tensor_type == TensorType.FLOAT32:
Expand Down Expand Up @@ -2355,6 +2359,40 @@ def convert_transpose_conv(self, op):

return out

def convert_quantize(self, op):
"""Convert TFLite Quantize"""

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

# The output must be quantized
assert output_tensor.qnn_params
# Quantize the input
out = self.quantize(in_expr, output_tensor)

return out

def convert_dequantize(self, op):
"""Convert TFLite Dequantize"""

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

# The input must be quantized
assert input_tensor.qnn_params
# Dequantize the input.
out = self.dequantize(in_expr, input_tensor)

return out

def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
Expand Down
43 changes: 43 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,48 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3])


#######################################################################
# Quantize/DeQuantize
# -------------------

def _test_quantize_dequantize(data):
""" One iteration of quantize and dequantize """

# Define a dummy model
data_in = tf.keras.layers.Input(shape=data.shape[1:])
act_func = tf.keras.layers.Activation('linear')
keras_model = tf.keras.models.Model(data_in, act_func(data_in))

# Load the model
converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
for i in range(100):
yield [data]

converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

# Convert the model to TensorFlow Lite format
tflite_model_quant = converter.convert()

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)


def test_forward_quantize_dequantize():
""" Quantize Dequantize """
data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'):
_test_quantize_dequantize(data)


#######################################################################
# Pad
# ---
Expand Down Expand Up @@ -2252,6 +2294,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_select()
test_forward_quantize_dequantize()

# NN
test_forward_convolution()
Expand Down