From ab44c004b80af7326fb349fa050fd22bbd717162 Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Tue, 29 Jan 2019 13:30:33 +0800 Subject: [PATCH] [Doc] TFLite frontend tutorial (#2508) * TFLite frontend tutorial * Modify as suggestion --- tutorials/frontend/from_tflite.py | 197 ++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 tutorials/frontend/from_tflite.py diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py new file mode 100644 index 000000000000..96068e8a2188 --- /dev/null +++ b/tutorials/frontend/from_tflite.py @@ -0,0 +1,197 @@ +""" +Compile TFLite Models +=================== +**Author**: `Zhao Wu `_ + +This article is an introductory tutorial to deploy TFLite models with Relay. + +To get started, Flatbuffers and TFLite package needs to be installed as prerequisites. + +A quick solution is to install Flatbuffers via pip + +.. code-block:: bash + + pip install flatbuffers --user + +To install TFlite packages, you could use our prebuilt wheel: + +.. code-block:: bash + + # For python3: + wget https://github.com/dmlc/web-data/tree/master/tensorflow/tflite/whl/tflite-0.0.1-py3-none-any.whl + pip install tflite-0.0.1-py3-none-any.whl --user + + # For python2: + wget https://github.com/dmlc/web-data/tree/master/tensorflow/tflite/whl/tflite-0.0.1-py2-none-any.whl + pip install tflite-0.0.1-py2-none-any.whl --user + + +or you could generate TFLite package by yourself. The steps are as following: + +.. code-block:: bash + + # Get the flatc compiler. + # Please refer to https://github.com/google/flatbuffers for details + # and make sure it is properly installed. + flatc --version + + # Get the TFLite schema. + wget https://raw.githubusercontent.com/tensorflow/tensorflow/r1.12/tensorflow/contrib/lite/schema/schema.fbs + + # Generate TFLite package. + flatc --python schema.fbs + + # Add it to PYTHONPATH. + export PYTHONPATH=/path/to/tflite + + +Now please check if TFLite package is installed successfully, ``python -c "import tflite"`` + +Below you can find an example for how to compile TFLite model using TVM. +""" +###################################################################### +# Utils for downloading and extracting zip files +# --------------------------------------------- + +def download(url, path, overwrite=False): + import os + if os.path.isfile(path) and not overwrite: + print('File {} existed, skip.'.format(path)) + return + print('Downloading from url {} to {}'.format(url, path)) + try: + import urllib.request + urllib.request.urlretrieve(url, path) + except: + import urllib + urllib.urlretrieve(url, path) + +def extract(path): + import tarfile + if path.endswith("tgz") or path.endswith("gz"): + tar = tarfile.open(path) + tar.extractall() + tar.close() + else: + raise RuntimeError('Could not decompress the file: ' + path) + + +###################################################################### +# Load pretrained TFLite model +# --------------------------------------------- +# we load mobilenet V1 TFLite model provided by Google +model_url = "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz" + +# we download model tar file and extract, finally get mobilenet_v1_1.0_224.tflite +download(model_url, "mobilenet_v1_1.0_224.tgz", False) +extract("mobilenet_v1_1.0_224.tgz") + +# now we have mobilenet_v1_1.0_224.tflite on disk and open it +tflite_model_file = "mobilenet_v1_1.0_224.tflite" +tflite_model_buf = open(tflite_model_file, "rb").read() + +# get TFLite model from buffer +import tflite.Model +tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + +###################################################################### +# Load a test image +# --------------------------------------------- +# A single cat dominates the examples! +from PIL import Image +from matplotlib import pyplot as plt +import numpy as np + +image_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' +download(image_url, 'cat.png') +resized_image = Image.open('cat.png').resize((224, 224)) +plt.imshow(resized_image) +plt.show() +image_data = np.asarray(resized_image).astype("float32") + +# convert HWC to CHW +image_data = image_data.transpose((2, 0, 1)) + +# after expand_dims, we have format NCHW +image_data = np.expand_dims(image_data, axis=0) + +# preprocess image as described here: +# https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243 +image_data[:, 0, :, :] = 2.0 / 255.0 * image_data[:, 0, :, :] - 1 +image_data[:, 1, :, :] = 2.0 / 255.0 * image_data[:, 1, :, :] - 1 +image_data[:, 2, :, :] = 2.0 / 255.0 * image_data[:, 2, :, :] - 1 +print('input', image_data.shape) + +#################################################################### +# +# .. note:: Input layout: +# +# Currently, TVM TFLite frontend accepts ``NCHW`` as input layout. + +###################################################################### +# Compile the model with relay +# --------------------------------------------- + +# TFLite input tensor name, shape and type +input_tensor = "input" +input_shape = (1, 3, 224, 224) +input_dtype = "float32" + +# parse TFLite model and convert into Relay computation graph +from tvm import relay +func, params = relay.frontend.from_tflite(tflite_model, + shape_dict={input_tensor: input_shape}, + dtype_dict={input_tensor: input_dtype}) + +# targt x86 cpu +target = "llvm" +with relay.build_module.build_config(opt_level=3): + graph, lib, params = relay.build(func, target, params=params) + +###################################################################### +# Execute on TVM +# --------------------------------------------- +import tvm +from tvm.contrib import graph_runtime as runtime + +# create a runtime executor module +module = runtime.create(graph, lib, tvm.cpu()) + +# feed input data +module.set_input(input_tensor, tvm.nd.array(image_data)) + +# feed related params +module.set_input(**params) + +# run +module.run() + +# get output +tvm_output = module.get_output(0).asnumpy() + +###################################################################### +# Display results +# --------------------------------------------- + +# load label file +label_file_url = ''.join(['https://raw.githubusercontent.com/', + 'tensorflow/tensorflow/master/tensorflow/lite/java/demo/', + 'app/src/main/assets/', + 'labels_mobilenet_quant_v1_224.txt']) +label_file = "labels_mobilenet_quant_v1_224.txt" +download(label_file_url, label_file) + +# map id to 1001 classes +labels = dict() +with open(label_file) as f: + for id, line in enumerate(f): + labels[id] = line + +# convert result to 1D data +predictions = np.squeeze(tvm_output) + +# get top 1 prediction +prediction = np.argmax(predictions) + +# convert id to class name and show the result +print("The image prediction result is: id " + str(prediction) + " name: " + labels[prediction])