Skip to content

Commit

Permalink
[TFLite] Model importer to be compatible with tflite 2.1.0 (#5497)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmoreau89 authored May 2, 2020
1 parent 360027d commit 8599f7c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 24 deletions.
8 changes: 6 additions & 2 deletions golang/sample/gen_mobilenet_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import os
from tvm import relay
from tvm.contrib.download import download_testdata
import tflite.Model


################################################
Expand Down Expand Up @@ -49,7 +48,12 @@ def extract(path):

# get TFLite model from buffer
tflite_model_buf = open(model_file, "rb").read()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
try:
import tflite
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)


##############################
Expand Down
12 changes: 9 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2524,7 +2524,7 @@ def from_tflite(model, shape_dict, dtype_dict):
Parameters
----------
model:
tflite.Model.Model
tflite.Model or tflite.Model.Model (depending on tflite version)
shape_dict : dict of str to int list/tuple
Input shapes of the model.
Expand All @@ -2541,12 +2541,18 @@ def from_tflite(model, shape_dict, dtype_dict):
The parameter dict to be used by relay
"""
try:
import tflite.Model
import tflite.SubGraph
import tflite.BuiltinOperator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(model, tflite.Model.Model)

# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite
assert isinstance(model, tflite.Model)
except TypeError:
import tflite.Model
assert isinstance(model, tflite.Model.Model)

# keep the same as tflite
assert model.SubgraphsLength() == 1, "only support one subgraph (main subgraph)"
Expand Down
8 changes: 5 additions & 3 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@ def get_real_image(im_height, im_width):
def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
out_names=None):
""" Generic function to compile on relay and execute on tvm """
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
except ImportError:
raise ImportError("The tflite package must be installed")

# get TFLite model from buffer
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)

input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)

Expand Down
19 changes: 3 additions & 16 deletions tutorials/frontend/from_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,12 @@
This article is an introductory tutorial to deploy TFLite models with Relay.
To get started, Flatbuffers and TFLite package needs to be installed as prerequisites.
A quick solution is to install Flatbuffers via pip
To get started, TFLite package needs to be installed as prerequisite.
.. code-block:: bash
pip install flatbuffers --user
To install TFlite packages, you could use our prebuilt wheel:
.. code-block:: bash
# For python3:
wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py3-none-any.whl
pip3 install -U tflite-1.13.1-py3-none-any.whl --user
# For python2:
wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py2-none-any.whl
pip install -U tflite-1.13.1-py2-none-any.whl --user
# install tflite
pip install tflite=2.1.0 --user
or you could generate TFLite package yourself. The steps are the following:
Expand Down

0 comments on commit 8599f7c

Please sign in to comment.