Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into tabular_classification/from_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
krshrimali authored May 9, 2022
2 parents dcfaefd + 4a128ce commit 730f4f2
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 174 deletions.
62 changes: 34 additions & 28 deletions tests/text/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any
from unittest import mock

import pytest
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash import Trainer
import flash
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING
from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING, _TORCH_ORT_AVAILABLE
from flash.text import TextClassifier
from tests.helpers.task_tester import TaskTester

# ======== Mock functions ========


class DummyDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
return {
"input_ids": torch.randint(1000, size=(100,)),
DataKeys.TARGET: torch.randint(2, size=(1,)).item(),
}

def __len__(self) -> int:
return 100


# ==============================
from flash.text.ort_callback import ORTCallback
from tests.helpers.boring_model import BoringModel
from tests.helpers.task_tester import StaticDataset, TaskTester

TEST_BACKBONE = "prajjwal1/bert-tiny" # tiny model for testing

Expand All @@ -54,7 +40,21 @@ class TestTextClassifier(TaskTester):

scriptable = False

marks = {"test_cli": [pytest.mark.parametrize("extra_args", ([], ["from_toxic"]))]}
marks = {
"test_fit": [
pytest.mark.parametrize(
"task_kwargs",
[
{},
pytest.param(
{"enable_ort": True},
marks=pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed."),
),
],
)
],
"test_cli": [pytest.mark.parametrize("extra_args", ([], ["from_toxic"]))],
}

@property
def example_forward_input(self):
Expand All @@ -64,14 +64,20 @@ def check_forward_output(self, output: Any):
assert isinstance(output, torch.Tensor)
assert output.shape == torch.Size([1, 2])

@property
def example_train_sample(self):
return {DataKeys.INPUT: "some text", DataKeys.TARGET: 1}

@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_init_train(tmpdir):
model = TextClassifier(2, backbone=TEST_BACKBONE)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_dl)
@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
def test_ort_callback_fails_no_model(self, tmpdir):
dataset = StaticDataset(self.example_train_sample, 4)

model = BoringModel()

trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback())

with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"):
trainer.fit(model, model.process_train_dataset(dataset, batch_size=4))


@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")
Expand Down
61 changes: 0 additions & 61 deletions tests/text/classification/test_ort.py

This file was deleted.

39 changes: 8 additions & 31 deletions tests/text/question_answering/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
from typing import Any

import pytest
import torch

from flash import Trainer
from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING
from flash.text import QuestionAnsweringTask
from tests.helpers.task_tester import TaskTester

# ======== Mock functions ========

SEQUENCE_LENGTH = 384


class DummyDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
return {
"input_ids": torch.randint(1000, size=(SEQUENCE_LENGTH,)),
"attention_mask": torch.randint(1, size=(SEQUENCE_LENGTH,)),
"start_positions": torch.randint(1000, size=(1,)),
"end_positions": torch.randint(1000, size=(1,)),
}

def __len__(self) -> int:
return 100


# ==============================

TEST_BACKBONE = "distilbert-base-uncased"


Expand Down Expand Up @@ -70,17 +48,16 @@ def check_forward_output(self, output: Any):
assert isinstance(output[0], torch.Tensor)
assert isinstance(output[1], collections.OrderedDict)

@property
def example_train_sample(self):
return {
"question": "A question",
"answer": {"text": ["The answer"], "answer_start": [0]},
"context": "The paragraph of text which contains the answer to the question",
}


@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_modules_to_freeze():
model = QuestionAnsweringTask(backbone=TEST_BACKBONE)
assert model.modules_to_freeze() is model.model.distilbert


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_init_train(tmpdir):
model = QuestionAnsweringTask(TEST_BACKBONE)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_dl)
35 changes: 8 additions & 27 deletions tests/text/seq2seq/summarization/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any
from unittest import mock

import pytest
import torch

from flash import DataKeys, Trainer
from flash import DataKeys
from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING
from flash.text import SummarizationTask
from tests.helpers.task_tester import TaskTester

# ======== Mock functions ========


class DummyDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
return {
"input_ids": torch.randint(1000, size=(128,)),
DataKeys.TARGET: torch.randint(1000, size=(128,)),
}

def __len__(self) -> int:
return 100


# ==============================

TEST_BACKBONE = "sshleifer/tiny-mbart" # tiny model for testing


class TestSummarizationTask(TaskTester):

task = SummarizationTask
task_kwargs = {"backbone": TEST_BACKBONE}
task_kwargs = {
"backbone": TEST_BACKBONE,
"tokenizer_kwargs": {"src_lang": "en_XX", "tgt_lang": "en_XX"},
}
cli_command = "summarization"
is_testing = _TEXT_TESTING
is_available = _TEXT_AVAILABLE
Expand All @@ -62,14 +48,9 @@ def check_forward_output(self, output: Any):
assert isinstance(output, torch.Tensor)
assert output.shape == torch.Size([1, 128])


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_init_train(tmpdir):
model = SummarizationTask(TEST_BACKBONE)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_dl)
@property
def example_train_sample(self):
return {DataKeys.INPUT: "Some long passage of text", DataKeys.TARGET: "A summary"}


@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")
Expand Down
35 changes: 8 additions & 27 deletions tests/text/seq2seq/translation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any
from unittest import mock

import pytest
import torch

from flash import DataKeys, Trainer
from flash import DataKeys
from flash.core.utilities.imports import _SERVE_TESTING, _TEXT_AVAILABLE, _TEXT_TESTING
from flash.text import TranslationTask
from tests.helpers.task_tester import TaskTester

# ======== Mock functions ========


class DummyDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
return {
"input_ids": torch.randint(1000, size=(128,)),
DataKeys.TARGET: torch.randint(1000, size=(128,)),
}

def __len__(self) -> int:
return 100


# ==============================

TEST_BACKBONE = "sshleifer/tiny-mbart" # tiny model for testing


class TestTranslationTask(TaskTester):

task = TranslationTask
task_kwargs = {"backbone": TEST_BACKBONE}
task_kwargs = {
"backbone": TEST_BACKBONE,
"tokenizer_kwargs": {"src_lang": "en_XX", "tgt_lang": "ro_RO"},
}
cli_command = "translation"
is_testing = _TEXT_TESTING
is_available = _TEXT_AVAILABLE
Expand All @@ -62,14 +48,9 @@ def check_forward_output(self, output: Any):
assert isinstance(output, torch.Tensor)
assert output.shape == torch.Size([1, 128])


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_init_train(tmpdir):
model = TranslationTask(TEST_BACKBONE)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_dl)
@property
def example_train_sample(self):
return {DataKeys.INPUT: "Some text", DataKeys.TARGET: "Some translated text"}


@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")
Expand Down

0 comments on commit 730f4f2

Please sign in to comment.