Skip to content

Commit

Permalink
Ina comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Feb 9, 2020
1 parent 10d40eb commit 6529512
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,12 +1167,14 @@ def _test_pad(data, mode="CONSTANT", quantized=False):

if quantized:
# fake_quant will keep the tensors in float32 until the conversion in the session
input_range = {'inq_0': (-100, 100)}
inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0],
min=-100,
max=100,
name="inq_0")]
out = array_ops.pad(inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
compare_tflite_with_tvm([data[0]], ['inq_0:0'], inq_data, [out], quantized=True)
compare_tflite_with_tvm([data[0]], ['inq_0:0'], inq_data, [out], quantized=True,
input_range=input_range)
else:
out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])
Expand Down Expand Up @@ -1462,10 +1464,10 @@ def test_forward_qnn_inception_v1_net():

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-5:][::-1]
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-5:][::-1]
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

def test_forward_qnn_mobilenet_v1_net():
Expand All @@ -1484,10 +1486,10 @@ def test_forward_qnn_mobilenet_v1_net():

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-5:][::-1]
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-5:][::-1]
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

def test_forward_qnn_mobilenet_v2_net():
Expand All @@ -1506,10 +1508,10 @@ def test_forward_qnn_mobilenet_v2_net():

tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-5:][::-1]
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-5:][::-1]
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

#######################################################################
Expand Down

0 comments on commit 6529512

Please sign in to comment.