Skip to content

Commit

Permalink
Chronos: fix ZooTestCase in tensorflow unit tests (#5800)
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang authored Sep 19, 2022
1 parent 0e31b5a commit de9fa8c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pytest

from bigdl.orca.test_zoo_utils import ZooTestCase
from unittest import TestCase
from bigdl.chronos.model.tf2.Seq2Seq_keras import LSTMSeq2Seq, model_creator
import numpy as np

Expand All @@ -42,7 +42,7 @@ def get_x_y(num_samples):


@pytest.mark.skipif(tf.__version__ < '2.0.0', reason="Run only when tf>2.0.0.")
class TestSeq2Seq(ZooTestCase):
class TestSeq2Seq(TestCase):

train_data, test_data = create_data()
model = model_creator(config={
Expand Down
4 changes: 2 additions & 2 deletions python/chronos/test/bigdl/chronos/model/tf2/test_tcn_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest
import tensorflow as tf

from bigdl.orca.test_zoo_utils import ZooTestCase
from unittest import TestCase
from bigdl.chronos.model.tf2.TCN_keras import model_creator, TemporalConvNet, TemporalBlock


Expand All @@ -40,7 +40,7 @@ def get_x_y(num_samples):
return train_data, test_data

@pytest.mark.skipif(tf.__version__ < '2.0.0', reason="Run only when tf>2.0.0.")
class TestTcnKeras(ZooTestCase):
class TestTcnKeras(TestCase):

train_data, test_data = create_data()
model = model_creator(config={
Expand Down

0 comments on commit de9fa8c

Please sign in to comment.