Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into vision/fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Feb 8, 2021
2 parents 5f2936b + 13acf78 commit 1030303
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 40 deletions.
14 changes: 5 additions & 9 deletions tests/text/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import os
from pathlib import Path

import pytest

from flash.text import TextClassificationData

TEST_BACKBONE = "prajjwal1/bert-tiny" # super small model for testing
Expand Down Expand Up @@ -43,10 +45,8 @@ def json_data(tmpdir):
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_from_csv(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
csv_path = csv_data(tmpdir)
dm = TextClassificationData.from_files(
backbone=TEST_BACKBONE, train_file=csv_path, input="sentence", target="label", batch_size=1
Expand All @@ -56,10 +56,8 @@ def test_from_csv(tmpdir):
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_test_valid(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
csv_path = csv_data(tmpdir)
dm = TextClassificationData.from_files(
backbone=TEST_BACKBONE,
Expand All @@ -79,10 +77,8 @@ def test_test_valid(tmpdir):
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_from_json(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
json_path = json_data(tmpdir)
dm = TextClassificationData.from_files(
backbone=TEST_BACKBONE, train_file=json_path, input="sentence", target="lab", filetype="json", batch_size=1
Expand Down
6 changes: 2 additions & 4 deletions tests/text/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os

import pytest
import torch

from flash import Trainer
Expand All @@ -38,11 +39,8 @@ def __len__(self):
TEST_BACKBONE = "prajjwal1/bert-tiny" # super small model for testing


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_init_train(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
#
return True
model = TextClassifier(2, TEST_BACKBONE)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
Expand Down
14 changes: 5 additions & 9 deletions tests/text/summarization/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import os
from pathlib import Path

import pytest

from flash.text import SummarizationData

TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing
Expand Down Expand Up @@ -43,10 +45,8 @@ def json_data(tmpdir):
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_from_csv(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
csv_path = csv_data(tmpdir)
dm = SummarizationData.from_files(
backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1
Expand All @@ -56,10 +56,8 @@ def test_from_csv(tmpdir):
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_from_files(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
csv_path = csv_data(tmpdir)
dm = SummarizationData.from_files(
backbone=TEST_BACKBONE,
Expand All @@ -79,10 +77,8 @@ def test_from_files(tmpdir):
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_from_json(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
json_path = json_data(tmpdir)
dm = SummarizationData.from_files(
backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1
Expand Down
5 changes: 2 additions & 3 deletions tests/text/summarization/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os

import pytest
import torch

from flash import Trainer
Expand All @@ -38,10 +39,8 @@ def __len__(self):
TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_init_train(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
model = SummarizationTask(TEST_BACKBONE)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
Expand Down
5 changes: 2 additions & 3 deletions tests/text/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
from pathlib import Path

import pytest
from pytorch_lightning import Trainer

from flash.text import TextClassificationData, TextClassifier
Expand All @@ -33,10 +34,8 @@ def csv_data(tmpdir):
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_classification(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True

csv_path = csv_data(tmpdir)

Expand Down
14 changes: 5 additions & 9 deletions tests/text/translation/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import os
from pathlib import Path

import pytest

from flash.text import TranslationData

TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing
Expand Down Expand Up @@ -43,10 +45,8 @@ def json_data(tmpdir):
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_from_csv(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
csv_path = csv_data(tmpdir)
dm = TranslationData.from_files(
backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1
Expand All @@ -56,10 +56,8 @@ def test_from_csv(tmpdir):
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_from_files(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
csv_path = csv_data(tmpdir)
dm = TranslationData.from_files(
backbone=TEST_BACKBONE,
Expand All @@ -79,10 +77,8 @@ def test_from_files(tmpdir):
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_from_json(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
json_path = json_data(tmpdir)
dm = TranslationData.from_files(
backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1
Expand Down
5 changes: 2 additions & 3 deletions tests/text/translation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os

import pytest
import torch

from flash import Trainer
Expand All @@ -38,10 +39,8 @@ def __len__(self):
TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
def test_init_train(tmpdir):
if os.name == "nt":
# TODO: huggingface stuff timing out on windows
return True
model = TranslationTask(TEST_BACKBONE)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
Expand Down

0 comments on commit 1030303

Please sign in to comment.