diff --git a/tests/python/frontend/tensorflow/test_bn_dynamic.py b/tests/python/frontend/tensorflow/test_bn_dynamic.py index 4be838e331ef..a2d69034a94a 100644 --- a/tests/python/frontend/tensorflow/test_bn_dynamic.py +++ b/tests/python/frontend/tensorflow/test_bn_dynamic.py @@ -22,7 +22,10 @@ """ import tvm import numpy as np -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf from tvm import relay from tensorflow.python.framework import graph_util diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index bc884bbbfa9b..93501f134d59 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1901,7 +1901,9 @@ def _get_tensorflow_output(): def test_forward_lstm(): '''test LSTM block cell''' - _test_lstm_cell(1, 2, 1, 0.5, 'float32') + if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'): + #in 2.0, tf.contrib.rnn.LSTMBlockCell is removed + _test_lstm_cell(1, 2, 1, 0.5, 'float32') ####################################################################### @@ -3308,9 +3310,7 @@ def test_forward_isfinite(): test_forward_ptb() # RNN - if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'): - #in 2.0, tf.contrib.rnn.LSTMBlockCell is removed - test_forward_lstm() + test_forward_lstm() # Elementwise test_forward_ceil()