forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Doc] TFLite frontend tutorial (apache#2508)
* TFLite frontend tutorial * Modify as suggestion
- Loading branch information
1 parent
db5e100
commit ab44c00
Showing
1 changed file
with
197 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
""" | ||
Compile TFLite Models | ||
=================== | ||
**Author**: `Zhao Wu <https://github.com/FrozenGene>`_ | ||
This article is an introductory tutorial to deploy TFLite models with Relay. | ||
To get started, Flatbuffers and TFLite package needs to be installed as prerequisites. | ||
A quick solution is to install Flatbuffers via pip | ||
.. code-block:: bash | ||
pip install flatbuffers --user | ||
To install TFlite packages, you could use our prebuilt wheel: | ||
.. code-block:: bash | ||
# For python3: | ||
wget https://github.com/dmlc/web-data/tree/master/tensorflow/tflite/whl/tflite-0.0.1-py3-none-any.whl | ||
pip install tflite-0.0.1-py3-none-any.whl --user | ||
# For python2: | ||
wget https://github.com/dmlc/web-data/tree/master/tensorflow/tflite/whl/tflite-0.0.1-py2-none-any.whl | ||
pip install tflite-0.0.1-py2-none-any.whl --user | ||
or you could generate TFLite package by yourself. The steps are as following: | ||
.. code-block:: bash | ||
# Get the flatc compiler. | ||
# Please refer to https://github.com/google/flatbuffers for details | ||
# and make sure it is properly installed. | ||
flatc --version | ||
# Get the TFLite schema. | ||
wget https://raw.githubusercontent.com/tensorflow/tensorflow/r1.12/tensorflow/contrib/lite/schema/schema.fbs | ||
# Generate TFLite package. | ||
flatc --python schema.fbs | ||
# Add it to PYTHONPATH. | ||
export PYTHONPATH=/path/to/tflite | ||
Now please check if TFLite package is installed successfully, ``python -c "import tflite"`` | ||
Below you can find an example for how to compile TFLite model using TVM. | ||
""" | ||
###################################################################### | ||
# Utils for downloading and extracting zip files | ||
# --------------------------------------------- | ||
|
||
def download(url, path, overwrite=False): | ||
import os | ||
if os.path.isfile(path) and not overwrite: | ||
print('File {} existed, skip.'.format(path)) | ||
return | ||
print('Downloading from url {} to {}'.format(url, path)) | ||
try: | ||
import urllib.request | ||
urllib.request.urlretrieve(url, path) | ||
except: | ||
import urllib | ||
urllib.urlretrieve(url, path) | ||
|
||
def extract(path): | ||
import tarfile | ||
if path.endswith("tgz") or path.endswith("gz"): | ||
tar = tarfile.open(path) | ||
tar.extractall() | ||
tar.close() | ||
else: | ||
raise RuntimeError('Could not decompress the file: ' + path) | ||
|
||
|
||
###################################################################### | ||
# Load pretrained TFLite model | ||
# --------------------------------------------- | ||
# we load mobilenet V1 TFLite model provided by Google | ||
model_url = "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz" | ||
|
||
# we download model tar file and extract, finally get mobilenet_v1_1.0_224.tflite | ||
download(model_url, "mobilenet_v1_1.0_224.tgz", False) | ||
extract("mobilenet_v1_1.0_224.tgz") | ||
|
||
# now we have mobilenet_v1_1.0_224.tflite on disk and open it | ||
tflite_model_file = "mobilenet_v1_1.0_224.tflite" | ||
tflite_model_buf = open(tflite_model_file, "rb").read() | ||
|
||
# get TFLite model from buffer | ||
import tflite.Model | ||
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) | ||
|
||
###################################################################### | ||
# Load a test image | ||
# --------------------------------------------- | ||
# A single cat dominates the examples! | ||
from PIL import Image | ||
from matplotlib import pyplot as plt | ||
import numpy as np | ||
|
||
image_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' | ||
download(image_url, 'cat.png') | ||
resized_image = Image.open('cat.png').resize((224, 224)) | ||
plt.imshow(resized_image) | ||
plt.show() | ||
image_data = np.asarray(resized_image).astype("float32") | ||
|
||
# convert HWC to CHW | ||
image_data = image_data.transpose((2, 0, 1)) | ||
|
||
# after expand_dims, we have format NCHW | ||
image_data = np.expand_dims(image_data, axis=0) | ||
|
||
# preprocess image as described here: | ||
# https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243 | ||
image_data[:, 0, :, :] = 2.0 / 255.0 * image_data[:, 0, :, :] - 1 | ||
image_data[:, 1, :, :] = 2.0 / 255.0 * image_data[:, 1, :, :] - 1 | ||
image_data[:, 2, :, :] = 2.0 / 255.0 * image_data[:, 2, :, :] - 1 | ||
print('input', image_data.shape) | ||
|
||
#################################################################### | ||
# | ||
# .. note:: Input layout: | ||
# | ||
# Currently, TVM TFLite frontend accepts ``NCHW`` as input layout. | ||
|
||
###################################################################### | ||
# Compile the model with relay | ||
# --------------------------------------------- | ||
|
||
# TFLite input tensor name, shape and type | ||
input_tensor = "input" | ||
input_shape = (1, 3, 224, 224) | ||
input_dtype = "float32" | ||
|
||
# parse TFLite model and convert into Relay computation graph | ||
from tvm import relay | ||
func, params = relay.frontend.from_tflite(tflite_model, | ||
shape_dict={input_tensor: input_shape}, | ||
dtype_dict={input_tensor: input_dtype}) | ||
|
||
# targt x86 cpu | ||
target = "llvm" | ||
with relay.build_module.build_config(opt_level=3): | ||
graph, lib, params = relay.build(func, target, params=params) | ||
|
||
###################################################################### | ||
# Execute on TVM | ||
# --------------------------------------------- | ||
import tvm | ||
from tvm.contrib import graph_runtime as runtime | ||
|
||
# create a runtime executor module | ||
module = runtime.create(graph, lib, tvm.cpu()) | ||
|
||
# feed input data | ||
module.set_input(input_tensor, tvm.nd.array(image_data)) | ||
|
||
# feed related params | ||
module.set_input(**params) | ||
|
||
# run | ||
module.run() | ||
|
||
# get output | ||
tvm_output = module.get_output(0).asnumpy() | ||
|
||
###################################################################### | ||
# Display results | ||
# --------------------------------------------- | ||
|
||
# load label file | ||
label_file_url = ''.join(['https://raw.githubusercontent.com/', | ||
'tensorflow/tensorflow/master/tensorflow/lite/java/demo/', | ||
'app/src/main/assets/', | ||
'labels_mobilenet_quant_v1_224.txt']) | ||
label_file = "labels_mobilenet_quant_v1_224.txt" | ||
download(label_file_url, label_file) | ||
|
||
# map id to 1001 classes | ||
labels = dict() | ||
with open(label_file) as f: | ||
for id, line in enumerate(f): | ||
labels[id] = line | ||
|
||
# convert result to 1D data | ||
predictions = np.squeeze(tvm_output) | ||
|
||
# get top 1 prediction | ||
prediction = np.argmax(predictions) | ||
|
||
# convert id to class name and show the result | ||
print("The image prediction result is: id " + str(prediction) + " name: " + labels[prediction]) |