Skip to content

Commit

Permalink
Testcases added
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed May 18, 2020
1 parent 65824b6 commit 13535b4
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,48 @@ def test_forward_squeeze():
_test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3])


#######################################################################
# Quantize/DeQuantize
# -------------------

def _test_quantize_dequantize(data):
""" One iteration of quantize and dequantize """

import tensorflow as tf2
# Define a dummy model
data_in = tf2.keras.layers.Input(shape=data.shape[1:])
act_func = tf2.keras.layers.Activation('linear')
keras_model = tf2.keras.models.Model(data_in, act_func(data_in))

# Load the model
converter = tf2.lite.TFLiteConverter.from_keras_model(keras_model)

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
for i in range(100):
yield [data]

converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

# Convert the model to TensorFlow Lite format
tflite_model_quant = converter.convert()

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)


def test_forward_quantize_dequantize():
""" Quantize Dequantize """
data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
_test_quantize_dequantize(data)


#######################################################################
# Pad
# ---
Expand Down Expand Up @@ -2252,6 +2294,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_select()
test_forward_quantize_dequantize()

# NN
test_forward_convolution()
Expand Down

0 comments on commit 13535b4

Please sign in to comment.