diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 41147fe9e9bf..c9fd0dc7143f 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -32,7 +32,10 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables -from tensorflow import lite as interpreter_wrapper +try: + from tensorflow import lite as interpreter_wrapper +except ImportError: + from tensorflow.contrib import lite as interpreter_wrapper import tvm.relay.testing.tf as tf_testing @@ -131,7 +134,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, if init_global_variables: sess.run(variables.global_variables_initializer()) # convert to tflite model - converter = tf.contrib.lite.TFLiteConverter.from_session( + converter = interpreter_wrapper.TFLiteConverter.from_session( sess, input_tensors, output_tensors) tflite_model_buffer = converter.convert() tflite_output = run_tflite_graph(tflite_model_buffer, in_data)