From dd6e7262242cd64da48e0f31bfbc9233813933e6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 21 Nov 2022 14:13:51 +0100 Subject: [PATCH] tests: switch/fix XSum test dataset (#310) * switch test dataset * fix CNN * workers * chlog --- CHANGELOG.md | 7 +++++++ .../nlp/summarization/datasets/cnn_dailymail.py | 4 ++-- tests/task/nlp/test_summarization.py | 17 ++++++++++------- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d618ea88..5871ca69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.2.5] - 2022-11-DD + +### Fixed + +- Fixed passing config name to `CNNDailyMailSummarizationDataModule` ([#310](https://github.com/Lightning-AI/lightning-transformers/pull/310)) + + ## [0.2.4] - 2022-11-03 ### Changed diff --git a/lightning_transformers/task/nlp/summarization/datasets/cnn_dailymail.py b/lightning_transformers/task/nlp/summarization/datasets/cnn_dailymail.py index 32eece09..a443b774 100644 --- a/lightning_transformers/task/nlp/summarization/datasets/cnn_dailymail.py +++ b/lightning_transformers/task/nlp/summarization/datasets/cnn_dailymail.py @@ -17,8 +17,8 @@ class CNNDailyMailSummarizationDataModule(SummarizationDataModule): - def __init__(self, *args, dataset_name: str = "cnn_dailymail", **kwargs): - super().__init__(*args, dataset_name=dataset_name, **kwargs) + def __init__(self, *args, dataset_name: str = "cnn_dailymail", config_name: str = "3.0.0", **kwargs): + super().__init__(*args, dataset_name=dataset_name, dataset_config_name=config_name, **kwargs) @property def source_target_column_names(self) -> Tuple[str, str]: diff --git a/tests/task/nlp/test_summarization.py b/tests/task/nlp/test_summarization.py index 39209f68..241f35d4 100644 --- a/tests/task/nlp/test_summarization.py +++ b/tests/task/nlp/test_summarization.py @@ -7,27 +7,30 @@ from transformers import AutoTokenizer from lightning_transformers.task.nlp.summarization import ( + CNNDailyMailSummarizationDataModule, SummarizationDataModule, SummarizationTransformer, - XsumSummarizationDataModule, ) +_MODEL_TINY = "patrickvonplaten/t5-tiny-random" + @pytest.mark.skipif(sys.platform == "win32", reason="Currently Windows is not supported") def test_smoke_train(hf_cache_path): - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random") + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=_MODEL_TINY) model = SummarizationTransformer( - pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random", + pretrained_model_name_or_path=_MODEL_TINY, use_stemmer=True, val_target_max_length=142, num_beams=None, compute_generate_metrics=True, ) - dm = XsumSummarizationDataModule( + dm = CNNDailyMailSummarizationDataModule( limit_train_samples=64, limit_val_samples=64, limit_test_samples=64, - batch_size=1, + num_workers=2, + preprocessing_num_workers=2, max_source_length=128, max_target_length=128, cache_dir=hf_cache_path, @@ -41,8 +44,8 @@ def test_smoke_train(hf_cache_path): @pytest.mark.skipif(sys.platform == "win32", reason="Currently Windows is not supported") def test_smoke_predict(): model = SummarizationTransformer( - pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random", - tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path="patrickvonplaten/t5-tiny-random"), + pretrained_model_name_or_path=_MODEL_TINY, + tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path=_MODEL_TINY), ) y = model.hf_predict(