Skip to content

Commit

Permalink
[TFLITE]Quantize & Dequantize op (apache#5394)
Browse files Browse the repository at this point in the history
* [TFLITE]Quantize & Dequantize op

* Testcases added

* Review comment fixed
  • Loading branch information
siju-samuel authored and Trevor Morris committed Jun 9, 2020
1 parent 1eb0677 commit a7ab91e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
38 changes: 38 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, model, subgraph, exp_tab):
'COS': self.convert_cos,
'DEPTH_TO_SPACE': self.convert_depth_to_space,
'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d,
'DEQUANTIZE': self.convert_dequantize,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess,
'DIV': self.convert_div,
'ELU': self.convert_elu,
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(self, model, subgraph, exp_tab):
'PAD': self.convert_pad,
'POW': self.convert_pow,
'PRELU': self.convert_prelu,
'QUANTIZE': self.convert_quantize,
'REDUCE_ANY': self.convert_reduce_any,
'REDUCE_MAX': self.convert_reduce_max,
'REDUCE_MIN': self.convert_reduce_min,
Expand Down Expand Up @@ -277,6 +279,8 @@ def get_tensor_type_str(self, tensor_type):
except ImportError:
raise ImportError("The tflite package must be installed")

if tensor_type == TensorType.INT8:
return "int8"
if tensor_type == TensorType.UINT8:
return "uint8"
if tensor_type == TensorType.FLOAT32:
Expand Down Expand Up @@ -2355,6 +2359,40 @@ def convert_transpose_conv(self, op):

return out

def convert_quantize(self, op):
"""Convert TFLite Quantize"""

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

# The output must be quantized
assert output_tensor.qnn_params
# Quantize the input
out = self.quantize(in_expr, output_tensor)

return out

def convert_dequantize(self, op):
"""Convert TFLite Dequantize"""

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

# The input must be quantized
assert input_tensor.qnn_params
# Dequantize the input.
out = self.dequantize(in_expr, input_tensor)

return out

def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
Expand Down
43 changes: 43 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,48 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3])


#######################################################################
# Quantize/DeQuantize
# -------------------

def _test_quantize_dequantize(data):
""" One iteration of quantize and dequantize """

# Define a dummy model
data_in = tf.keras.layers.Input(shape=data.shape[1:])
act_func = tf.keras.layers.Activation('linear')
keras_model = tf.keras.models.Model(data_in, act_func(data_in))

# Load the model
converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
for i in range(100):
yield [data]

converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

# Convert the model to TensorFlow Lite format
tflite_model_quant = converter.convert()

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)


def test_forward_quantize_dequantize():
""" Quantize Dequantize """
data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'):
_test_quantize_dequantize(data)


#######################################################################
# Pad
# ---
Expand Down Expand Up @@ -2264,6 +2306,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_select()
test_forward_quantize_dequantize()

# NN
test_forward_convolution()
Expand Down

0 comments on commit a7ab91e

Please sign in to comment.