diff --git a/python/chronos/dev/test/run-pytests-tf1.sh b/python/chronos/dev/test/run-pytests-tf1.sh index 72bc242f7432..55fd5fadb56d 100755 --- a/python/chronos/dev/test/run-pytests-tf1.sh +++ b/python/chronos/dev/test/run-pytests-tf1.sh @@ -28,7 +28,7 @@ fi ray stop -f echo "Running chronos tests TF1 and Deprecated API" -python -m pytest -v -m "skipif and tf1" test/bigdl/chronos/model/ +python -m pytest -v -m "skipif and tf1" test/bigdl/chronos/model/tf1/ exit_status_0=$? if [ $exit_status_0 -ne 0 ]; diff --git a/python/chronos/test/bigdl/chronos/model/tf1/__init__.py b/python/chronos/test/bigdl/chronos/model/tf1/__init__.py index 2151a805423a..c01e6f73e9b7 100644 --- a/python/chronos/test/bigdl/chronos/model/tf1/__init__.py +++ b/python/chronos/test/bigdl/chronos/model/tf1/__init__.py @@ -13,3 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import pytest +import tensorflow as tf +skip_tf = pytest.mark.skipif(tf.__version__ > '2.0.0', reason="Run only when tf==1.15.0.") \ No newline at end of file diff --git a/python/chronos/test/bigdl/chronos/model/tf1/test_Seq2Seq_keras.py b/python/chronos/test/bigdl/chronos/model/tf1/test_Seq2Seq_keras.py index 3f4450260b29..d61c56ae34e7 100644 --- a/python/chronos/test/bigdl/chronos/model/tf1/test_Seq2Seq_keras.py +++ b/python/chronos/test/bigdl/chronos/model/tf1/test_Seq2Seq_keras.py @@ -23,9 +23,10 @@ from numpy.testing import assert_array_almost_equal import pandas as pd import numpy as np +from . import skip_tf -@pytest.mark.skipif(tf.__version__ > '2.0.0', reason="Run only when tf==1.15.0.") +@skip_tf @pytest.mark.tf1 class TestSeq2Seq(ZooTestCase): diff --git a/python/chronos/test/bigdl/chronos/model/tf1/test_VanillaLSTM_keras.py b/python/chronos/test/bigdl/chronos/model/tf1/test_VanillaLSTM_keras.py index d38a21c4af1c..24c11d46cd68 100644 --- a/python/chronos/test/bigdl/chronos/model/tf1/test_VanillaLSTM_keras.py +++ b/python/chronos/test/bigdl/chronos/model/tf1/test_VanillaLSTM_keras.py @@ -21,6 +21,7 @@ import tempfile import os import tensorflow as tf +from . import skip_tf def create_data(): @@ -41,7 +42,8 @@ def get_x_y(num_samples): test_data = get_x_y(num_test_samples) return train_data, val_data, test_data -@pytest.mark.skipif(tf.__version__ > '2.0.0', reason="Run only when tf==1.15.0.") + +@skip_tf @pytest.mark.tf1 class TestVanillaLSTM(TestCase): train_data, val_data, test_data = create_data() diff --git a/python/chronos/test/bigdl/chronos/model/tf1/test_mtnet_keras.py b/python/chronos/test/bigdl/chronos/model/tf1/test_mtnet_keras.py index 7c9c9240a6de..b39936dcfee8 100644 --- a/python/chronos/test/bigdl/chronos/model/tf1/test_mtnet_keras.py +++ b/python/chronos/test/bigdl/chronos/model/tf1/test_mtnet_keras.py @@ -24,6 +24,7 @@ import numpy as np import tensorflow as tf from numpy.testing import assert_array_almost_equal +from . import skip_tf def create_data(): @@ -51,8 +52,7 @@ def get_data(num_samples): tsdata.roll(lookback=lookback, horizon=horizon) return tsdata_train, tsdata_test - -@pytest.mark.skipif(tf.__version__ > '2.0.0', reason="Run only when tf==1.15.0.") +@skip_tf @pytest.mark.tf1 class TestMTNetKeras(ZooTestCase):