Skip to content

Commit

Permalink
reset pytest path
Browse files Browse the repository at this point in the history
  • Loading branch information
liangs6212 committed Jul 25, 2022
1 parent 6349247 commit 3313dd1
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/chronos/dev/test/run-pytests-tf1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fi
ray stop -f

echo "Running chronos tests TF1 and Deprecated API"
python -m pytest -v -m "skipif and tf1" test/bigdl/chronos/model/
python -m pytest -v -m "skipif and tf1" test/bigdl/chronos/model/tf1/

exit_status_0=$?
if [ $exit_status_0 -ne 0 ];
Expand Down
3 changes: 3 additions & 0 deletions python/chronos/test/bigdl/chronos/model/tf1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
import tensorflow as tf
skip_tf = pytest.mark.skipif(tf.__version__ > '2.0.0', reason="Run only when tf==1.15.0.")
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from numpy.testing import assert_array_almost_equal
import pandas as pd
import numpy as np
from . import skip_tf


@pytest.mark.skipif(tf.__version__ > '2.0.0', reason="Run only when tf==1.15.0.")
@skip_tf
@pytest.mark.tf1
class TestSeq2Seq(ZooTestCase):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tempfile
import os
import tensorflow as tf
from . import skip_tf


def create_data():
Expand All @@ -41,7 +42,8 @@ def get_x_y(num_samples):
test_data = get_x_y(num_test_samples)
return train_data, val_data, test_data

@pytest.mark.skipif(tf.__version__ > '2.0.0', reason="Run only when tf==1.15.0.")

@skip_tf
@pytest.mark.tf1
class TestVanillaLSTM(TestCase):
train_data, val_data, test_data = create_data()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
import tensorflow as tf
from numpy.testing import assert_array_almost_equal
from . import skip_tf


def create_data():
Expand Down Expand Up @@ -51,8 +52,7 @@ def get_data(num_samples):
tsdata.roll(lookback=lookback, horizon=horizon)
return tsdata_train, tsdata_test


@pytest.mark.skipif(tf.__version__ > '2.0.0', reason="Run only when tf==1.15.0.")
@skip_tf
@pytest.mark.tf1
class TestMTNetKeras(ZooTestCase):

Expand Down

0 comments on commit 3313dd1

Please sign in to comment.