From b71401e250fe0f7966d9d0116629a8fd03f698fd Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 11 Aug 2021 12:08:33 +0200 Subject: [PATCH 01/10] deprecate --- flash/core/utilities/imports.py | 4 +-- flash/text/seq2seq/core/metrics.py | 21 +++++++++++++- flash/text/seq2seq/summarization/model.py | 7 ++--- flash/text/seq2seq/translation/model.py | 3 +- requirements.txt | 2 +- requirements/datatype_text.txt | 2 +- tests/text/seq2seq/core/test_metrics.py | 34 ----------------------- 7 files changed, 28 insertions(+), 45 deletions(-) delete mode 100644 tests/text/seq2seq/core/test_metrics.py diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index a1375fca9b..7623c91df0 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -93,9 +93,9 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") _TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric") _TORCHAUDIO_AVAILABLE = _module_available("torchaudio") -_ROUGE_SCORE_AVAILABLE = _module_available("rouge_score") _SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece") _DATASETS_AVAILABLE = _module_available("datasets") +_TM_TEXT_AVAILABLE: bool = _module_available("torchmetrics.text") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") @@ -103,9 +103,9 @@ def _compare_version(package: str, op, version) -> bool: _TEXT_AVAILABLE = all( [ _TRANSFORMERS_AVAILABLE, - _ROUGE_SCORE_AVAILABLE, _SENTENCEPIECE_AVAILABLE, _DATASETS_AVAILABLE, + _TM_TEXT_AVAILABLE, ] ) _TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE diff --git a/flash/text/seq2seq/core/metrics.py b/flash/text/seq2seq/core/metrics.py index 621bb23d74..40ac5b28db 100644 --- a/flash/text/seq2seq/core/metrics.py +++ b/flash/text/seq2seq/core/metrics.py @@ -23,6 +23,7 @@ import torch from torch import tensor from torchmetrics import Metric +from warnings import warn from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras from flash.text.seq2seq.core.utils import add_newline_to_end_of_each_sentence @@ -64,14 +65,23 @@ class BLEUScore(Metric): >>> metric = BLEUScore() >>> metric(translate_corpus, reference_corpus) tensor(0.7598) - """ + .. deprecated:: v0.5 + Use :func:`torchmetrics.text.BLEUScore` instead. Will be removed in v0.6. + + """ def __init__(self, n_gram: int = 4, smooth: bool = False): """ Args: n_gram: Gram value ranged from 1 to 4 (Default 4) smooth: Whether or not to apply smoothing – Lin et al. 2004 """ + warn( + "Metric `text.seq2seq.core.metrics.BLEUScore` is deprecated in v0.5 and will be removed in v0.6." + " Use `torchmetrics.text.BLEUScore` instead.", + DeprecationWarning, + ) + super().__init__() self.n_gram = n_gram self.smooth = smooth @@ -151,6 +161,10 @@ class RougeMetric(Metric): 'rougeLsum_fmeasure': 0.25, 'rougeLsum_precision': 0.25, 'rougeLsum_recall': 0.25} + + .. deprecated:: v0.5 + Use :func:`torchmetrics.text.ROUGEScore` instead. Will be removed in v0.6. + """ @requires_extras("text") @@ -160,6 +174,11 @@ def __init__( use_stemmer: bool = False, rouge_keys: Tuple[str] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), ): + warn( + "Metric `text.seq2seq.core.metrics.RougeScore` is deprecated in v0.5 and will be removed in v0.6." + " Use `torchmetrics.text.ROUGEScore` instead.", + DeprecationWarning, + ) super().__init__() self.rouge_newline_sep = rouge_newline_sep diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index af7820b10e..b3e80cea8a 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -14,9 +14,8 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch -from torchmetrics import Metric +from torchmetrics import Metric, ROUGEScore -from flash.text.seq2seq.core.metrics import RougeMetric from flash.text.seq2seq.core.model import Seq2SeqTask @@ -66,8 +65,8 @@ def __init__( val_target_max_length=val_target_max_length, num_beams=num_beams, ) - self.rouge = RougeMetric( - rouge_newline_sep=rouge_newline_sep, + self.rouge = ROUGEScore( + newline_sep=rouge_newline_sep, use_stemmer=use_stemmer, ) diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index ad99f47e31..aba48b6652 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -14,9 +14,8 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch -from torchmetrics import Metric +from torchmetrics import Metric, BLEUScore -from flash.text.seq2seq.core.metrics import BLEUScore from flash.text.seq2seq.core.model import Seq2SeqTask diff --git a/requirements.txt b/requirements.txt index 0693689f06..888883abfc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch -torchmetrics +torchmetrics>=0.5.0 pytorch-lightning>=1.4.0rc0 pyDeprecate PyYAML>=5.1 diff --git a/requirements/datatype_text.txt b/requirements/datatype_text.txt index 9953e12545..314ffc34a7 100644 --- a/requirements/datatype_text.txt +++ b/requirements/datatype_text.txt @@ -1,5 +1,5 @@ -rouge-score>=0.0.4 sentencepiece>=0.1.95 filelock transformers>=4.5 datasets>=1.2, <1.3 +torchmetrics[text] diff --git a/tests/text/seq2seq/core/test_metrics.py b/tests/text/seq2seq/core/test_metrics.py deleted file mode 100644 index c16f828c37..0000000000 --- a/tests/text/seq2seq/core/test_metrics.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -import torch - -from flash.text.seq2seq.core.metrics import BLEUScore, RougeMetric -from tests.helpers.utils import _TEXT_TESTING - - -@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") -def test_rouge(): - preds = "My name is John".split() - target = "Is your name John".split() - metric = RougeMetric() - assert torch.allclose(torch.tensor(metric(preds, target)["rouge1_recall"]).float(), torch.tensor(0.25), 1e-4) - - -@pytest.mark.parametrize("smooth, expected", [(False, 0.7598), (True, 0.8091)]) -def test_bleu_score(smooth, expected): - translate_corpus = ["the cat is on the mat".split()] - reference_corpus = [["there is a cat on the mat".split(), "a cat is on the mat".split()]] - metric = BLEUScore(smooth=smooth) - assert torch.allclose(metric(translate_corpus, reference_corpus), torch.tensor(expected), 1e-4) From 99f5c3028032947d028fe05e7a1c8e7709e31a8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Aug 2021 10:10:48 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/text/seq2seq/core/metrics.py | 5 ++--- flash/text/seq2seq/translation/model.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/flash/text/seq2seq/core/metrics.py b/flash/text/seq2seq/core/metrics.py index 40ac5b28db..88ab71d63a 100644 --- a/flash/text/seq2seq/core/metrics.py +++ b/flash/text/seq2seq/core/metrics.py @@ -18,12 +18,12 @@ # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter from typing import Dict, List, Tuple +from warnings import warn import numpy as np import torch from torch import tensor from torchmetrics import Metric -from warnings import warn from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras from flash.text.seq2seq.core.utils import add_newline_to_end_of_each_sentence @@ -68,8 +68,8 @@ class BLEUScore(Metric): .. deprecated:: v0.5 Use :func:`torchmetrics.text.BLEUScore` instead. Will be removed in v0.6. - """ + def __init__(self, n_gram: int = 4, smooth: bool = False): """ Args: @@ -164,7 +164,6 @@ class RougeMetric(Metric): .. deprecated:: v0.5 Use :func:`torchmetrics.text.ROUGEScore` instead. Will be removed in v0.6. - """ @requires_extras("text") diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index aba48b6652..430a104c01 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch -from torchmetrics import Metric, BLEUScore +from torchmetrics import BLEUScore, Metric from flash.text.seq2seq.core.model import Seq2SeqTask From c361f6a6290ab8b2bd20c1414cb46891ee27d2f0 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 25 Aug 2021 07:45:23 +0900 Subject: [PATCH 03/10] Map from text metrics to torchmetrics --- flash/text/seq2seq/core/metrics.py | 237 ++--------------------------- flash/text/seq2seq/core/utils.py | 29 ---- 2 files changed, 14 insertions(+), 252 deletions(-) delete mode 100644 flash/text/seq2seq/core/utils.py diff --git a/flash/text/seq2seq/core/metrics.py b/flash/text/seq2seq/core/metrics.py index a3d42410f5..cdcdc02c9f 100644 --- a/flash/text/seq2seq/core/metrics.py +++ b/flash/text/seq2seq/core/metrics.py @@ -16,238 +16,29 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from collections import Counter -from typing import Dict, List, Tuple -from warnings import warn +from functools import partial +from typing import Tuple -import numpy as np -import torch -from torch import tensor -from torchmetrics import Metric +from deprecate import deprecated, void +from pytorch_lightning.utilities import rank_zero_deprecation +from torchmetrics.text import BLEUScore as _BLEUScore +from torchmetrics.text import ROUGEScore as _ROUGEScore -from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras -from flash.text.seq2seq.core.utils import add_newline_to_end_of_each_sentence +_deprecated_text_metrics = partial(deprecated, deprecated_in="0.5.0", remove_in="0.6.0", stream=rank_zero_deprecation) -if _TEXT_AVAILABLE: - from rouge_score import rouge_scorer - from rouge_score.scoring import AggregateScore, BootstrapAggregator, Score -else: - AggregateScore, Score, BootstrapAggregator = None, None, object - - -def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: - """ - Counting how many times each word appears in a given text with ngram - Args: - ngram_input_list: A list of translated text or reference texts - n_gram: gram value ranged 1 to 4 - - Return: - ngram_counter: a collections.Counter object of ngram - """ - - ngram_counter = Counter() - - for i in range(1, n_gram + 1): - for j in range(len(ngram_input_list) - i + 1): - ngram_key = tuple(ngram_input_list[j : (i + j)]) - ngram_counter[ngram_key] += 1 - - return ngram_counter - - -class BLEUScore(Metric): - """Calculate BLEU score of machine translated text with one or more references. - - Example: - >>> translate_corpus = ['the cat is on the mat'.split()] - >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] - >>> metric = BLEUScore() - >>> metric(translate_corpus, reference_corpus) - tensor(0.7598) - - .. deprecated:: v0.5 - Use :func:`torchmetrics.text.BLEUScore` instead. Will be removed in v0.6. - """ +class BLEUScore(_BLEUScore): + @_deprecated_text_metrics(target=_BLEUScore) def __init__(self, n_gram: int = 4, smooth: bool = False): - """ - Args: - n_gram: Gram value ranged from 1 to 4 (Default 4) - smooth: Whether or not to apply smoothing – Lin et al. 2004 - """ - warn( - "Metric `text.seq2seq.core.metrics.BLEUScore` is deprecated in v0.5 and will be removed in v0.6." - " Use `torchmetrics.text.BLEUScore` instead.", - DeprecationWarning, - ) - - super().__init__() - self.n_gram = n_gram - self.smooth = smooth - - self.add_state("c", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("r", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - - def compute(self): - - trans_len = self.c.clone().detach() - ref_len = self.r.clone().detach() - - if min(self.numerator) == 0.0: - return tensor(0.0, device=self.r.device) - - if self.smooth: - precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0) - else: - precision_scores = self.numerator / self.denominator - - log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram, device=self.r.device) * torch.log( - precision_scores - ) - geometric_mean = torch.exp(torch.sum(log_precision_scores)) - brevity_penalty = tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len)) - bleu = brevity_penalty * geometric_mean - return bleu - - def update(self, translate_corpus, reference_corpus) -> None: - """ - Actual metric computation - Args: - translate_corpus: An iterable of machine translated corpus - reference_corpus: An iterable of iterables of reference corpus - """ - for (translation, references) in zip(translate_corpus, reference_corpus): - self.c += len(translation) - ref_len_list = [len(ref) for ref in references] - ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] - self.r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] - translation_counter = _count_ngram(translation, self.n_gram) - reference_counter = Counter() - - for ref in references: - reference_counter |= _count_ngram(ref, self.n_gram) - - ngram_counter_clip = translation_counter & reference_counter + void(n_gram, smooth) - for counter_clip in ngram_counter_clip: - self.numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] - for counter in translation_counter: - self.denominator[len(counter) - 1] += translation_counter[counter] - - -class RougeMetric(Metric): - """Metric used for automatic summarization. https://www.aclweb.org/anthology/W04-1013/ - - Example: - - >>> target = "Is your name John".split() - >>> preds = "My name is John".split() - >>> rouge = RougeMetric() # doctest: +SKIP - >>> from pprint import pprint - >>> pprint(rouge(preds, target)) # doctest: +NORMALIZE_WHITESPACE +SKIP - {'rouge1_fmeasure': 0.25, - 'rouge1_precision': 0.25, - 'rouge1_recall': 0.25, - 'rouge2_fmeasure': 0.0, - 'rouge2_precision': 0.0, - 'rouge2_recall': 0.0, - 'rougeL_fmeasure': 0.25, - 'rougeL_precision': 0.25, - 'rougeL_recall': 0.25, - 'rougeLsum_fmeasure': 0.25, - 'rougeLsum_precision': 0.25, - 'rougeLsum_recall': 0.25} - - .. deprecated:: v0.5 - Use :func:`torchmetrics.text.ROUGEScore` instead. Will be removed in v0.6. - """ - - @requires_extras("text") +class RougeMetric(_ROUGEScore): + @_deprecated_text_metrics(target=_ROUGEScore) def __init__( self, - rouge_newline_sep: bool = False, + newline_sep: bool = False, use_stemmer: bool = False, rouge_keys: Tuple[str] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), ): - warn( - "Metric `text.seq2seq.core.metrics.RougeScore` is deprecated in v0.5 and will be removed in v0.6." - " Use `torchmetrics.text.ROUGEScore` instead.", - DeprecationWarning, - ) - super().__init__() - - self.rouge_newline_sep = rouge_newline_sep - self.rouge_keys = rouge_keys - self.use_stemmer = use_stemmer - self.aggregator = RougeBatchAggregator() - self.scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=self.use_stemmer) - - for key in rouge_keys: - self.add_state(key, []) - - def update(self, pred_lns: List[str], tgt_lns: List[str]): - for pred, tgt in zip(pred_lns, tgt_lns): - # rougeLsum expects "\n" separated sentences within a summary - if self.rouge_newline_sep: - pred = add_newline_to_end_of_each_sentence(pred) - tgt = add_newline_to_end_of_each_sentence(tgt) - results = self.scorer.score(pred, tgt) - for key, score in results.items(): - score = tensor([score.precision, score.recall, score.fmeasure]) - getattr(self, key).append(score) - - def compute(self) -> Dict[str, float]: - scores = {key: getattr(self, key) for key in self.rouge_keys} - self.aggregator.add_scores(scores) - result = self.aggregator.aggregate() - return format_rouge_results(result) - - def __hash__(self): - # override to hash list objects. - # this is a bug in the upstream pytorch release. - hash_vals = [self.__class__.__name__] - - for key in self._defaults: - value = getattr(self, key) - if isinstance(value, list): - value = tuple(value) - hash_vals.append(value) - - return hash(tuple(hash_vals)) - - -class RougeBatchAggregator(BootstrapAggregator): - """Aggregates rouge scores and provides confidence intervals.""" - - def aggregate(self): - """Override function to wrap the final results in `Score` objects. - - This is due to the scores being replaced with a list of torch tensors. - """ - result = {} - for score_type, scores in self._scores.items(): - # Stack scores into a 2-d matrix of (sample, measure). - score_matrix = np.vstack(tuple(scores)) - # Percentiles are returned as (interval, measure). - percentiles = self._bootstrap_resample(score_matrix) - # Extract the three intervals (low, mid, high). - intervals = tuple(Score(*percentiles[j, :]) for j in range(3)) - result[score_type] = AggregateScore(low=intervals[0], mid=intervals[1], high=intervals[2]) - return result - - def add_scores(self, scores): - self._scores = scores - - -def format_rouge_results(result: Dict[str, AggregateScore], decimal_places: int = 4) -> Dict[str, float]: - flattened_result = {} - for rouge_key, rouge_aggregate_score in result.items(): - for stat in ["precision", "recall", "fmeasure"]: - mid = rouge_aggregate_score.mid - score = round(getattr(mid, stat), decimal_places) - flattened_result[f"{rouge_key}_{stat}"] = score - return flattened_result + void(newline_sep, use_stemmer, rouge_keys) diff --git a/flash/text/seq2seq/core/utils.py b/flash/text/seq2seq/core/utils.py deleted file mode 100644 index e48248754c..0000000000 --- a/flash/text/seq2seq/core/utils.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2020 The PyTorch Lightning team and The HuggingFace Team. All rights reserved. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import re - -from pytorch_lightning.utilities import _module_available - -nltk = None -if _module_available("nltk"): - import nltk - - nltk.download("punkt", quiet=True) - - -def add_newline_to_end_of_each_sentence(x: str) -> str: - """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" - re.sub("", "", x) # remove pegasus newline char - assert nltk, "nltk must be installed to separate newlines between sentences. (pip install nltk)" - return "\n".join(nltk.sent_tokenize(x)) From 88e522d8fb9be478d3659c1296f0193fe660df39 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 25 Aug 2021 07:52:13 +0900 Subject: [PATCH 04/10] Fix docs --- docs/source/api/text.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/api/text.rst b/docs/source/api/text.rst index f9177eec85..0a71cd2648 100644 --- a/docs/source/api/text.rst +++ b/docs/source/api/text.rst @@ -89,5 +89,4 @@ _______________ seq2seq.core.data.Seq2SeqPreprocess seq2seq.core.data.Seq2SeqSentencesDataSource seq2seq.core.metrics.BLEUScore - seq2seq.core.metrics.RougeBatchAggregator seq2seq.core.metrics.RougeMetric From bf9bb9e31c8255d032ab145d8ae504ba145592b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Nov 2021 11:22:02 +0000 Subject: [PATCH 05/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/text/seq2seq/core/metrics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash/text/seq2seq/core/metrics.py b/flash/text/seq2seq/core/metrics.py index 706047e570..cdcdc02c9f 100644 --- a/flash/text/seq2seq/core/metrics.py +++ b/flash/text/seq2seq/core/metrics.py @@ -32,6 +32,7 @@ class BLEUScore(_BLEUScore): def __init__(self, n_gram: int = 4, smooth: bool = False): void(n_gram, smooth) + class RougeMetric(_ROUGEScore): @_deprecated_text_metrics(target=_ROUGEScore) def __init__( From b4ee1ca68ba01451ac586e38f088f6340cae3b64 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 9 Nov 2021 11:32:09 +0000 Subject: [PATCH 06/10] Pre-commit --- flash/text/seq2seq/translation/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index f9f5ae5ef9..553adb6b7a 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -13,8 +13,7 @@ # limitations under the License. from typing import Any, Dict, List, Optional -import torch -from torchmetrics import BLEUScore, Metric +from torchmetrics import BLEUScore from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.seq2seq.core.model import Seq2SeqTask From 0b5fac2d349da15c5e7557a8e85a0816fa6974ab Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 9 Nov 2021 11:33:02 +0000 Subject: [PATCH 07/10] Update deprecated / remove in --- flash/text/seq2seq/core/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/text/seq2seq/core/metrics.py b/flash/text/seq2seq/core/metrics.py index cdcdc02c9f..5eee448851 100644 --- a/flash/text/seq2seq/core/metrics.py +++ b/flash/text/seq2seq/core/metrics.py @@ -24,7 +24,7 @@ from torchmetrics.text import BLEUScore as _BLEUScore from torchmetrics.text import ROUGEScore as _ROUGEScore -_deprecated_text_metrics = partial(deprecated, deprecated_in="0.5.0", remove_in="0.6.0", stream=rank_zero_deprecation) +_deprecated_text_metrics = partial(deprecated, deprecated_in="0.6.0", remove_in="0.7.0", stream=rank_zero_deprecation) class BLEUScore(_BLEUScore): From 4b27867198e44b19060c90ff5c839b69ee19d782 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 9 Nov 2021 11:36:40 +0000 Subject: [PATCH 08/10] Prune docs --- docs/source/api/text.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/api/text.rst b/docs/source/api/text.rst index 383aab96f3..50750dad05 100644 --- a/docs/source/api/text.rst +++ b/docs/source/api/text.rst @@ -99,5 +99,3 @@ _______________ seq2seq.core.data.Seq2SeqOutputTransform seq2seq.core.data.Seq2SeqPreprocess seq2seq.core.data.Seq2SeqSentencesDataSource - seq2seq.core.metrics.BLEUScore - seq2seq.core.metrics.RougeMetric From 713ba482fc1ebbdcbe0ff3799b0e84e1b08d3ce5 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 9 Nov 2021 11:37:57 +0000 Subject: [PATCH 09/10] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61a3665282..b74da5e12c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Task.serializer` in favour of `Task.output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927)) +- Deprecated `flash.text.seq2seq.core.metrics` in favour of `torchmetrics[text]` ([#648](https://github.com/PyTorchLightning/lightning-flash/pull/648)) + ### Fixed ### Removed From 2e4c02c0cfc4cb76f9a928d41bd85a6e5e2eac1f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 9 Nov 2021 12:06:43 +0000 Subject: [PATCH 10/10] Fix --- flash/text/question_answering/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 7087239c18..1c17b0eca5 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -125,7 +125,7 @@ def __init__( self._initialize_model_specific_parameters() self.rouge = RougeMetric( - rouge_newline_sep=rouge_newline_sep, + newline_sep=rouge_newline_sep, use_stemmer=use_stemmer, )