diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 3c4a6392424da..c20e96cbdd97c 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -73,6 +73,16 @@ def get_real_image(im_height, im_width): data = np.reshape(x, (1, im_height, im_width, 3)) return data +def get_real_image_object_detection(im_height, im_width): + repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/' + img_name = 'street_small.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + image = Image.open(img_path).resize((im_height, im_width)) + x = np.array(image).astype('uint8') + data = np.reshape(x, (1, im_height, im_width, 3)) + return data + def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', out_names=None): """ Generic function to compile on relay and execute on tvm """