From 13acf78daab7609f197cdde35b7dff92c08e34ae Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 8 Feb 2021 19:45:52 +0530 Subject: [PATCH] Replace return true with pytest.skip for Windows & HF (#85) --- tests/text/classification/test_data.py | 14 +++++--------- tests/text/classification/test_model.py | 6 ++---- tests/text/summarization/test_data.py | 14 +++++--------- tests/text/summarization/test_model.py | 5 ++--- tests/text/test_data_model_integration.py | 5 ++--- tests/text/translation/test_data.py | 14 +++++--------- tests/text/translation/test_model.py | 5 ++--- 7 files changed, 23 insertions(+), 40 deletions(-) diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 8e5eff04e8..fdd28be60b 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -14,6 +14,8 @@ import os from pathlib import Path +import pytest + from flash.text import TextClassificationData TEST_BACKBONE = "prajjwal1/bert-tiny" # super small model for testing @@ -43,10 +45,8 @@ def json_data(tmpdir): return path +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_csv(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True csv_path = csv_data(tmpdir) dm = TextClassificationData.from_files( backbone=TEST_BACKBONE, train_file=csv_path, input="sentence", target="label", batch_size=1 @@ -56,10 +56,8 @@ def test_from_csv(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_test_valid(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True csv_path = csv_data(tmpdir) dm = TextClassificationData.from_files( backbone=TEST_BACKBONE, @@ -79,10 +77,8 @@ def test_test_valid(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_json(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True json_path = json_data(tmpdir) dm = TextClassificationData.from_files( backbone=TEST_BACKBONE, train_file=json_path, input="sentence", target="lab", filetype="json", batch_size=1 diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index f2dfab2828..411a8051f2 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import pytest import torch from flash import Trainer @@ -38,11 +39,8 @@ def __len__(self): TEST_BACKBONE = "prajjwal1/bert-tiny" # super small model for testing +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_init_train(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - # - return True model = TextClassifier(2, TEST_BACKBONE) train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) diff --git a/tests/text/summarization/test_data.py b/tests/text/summarization/test_data.py index 15c856c72d..4399068085 100644 --- a/tests/text/summarization/test_data.py +++ b/tests/text/summarization/test_data.py @@ -14,6 +14,8 @@ import os from pathlib import Path +import pytest + from flash.text import SummarizationData TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing @@ -43,10 +45,8 @@ def json_data(tmpdir): return path +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_csv(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True csv_path = csv_data(tmpdir) dm = SummarizationData.from_files( backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1 @@ -56,10 +56,8 @@ def test_from_csv(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_files(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True csv_path = csv_data(tmpdir) dm = SummarizationData.from_files( backbone=TEST_BACKBONE, @@ -79,10 +77,8 @@ def test_from_files(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_json(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True json_path = json_data(tmpdir) dm = SummarizationData.from_files( backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1 diff --git a/tests/text/summarization/test_model.py b/tests/text/summarization/test_model.py index 4c28d1b5da..283847a54e 100644 --- a/tests/text/summarization/test_model.py +++ b/tests/text/summarization/test_model.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import pytest import torch from flash import Trainer @@ -38,10 +39,8 @@ def __len__(self): TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_init_train(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True model = SummarizationTask(TEST_BACKBONE) train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) diff --git a/tests/text/test_data_model_integration.py b/tests/text/test_data_model_integration.py index f5fd6682d1..7aeadba7de 100644 --- a/tests/text/test_data_model_integration.py +++ b/tests/text/test_data_model_integration.py @@ -14,6 +14,7 @@ import os from pathlib import Path +import pytest from pytorch_lightning import Trainer from flash.text import TextClassificationData, TextClassifier @@ -33,10 +34,8 @@ def csv_data(tmpdir): return path +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_classification(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True csv_path = csv_data(tmpdir) diff --git a/tests/text/translation/test_data.py b/tests/text/translation/test_data.py index 9bc20265d4..6ac5eba425 100644 --- a/tests/text/translation/test_data.py +++ b/tests/text/translation/test_data.py @@ -14,6 +14,8 @@ import os from pathlib import Path +import pytest + from flash.text import TranslationData TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing @@ -43,10 +45,8 @@ def json_data(tmpdir): return path +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_csv(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True csv_path = csv_data(tmpdir) dm = TranslationData.from_files( backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1 @@ -56,10 +56,8 @@ def test_from_csv(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_files(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True csv_path = csv_data(tmpdir) dm = TranslationData.from_files( backbone=TEST_BACKBONE, @@ -79,10 +77,8 @@ def test_from_files(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_from_json(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True json_path = json_data(tmpdir) dm = TranslationData.from_files( backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1 diff --git a/tests/text/translation/test_model.py b/tests/text/translation/test_model.py index 3577920b66..374d01c307 100644 --- a/tests/text/translation/test_model.py +++ b/tests/text/translation/test_model.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import pytest import torch from flash import Trainer @@ -38,10 +39,8 @@ def __len__(self): TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") def test_init_train(tmpdir): - if os.name == "nt": - # TODO: huggingface stuff timing out on windows - return True model = TranslationTask(TEST_BACKBONE) train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)