Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Apr 29, 2020
1 parent e3361d1 commit d4adff4
Showing 1 changed file with 42 additions and 43 deletions.
85 changes: 42 additions & 43 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,23 +1729,27 @@ def test_detection_postprocess():
tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions])
tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions],
["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4)
# check valid count is the same

# Check all output shapes are equal
assert all([tvm_tensor.shape == tflite_tensor.shape \
for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)])

# Check valid count is the same
assert tvm_output[3] == tflite_output[3]
# check all the output shapes are the same
assert tvm_output[0].shape == tflite_output[0].shape
assert tvm_output[1].shape == tflite_output[1].shape
assert tvm_output[2].shape == tflite_output[2].shape
valid_count = tvm_output[3][0]
# only check the valid detections are the same
# tvm has a different convention to tflite for invalid detections, it uses all -1s whereas
# tflite appears to put in nonsense data instead
tvm_boxes = tvm_output[0][0][:valid_count]
tvm_classes = tvm_output[1][0][:valid_count]
tvm_scores = tvm_output[2][0][:valid_count]
# check the output data is correct
tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5)

# For boxes that do not have any detections, TFLite puts random values. Therefore, we compare
# tflite and tvm tensors for only valid boxes.
for i in range(0, valid_count):
# Check bounding box co-ords
tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]),
rtol=1e-5, atol=1e-5)
# Check the class
tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]),
rtol=1e-5, atol=1e-5)
# Check the score
tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]),
rtol=1e-5, atol=1e-5)


#######################################################################
Expand Down Expand Up @@ -1933,28 +1937,6 @@ def test_forward_qnn_mobilenet_v3_net():
# SSD Mobilenet
# -------------

def test_forward_coco_ssd_mobilenet_v1_nopp():
"""Test the SSD Mobilenet V1 TF Lite model."""
# SSD MobilenetV1 with no post processing
tflite_model_file = tf_testing.get_workload_official(
"https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28_nopp.tgz",
"ssd_mobilenet_v1_coco_2018_01_28_nopp.tflite")

with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
np.random.seed(0)
data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=2)
for i in range(2):
tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]),
rtol=1e-5, atol=2e-5)


#######################################################################
# SSD Mobilenet
# -------------

def test_forward_coco_ssd_mobilenet_v1():
"""Test the quantized Coco SSD Mobilenet V1 TF Lite model."""
tflite_model_file = tf_testing.get_workload_official(
Expand All @@ -1963,14 +1945,32 @@ def test_forward_coco_ssd_mobilenet_v1():

with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()

np.random.seed(0)
data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=2)
for i in range(2):
tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]),
rtol=1e-5, atol=2e-5)
tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4)

# Check all output shapes are equal
assert all([tvm_tensor.shape == tflite_tensor.shape \
for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)])

# Check valid count is the same
assert tvm_output[3] == tflite_output[3]
valid_count = tvm_output[3][0]

# For boxes that do not have any detections, TFLite puts random values. Therefore, we compare
# tflite and tvm tensors for only valid boxes.
for i in range(0, valid_count):
# Check bounding box co-ords
tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]),
rtol=1e-5, atol=1e-5)
# Check the class
tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]),
rtol=1e-5, atol=1e-5)
# Check the score
tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]),
rtol=1e-5, atol=1e-5)


#######################################################################
Expand Down Expand Up @@ -2069,7 +2069,6 @@ def test_forward_mediapipe_hand_landmark():
test_forward_mobilenet_v3()
test_forward_inception_v3_net()
test_forward_inception_v4_net()
test_forward_coco_ssd_mobilenet_v1_nopp()
test_forward_coco_ssd_mobilenet_v1()
test_forward_mediapipe_hand_landmark()

Expand Down

0 comments on commit d4adff4

Please sign in to comment.