From ac991a4632d30b892857f92167c3968ba6a67344 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 4 Feb 2020 22:17:52 +0000 Subject: [PATCH 1/5] [TFLite] Using real image for QNN testing. --- python/tvm/relay/frontend/tflite.py | 6 +- tests/python/frontend/tflite/test_forward.py | 59 ++++++++++++++------ 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index cefd4085b67c..ba2f9851a33d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1185,10 +1185,14 @@ def convert_conv(self, op, conv_type): pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0) if do_pad: + pad_value = 0 + if input_tensor.qnn_params: + pad_value = get_scalar_from_constant(input_tensor.qnn_params['zero_point']) in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), (pad_top, pad_bottom), (pad_left, pad_right), - (0, 0))) + (0, 0)), pad_value=float(pad_value)) + else: raise tvm.error.OpAttributeUnImplemented( 'Padding format {} is not supported for operator Conv.'.format(padding)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index ad1abc247f7e..b9b21067e104 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -42,6 +42,9 @@ import tvm.relay.testing.tf as tf_testing from packaging import version as package_version +from PIL import Image +import os + ####################################################################### # Generic run functions for TVM & TFLite # -------------------------------------- @@ -50,6 +53,20 @@ def convert_to_list(x): x = [x] return x + +####################################################################### +# Get a real image for e2e testing. +# -------------------------------------- +def get_real_image(im_height, im_width): + repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' + img_name = 'elephant-299.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + image = Image.open(img_path).resize((im_height, im_width)) + x = np.array(image).astype('uint8') + data = np.reshape(x, (1, im_height, im_width, 3)) + return data + def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', out_names=None): """ Generic function to compile on relay and execute on tvm """ @@ -1425,16 +1442,18 @@ def test_forward_qnn_inception_v1_net(): "inception_v1_224_quant.tflite") with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - # Checking the labels because the requantize implementation is different between TFLite and - # Relay. This cause final output numbers to mismatch. So, testing accuracy via labels. - np.random.seed(0) - data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8') + + # Test image. Checking the labels because the requantize implementation is different between + # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via + # labels. Also, giving a real image, instead of random inputs. + data = get_real_image(224, 224) + tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tflite_sorted_labels = tflite_predictions.argsort()[-5:][::-1] tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm_sorted_labels = tvm_predictions.argsort()[-5:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) def test_forward_qnn_mobilenet_v1_net(): @@ -1445,16 +1464,18 @@ def test_forward_qnn_mobilenet_v1_net(): "mobilenet_v1_1.0_224_quant.tflite") with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - # Checking the labels because the requantize implementation is different between TFLite and - # Relay. This cause final output numbers to mismatch. So, testing accuracy via labels. - np.random.seed(0) - data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8') + + # Test image. Checking the labels because the requantize implementation is different between + # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via + # labels. Also, giving a real image, instead of random inputs. + data = get_real_image(224, 224) + tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tflite_sorted_labels = tflite_predictions.argsort()[-5:][::-1] tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm_sorted_labels = tvm_predictions.argsort()[-5:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) def test_forward_qnn_mobilenet_v2_net(): @@ -1465,16 +1486,18 @@ def test_forward_qnn_mobilenet_v2_net(): "mobilenet_v2_1.0_224_quant.tflite") with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - # Checking the labels because the requantize implementation is different between TFLite and - # Relay. This cause final output numbers to mismatch. So, testing accuracy via labels. - np.random.seed(0) - data = np.random.random_integers(low=0, high=128, size=(1, 224, 224, 3)).astype('uint8') + + # Test image. Checking the labels because the requantize implementation is different between + # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via + # labels. Also, giving a real image, instead of random inputs. + data = get_real_image(224, 224) + tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + tflite_sorted_labels = tflite_predictions.argsort()[-5:][::-1] tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] + tvm_sorted_labels = tvm_predictions.argsort()[-5:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) ####################################################################### From 3df4b8f3bc55b848798c992ad6214d1c69f88b7b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 5 Feb 2020 05:13:45 +0000 Subject: [PATCH 2/5] Setting seed for SSD mobilenet for fixed input. --- tests/python/frontend/tflite/test_forward.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index b9b21067e104..64f9a22f5c88 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1512,6 +1512,7 @@ def test_forward_ssd_mobilenet_v1(): "ssd_mobilenet_v1_coco_2018_01_28_nopp.tflite") with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() + np.random.seed(0) data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32') tflite_output = run_tflite_graph(tflite_model_buf, data) tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=2) From bc75ae25bcfb1669ee799c6d0fb5731fb29b435d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 5 Feb 2020 19:24:16 +0000 Subject: [PATCH 3/5] Support quantized Pad op. --- python/tvm/relay/frontend/tflite.py | 15 ++++++++++++-- tests/python/frontend/tflite/test_forward.py | 21 ++++++++++++++++---- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index ba2f9851a33d..2b8d3f8ef418 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1487,8 +1487,19 @@ def convert_pad(self, op): # convert list of lists to tuple of tuples paddings = tuple(tuple(l) for l in pad_list) - # Use default pad_value 0 because TFLite PAD does not support constant_values parameter - out = _op.nn.pad(in_expr, paddings) + # Set the pad value + pad_value = 0 + if input_tensor.qnn_params: + # Check that input and output tensor have same qnn params. + output_tensors = self.get_output_tensors(op) + output_tensor = output_tensors[0] + assert self.has_same_qnn_params(input_tensor, output_tensor), \ + "TFLite reshape requires input and output scale and zero points to be equal" + + # The pad value for quantized pad is the input zero point. + pad_value = float(input_tensor.qnn_params['zero_point'].data.asnumpy()) + + out = _op.nn.pad(in_expr, pad_width=paddings, pad_value=pad_value) return out def convert_mirror_pad(self, op): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 64f9a22f5c88..f86d5f225c12 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1156,16 +1156,27 @@ def test_forward_squeeze(): # Pad # --- -def _test_pad(data, mode="CONSTANT"): +def _test_pad(data, mode="CONSTANT", quantized=False): """ One iteration of PAD """ assert len(data) == 2 # Test with tensor and constant with tf.Graph().as_default(): - in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')] - out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode) - compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out]) + in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in')] + + if quantized: + min_value, max_value = -100, 100 + # fake_quant will keep the tensors in float32 until the conversion in the session + inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], + min=-100, + max=100, + name="inq_0")] + out = array_ops.pad(inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode) + compare_tflite_with_tvm([data[0]], ['inq_0:0'], inq_data, [out], quantized=True) + else: + out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode) + compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out]) def test_forward_pad(): @@ -1182,6 +1193,8 @@ def test_forward_pad(): np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="REFLECT") _test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)), np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="SYMMETRIC") + _test_pad([np.arange(0, 256, dtype=np.uint8).reshape((1, 256)), + np.array([[1, 1], [2, 2]], dtype=np.int32)], quantized=True) ####################################################################### From 10d40eb3b671b1a8f5a72aa45a6c8421d2563c5c Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 5 Feb 2020 19:30:35 +0000 Subject: [PATCH 4/5] Remove unnnecessary line. --- tests/python/frontend/tflite/test_forward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index f86d5f225c12..fe10d3593a50 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1166,7 +1166,6 @@ def _test_pad(data, mode="CONSTANT", quantized=False): in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in')] if quantized: - min_value, max_value = -100, 100 # fake_quant will keep the tensors in float32 until the conversion in the session inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, From 652951242e9b1bbfab7a8dbceaf706c0f8333e28 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 9 Feb 2020 02:40:57 +0000 Subject: [PATCH 5/5] Ina comments. --- tests/python/frontend/tflite/test_forward.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index fe10d3593a50..ccb8b8740d81 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1167,12 +1167,14 @@ def _test_pad(data, mode="CONSTANT", quantized=False): if quantized: # fake_quant will keep the tensors in float32 until the conversion in the session + input_range = {'inq_0': (-100, 100)} inq_data = [tf.quantization.fake_quant_with_min_max_args(in_data[0], min=-100, max=100, name="inq_0")] out = array_ops.pad(inq_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode) - compare_tflite_with_tvm([data[0]], ['inq_0:0'], inq_data, [out], quantized=True) + compare_tflite_with_tvm([data[0]], ['inq_0:0'], inq_data, [out], quantized=True, + input_range=input_range) else: out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode) compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out]) @@ -1462,10 +1464,10 @@ def test_forward_qnn_inception_v1_net(): tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-5:][::-1] + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-5:][::-1] + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) def test_forward_qnn_mobilenet_v1_net(): @@ -1484,10 +1486,10 @@ def test_forward_qnn_mobilenet_v1_net(): tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-5:][::-1] + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-5:][::-1] + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) def test_forward_qnn_mobilenet_v2_net(): @@ -1506,10 +1508,10 @@ def test_forward_qnn_mobilenet_v2_net(): tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) - tflite_sorted_labels = tflite_predictions.argsort()[-5:][::-1] + tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] tvm_output = run_tvm_graph(tflite_model_buf, data, 'input') tvm_predictions = np.squeeze(tvm_output) - tvm_sorted_labels = tvm_predictions.argsort()[-5:][::-1] + tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) #######################################################################