diff --git a/python/asr/nn.py b/python/asr/nn.py index 03aac568c5164..5131c9bfe23a5 100644 --- a/python/asr/nn.py +++ b/python/asr/nn.py @@ -75,10 +75,16 @@ def _load_frontend_model(path, input_tensor): input_dtype = 'float32' if info.frontend == 'tflite': - import tflite.Model + import tflite tflite_model_buf = open(path, 'rb').read() - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + except ImportError: + raise ImportError("The tflite package must be installed") mod, param = relay.frontend.from_tflite(tflite_model, shape_dict={input_tensor: input_shape},