Skip to content

Commit

Permalink
Handle TFLite input layer naming.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Jun 29, 2020
1 parent e0cc5c7 commit cef3a85
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,6 +1820,7 @@ def _test_quantize_dequantize(data):
add = tf.keras.layers.Add()([data_in, relu])
concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
input_name = data_in.name.split(":")[0]

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
Expand All @@ -1829,7 +1830,7 @@ def representative_data_gen():
tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1')
tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-2)

Expand Down Expand Up @@ -2074,6 +2075,7 @@ def _test_relu(data, quantized=False):
data_in = tf.keras.layers.Input(shape=data.shape[1:])
relu = tf.keras.layers.ReLU()(data_in)
keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu)
input_name = data_in.name.split(":")[0]

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
Expand All @@ -2083,7 +2085,7 @@ def representative_data_gen():
tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1')
tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)
else:
Expand Down

0 comments on commit cef3a85

Please sign in to comment.