From e97c01012d3a04c1f2a39abbc810c34395c74080 Mon Sep 17 00:00:00 2001 From: Sammy Date: Mon, 24 Jun 2019 23:55:55 -0400 Subject: [PATCH] Fixing package path in tflite test (#3427) --- tests/python/frontend/tflite/test_forward.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 41147fe9e9bf..c9fd0dc7143f 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -32,7 +32,10 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables -from tensorflow import lite as interpreter_wrapper +try: + from tensorflow import lite as interpreter_wrapper +except ImportError: + from tensorflow.contrib import lite as interpreter_wrapper import tvm.relay.testing.tf as tf_testing @@ -131,7 +134,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, if init_global_variables: sess.run(variables.global_variables_initializer()) # convert to tflite model - converter = tf.contrib.lite.TFLiteConverter.from_session( + converter = interpreter_wrapper.TFLiteConverter.from_session( sess, input_tensors, output_tensors) tflite_model_buffer = converter.convert() tflite_output = run_tflite_graph(tflite_model_buffer, in_data)