Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Commit

Permalink
tests: switch/fix XSum test dataset (#310)
Browse files Browse the repository at this point in the history
* switch test dataset
* fix CNN
* workers
* chlog
  • Loading branch information
Borda authored Nov 21, 2022
1 parent a49e149 commit dd6e726
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [0.2.5] - 2022-11-DD

### Fixed

- Fixed passing config name to `CNNDailyMailSummarizationDataModule` ([#310](https://github.com/Lightning-AI/lightning-transformers/pull/310))


## [0.2.4] - 2022-11-03

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@


class CNNDailyMailSummarizationDataModule(SummarizationDataModule):
def __init__(self, *args, dataset_name: str = "cnn_dailymail", **kwargs):
super().__init__(*args, dataset_name=dataset_name, **kwargs)
def __init__(self, *args, dataset_name: str = "cnn_dailymail", config_name: str = "3.0.0", **kwargs):
super().__init__(*args, dataset_name=dataset_name, dataset_config_name=config_name, **kwargs)

@property
def source_target_column_names(self) -> Tuple[str, str]:
Expand Down
17 changes: 10 additions & 7 deletions tests/task/nlp/test_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,30 @@
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.summarization import (
CNNDailyMailSummarizationDataModule,
SummarizationDataModule,
SummarizationTransformer,
XsumSummarizationDataModule,
)

_MODEL_TINY = "patrickvonplaten/t5-tiny-random"


@pytest.mark.skipif(sys.platform == "win32", reason="Currently Windows is not supported")
def test_smoke_train(hf_cache_path):
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random")
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=_MODEL_TINY)
model = SummarizationTransformer(
pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random",
pretrained_model_name_or_path=_MODEL_TINY,
use_stemmer=True,
val_target_max_length=142,
num_beams=None,
compute_generate_metrics=True,
)
dm = XsumSummarizationDataModule(
dm = CNNDailyMailSummarizationDataModule(
limit_train_samples=64,
limit_val_samples=64,
limit_test_samples=64,
batch_size=1,
num_workers=2,
preprocessing_num_workers=2,
max_source_length=128,
max_target_length=128,
cache_dir=hf_cache_path,
Expand All @@ -41,8 +44,8 @@ def test_smoke_train(hf_cache_path):
@pytest.mark.skipif(sys.platform == "win32", reason="Currently Windows is not supported")
def test_smoke_predict():
model = SummarizationTransformer(
pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random",
tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random"),
pretrained_model_name_or_path=_MODEL_TINY,
tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path=_MODEL_TINY),
)

y = model.hf_predict(
Expand Down

0 comments on commit dd6e726

Please sign in to comment.