diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index 68ece85eb8..0d4dddbb15 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -11,34 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any from unittest import mock import pytest import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException -from flash import Trainer +import flash from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING +from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING, _TORCH_ORT_AVAILABLE from flash.text import TextClassifier -from tests.helpers.task_tester import TaskTester - -# ======== Mock functions ======== - - -class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): - return { - "input_ids": torch.randint(1000, size=(100,)), - DataKeys.TARGET: torch.randint(2, size=(1,)).item(), - } - - def __len__(self) -> int: - return 100 - - -# ============================== +from flash.text.ort_callback import ORTCallback +from tests.helpers.boring_model import BoringModel +from tests.helpers.task_tester import StaticDataset, TaskTester TEST_BACKBONE = "prajjwal1/bert-tiny" # tiny model for testing @@ -54,7 +40,21 @@ class TestTextClassifier(TaskTester): scriptable = False - marks = {"test_cli": [pytest.mark.parametrize("extra_args", ([], ["from_toxic"]))]} + marks = { + "test_fit": [ + pytest.mark.parametrize( + "task_kwargs", + [ + {}, + pytest.param( + {"enable_ort": True}, + marks=pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed."), + ), + ], + ) + ], + "test_cli": [pytest.mark.parametrize("extra_args", ([], ["from_toxic"]))], + } @property def example_forward_input(self): @@ -64,14 +64,20 @@ def check_forward_output(self, output: Any): assert isinstance(output, torch.Tensor) assert output.shape == torch.Size([1, 2]) + @property + def example_train_sample(self): + return {DataKeys.INPUT: "some text", DataKeys.TARGET: 1} -@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") -def test_init_train(tmpdir): - model = TextClassifier(2, backbone=TEST_BACKBONE) - train_dl = torch.utils.data.DataLoader(DummyDataset()) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) + @pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") + def test_ort_callback_fails_no_model(self, tmpdir): + dataset = StaticDataset(self.example_train_sample, 4) + + model = BoringModel() + + trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback()) + + with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"): + trainer.fit(model, model.process_train_dataset(dataset, batch_size=4)) @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") diff --git a/tests/text/classification/test_ort.py b/tests/text/classification/test_ort.py deleted file mode 100644 index bbcf5e94d4..0000000000 --- a/tests/text/classification/test_ort.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import pytest -import torch -from pytorch_lightning import Callback -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -from flash import Trainer -from flash.core.utilities.imports import _TEXT_TESTING, _TORCH_ORT_AVAILABLE -from flash.text import TextClassifier -from flash.text.ort_callback import ORTCallback -from tests.helpers.boring_model import BoringModel -from tests.text.classification.test_model import DummyDataset, TEST_BACKBONE - -if _TORCH_ORT_AVAILABLE: - from torch_ort import ORTModule - - -@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") -@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") -def test_init_train_enable_ort(tmpdir): - class TestCallback(Callback): - def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - assert isinstance(pl_module.model, ORTModule) - - model = TextClassifier(2, TEST_BACKBONE, enable_ort=True) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TestCallback()) - trainer.fit( - model, - train_dataloader=torch.utils.data.DataLoader(DummyDataset()), - val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), - ) - trainer.test(model, test_dataloaders=torch.utils.data.DataLoader(DummyDataset())) - - -@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.") -def test_ort_callback_fails_no_model(tmpdir): - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback()) - with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"): - trainer.fit( - model, - train_dataloader=torch.utils.data.DataLoader(DummyDataset()), - val_dataloaders=torch.utils.data.DataLoader(DummyDataset()), - ) diff --git a/tests/text/question_answering/test_model.py b/tests/text/question_answering/test_model.py index 9eec93ea2f..5ba1a77496 100644 --- a/tests/text/question_answering/test_model.py +++ b/tests/text/question_answering/test_model.py @@ -12,37 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections -import os from typing import Any import pytest import torch -from flash import Trainer from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING from flash.text import QuestionAnsweringTask from tests.helpers.task_tester import TaskTester -# ======== Mock functions ======== - -SEQUENCE_LENGTH = 384 - - -class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): - return { - "input_ids": torch.randint(1000, size=(SEQUENCE_LENGTH,)), - "attention_mask": torch.randint(1, size=(SEQUENCE_LENGTH,)), - "start_positions": torch.randint(1000, size=(1,)), - "end_positions": torch.randint(1000, size=(1,)), - } - - def __len__(self) -> int: - return 100 - - -# ============================== - TEST_BACKBONE = "distilbert-base-uncased" @@ -70,17 +48,16 @@ def check_forward_output(self, output: Any): assert isinstance(output[0], torch.Tensor) assert isinstance(output[1], collections.OrderedDict) + @property + def example_train_sample(self): + return { + "question": "A question", + "answer": {"text": ["The answer"], "answer_start": [0]}, + "context": "The paragraph of text which contains the answer to the question", + } + @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_modules_to_freeze(): model = QuestionAnsweringTask(backbone=TEST_BACKBONE) assert model.modules_to_freeze() is model.model.distilbert - - -@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") -def test_init_train(tmpdir): - model = QuestionAnsweringTask(TEST_BACKBONE) - train_dl = torch.utils.data.DataLoader(DummyDataset()) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index 230246c7ab..829621ba8b 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -11,41 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any from unittest import mock import pytest import torch -from flash import DataKeys, Trainer +from flash import DataKeys from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING from flash.text import SummarizationTask from tests.helpers.task_tester import TaskTester -# ======== Mock functions ======== - - -class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): - return { - "input_ids": torch.randint(1000, size=(128,)), - DataKeys.TARGET: torch.randint(1000, size=(128,)), - } - - def __len__(self) -> int: - return 100 - - -# ============================== - TEST_BACKBONE = "sshleifer/tiny-mbart" # tiny model for testing class TestSummarizationTask(TaskTester): task = SummarizationTask - task_kwargs = {"backbone": TEST_BACKBONE} + task_kwargs = { + "backbone": TEST_BACKBONE, + "tokenizer_kwargs": {"src_lang": "en_XX", "tgt_lang": "en_XX"}, + } cli_command = "summarization" is_testing = _TEXT_TESTING is_available = _TEXT_AVAILABLE @@ -62,14 +48,9 @@ def check_forward_output(self, output: Any): assert isinstance(output, torch.Tensor) assert output.shape == torch.Size([1, 128]) - -@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") -def test_init_train(tmpdir): - model = SummarizationTask(TEST_BACKBONE) - train_dl = torch.utils.data.DataLoader(DummyDataset()) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) + @property + def example_train_sample(self): + return {DataKeys.INPUT: "Some long passage of text", DataKeys.TARGET: "A summary"} @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index 6aba2a8231..53e9b0b072 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -11,41 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any from unittest import mock import pytest import torch -from flash import DataKeys, Trainer +from flash import DataKeys from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING from flash.text import TranslationTask from tests.helpers.task_tester import TaskTester -# ======== Mock functions ======== - - -class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index): - return { - "input_ids": torch.randint(1000, size=(128,)), - DataKeys.TARGET: torch.randint(1000, size=(128,)), - } - - def __len__(self) -> int: - return 100 - - -# ============================== - TEST_BACKBONE = "sshleifer/tiny-mbart" # tiny model for testing class TestTranslationTask(TaskTester): task = TranslationTask - task_kwargs = {"backbone": TEST_BACKBONE} + task_kwargs = { + "backbone": TEST_BACKBONE, + "tokenizer_kwargs": {"src_lang": "en_XX", "tgt_lang": "ro_RO"}, + } cli_command = "translation" is_testing = _TEXT_TESTING is_available = _TEXT_AVAILABLE @@ -62,14 +48,9 @@ def check_forward_output(self, output: Any): assert isinstance(output, torch.Tensor) assert output.shape == torch.Size([1, 128]) - -@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") -def test_init_train(tmpdir): - model = TranslationTask(TEST_BACKBONE) - train_dl = torch.utils.data.DataLoader(DummyDataset()) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.fit(model, train_dl) + @property + def example_train_sample(self): + return {DataKeys.INPUT: "Some text", DataKeys.TARGET: "Some translated text"} @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")