Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Deprecate text metrics #648

Merged
merged 17 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Deprecated `Task.serializer` in favour of `Task.output` ([#927](https://github.com/PyTorchLightning/lightning-flash/pull/927))

- Deprecated `flash.text.seq2seq.core.metrics` in favour of `torchmetrics[text]` ([#648](https://github.com/PyTorchLightning/lightning-flash/pull/648))

### Fixed

### Removed
Expand Down
3 changes: 0 additions & 3 deletions docs/source/api/text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,3 @@ _______________
seq2seq.core.data.Seq2SeqOutputTransform
seq2seq.core.data.Seq2SeqPreprocess
seq2seq.core.data.Seq2SeqSentencesDataSource
seq2seq.core.metrics.BLEUScore
seq2seq.core.metrics.RougeBatchAggregator
seq2seq.core.metrics.RougeMetric
4 changes: 2 additions & 2 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")
_TORCHAUDIO_AVAILABLE = _module_available("torchaudio")
_ROUGE_SCORE_AVAILABLE = _module_available("rouge_score")
_SENTENCEPIECE_AVAILABLE = _module_available("sentencepiece")
_DATASETS_AVAILABLE = _module_available("datasets")
_TM_TEXT_AVAILABLE: bool = _module_available("torchmetrics.text")
_ICEVISION_AVAILABLE = _module_available("icevision")
_ICEDATA_AVAILABLE = _module_available("icedata")
_LEARN2LEARN_AVAILABLE = _module_available("learn2learn") and _compare_version("learn2learn", operator.ge, "0.1.6")
Expand All @@ -123,9 +123,9 @@ class Image:
_TEXT_AVAILABLE = all(
[
_TRANSFORMERS_AVAILABLE,
_ROUGE_SCORE_AVAILABLE,
_SENTENCEPIECE_AVAILABLE,
_DATASETS_AVAILABLE,
_TM_TEXT_AVAILABLE,
]
)
_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE and _FORECASTING_AVAILABLE
Expand Down
2 changes: 1 addition & 1 deletion flash/text/question_answering/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
self._initialize_model_specific_parameters()

self.rouge = RougeMetric(
rouge_newline_sep=rouge_newline_sep,
newline_sep=rouge_newline_sep,
use_stemmer=use_stemmer,
)

Expand Down
219 changes: 14 additions & 205 deletions flash/text/seq2seq/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,220 +16,29 @@
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from collections import Counter
from typing import Dict, List, Tuple
from functools import partial
from typing import Tuple

import numpy as np
import torch
from torch import tensor
from torchmetrics import Metric
from deprecate import deprecated, void
from pytorch_lightning.utilities import rank_zero_deprecation
from torchmetrics.text import BLEUScore as _BLEUScore
from torchmetrics.text import ROUGEScore as _ROUGEScore

from flash.core.utilities.imports import _TEXT_AVAILABLE, requires
from flash.text.seq2seq.core.utils import add_newline_to_end_of_each_sentence
_deprecated_text_metrics = partial(deprecated, deprecated_in="0.6.0", remove_in="0.7.0", stream=rank_zero_deprecation)

if _TEXT_AVAILABLE:
from rouge_score import rouge_scorer
from rouge_score.scoring import AggregateScore, BootstrapAggregator, Score
else:
AggregateScore, Score, BootstrapAggregator = None, None, object


def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
"""
Counting how many times each word appears in a given text with ngram
Args:
ngram_input_list: A list of translated text or reference texts
n_gram: gram value ranged 1 to 4

Return:
ngram_counter: a collections.Counter object of ngram
"""

ngram_counter = Counter()

for i in range(1, n_gram + 1):
for j in range(len(ngram_input_list) - i + 1):
ngram_key = tuple(ngram_input_list[j : (i + j)])
ngram_counter[ngram_key] += 1

return ngram_counter


class BLEUScore(Metric):
"""Calculate BLEU score of machine translated text with one or more references.

Example:
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> metric = BLEUScore()
>>> metric(translate_corpus, reference_corpus)
tensor(0.7598)
"""

class BLEUScore(_BLEUScore):
@_deprecated_text_metrics(target=_BLEUScore)
def __init__(self, n_gram: int = 4, smooth: bool = False):
"""
Args:
n_gram: Gram value ranged from 1 to 4 (Default 4)
smooth: Whether or not to apply smoothing – Lin et al. 2004
"""
super().__init__()
self.n_gram = n_gram
self.smooth = smooth

self.add_state("c", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("r", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum")
self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum")

def compute(self):

trans_len = self.c.clone().detach()
ref_len = self.r.clone().detach()

if min(self.numerator) == 0.0:
return tensor(0.0, device=self.r.device)

if self.smooth:
precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0)
else:
precision_scores = self.numerator / self.denominator

log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram, device=self.r.device) * torch.log(
precision_scores
)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len))
bleu = brevity_penalty * geometric_mean
return bleu

def update(self, translate_corpus, reference_corpus) -> None:
"""
Actual metric computation
Args:
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus
"""
for (translation, references) in zip(translate_corpus, reference_corpus):
self.c += len(translation)
ref_len_list = [len(ref) for ref in references]
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
self.r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
translation_counter = _count_ngram(translation, self.n_gram)
reference_counter = Counter()

for ref in references:
reference_counter |= _count_ngram(ref, self.n_gram)

ngram_counter_clip = translation_counter & reference_counter

for counter_clip in ngram_counter_clip:
self.numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
void(n_gram, smooth)

for counter in translation_counter:
self.denominator[len(counter) - 1] += translation_counter[counter]


class RougeMetric(Metric):
"""Metric used for automatic summarization. https://www.aclweb.org/anthology/W04-1013/

Example:

>>> target = "Is your name John".split()
>>> preds = "My name is John".split()
>>> rouge = RougeMetric() # doctest: +SKIP
>>> from pprint import pprint
>>> pprint(rouge(preds, target)) # doctest: +NORMALIZE_WHITESPACE +SKIP
{'rouge1_fmeasure': 0.25,
'rouge1_precision': 0.25,
'rouge1_recall': 0.25,
'rouge2_fmeasure': 0.0,
'rouge2_precision': 0.0,
'rouge2_recall': 0.0,
'rougeL_fmeasure': 0.25,
'rougeL_precision': 0.25,
'rougeL_recall': 0.25,
'rougeLsum_fmeasure': 0.25,
'rougeLsum_precision': 0.25,
'rougeLsum_recall': 0.25}
"""

@requires("text")
class RougeMetric(_ROUGEScore):
@_deprecated_text_metrics(target=_ROUGEScore)
def __init__(
self,
rouge_newline_sep: bool = False,
newline_sep: bool = False,
use_stemmer: bool = False,
rouge_keys: Tuple[str] = ("rouge1", "rouge2", "rougeL", "rougeLsum"),
):
super().__init__()

self.rouge_newline_sep = rouge_newline_sep
self.rouge_keys = rouge_keys
self.use_stemmer = use_stemmer
self.aggregator = RougeBatchAggregator()
self.scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=self.use_stemmer)

for key in rouge_keys:
self.add_state(key, [])

def update(self, pred_lns: List[str], tgt_lns: List[str]):
for pred, tgt in zip(pred_lns, tgt_lns):
# rougeLsum expects "\n" separated sentences within a summary
if self.rouge_newline_sep:
pred = add_newline_to_end_of_each_sentence(pred)
tgt = add_newline_to_end_of_each_sentence(tgt)
results = self.scorer.score(pred, tgt)
for key, score in results.items():
score = tensor([score.precision, score.recall, score.fmeasure])
getattr(self, key).append(score)

def compute(self) -> Dict[str, float]:
scores = {key: getattr(self, key) for key in self.rouge_keys}
self.aggregator.add_scores(scores)
result = self.aggregator.aggregate()
return format_rouge_results(result)

def __hash__(self):
# override to hash list objects.
# this is a bug in the upstream pytorch release.
hash_vals = [self.__class__.__name__]

for key in self._defaults:
value = getattr(self, key)
if isinstance(value, list):
value = tuple(value)
hash_vals.append(value)

return hash(tuple(hash_vals))


class RougeBatchAggregator(BootstrapAggregator):
"""Aggregates rouge scores and provides confidence intervals."""

def aggregate(self):
"""Override function to wrap the final results in `Score` objects.

This is due to the scores being replaced with a list of torch tensors.
"""
result = {}
for score_type, scores in self._scores.items():
# Stack scores into a 2-d matrix of (sample, measure).
score_matrix = np.vstack(tuple(scores))
# Percentiles are returned as (interval, measure).
percentiles = self._bootstrap_resample(score_matrix)
# Extract the three intervals (low, mid, high).
intervals = tuple(Score(*percentiles[j, :]) for j in range(3))
result[score_type] = AggregateScore(low=intervals[0], mid=intervals[1], high=intervals[2])
return result

def add_scores(self, scores):
self._scores = scores


def format_rouge_results(result: Dict[str, AggregateScore], decimal_places: int = 4) -> Dict[str, float]:
flattened_result = {}
for rouge_key, rouge_aggregate_score in result.items():
for stat in ["precision", "recall", "fmeasure"]:
mid = rouge_aggregate_score.mid
score = round(getattr(mid, stat), decimal_places)
flattened_result[f"{rouge_key}_{stat}"] = score
return flattened_result
void(newline_sep, use_stemmer, rouge_keys)
29 changes: 0 additions & 29 deletions flash/text/seq2seq/core/utils.py

This file was deleted.

6 changes: 3 additions & 3 deletions flash/text/seq2seq/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from typing import Any, Dict, List, Optional

import torch
from torchmetrics import ROUGEScore

from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE
from flash.text.seq2seq.core.metrics import RougeMetric
from flash.text.seq2seq.core.model import Seq2SeqTask


Expand Down Expand Up @@ -72,8 +72,8 @@ def __init__(
num_beams=num_beams,
enable_ort=enable_ort,
)
self.rouge = RougeMetric(
rouge_newline_sep=rouge_newline_sep,
self.rouge = ROUGEScore(
newline_sep=rouge_newline_sep,
use_stemmer=use_stemmer,
)

Expand Down
3 changes: 2 additions & 1 deletion flash/text/seq2seq/translation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
from typing import Any, Dict, List, Optional

from torchmetrics import BLEUScore

from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE
from flash.text.seq2seq.core.metrics import BLEUScore
from flash.text.seq2seq.core.model import Seq2SeqTask


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
packaging
numpy
torch>=1.7.1
torchmetrics>=0.4.0,!=0.5.1
torchmetrics>=0.5.1
pytorch-lightning>=1.4.0
pyDeprecate
pandas<1.3.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/datatype_text.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
rouge-score>=0.0.4
sentencepiece>=0.1.95
filelock
transformers>=4.5
torchmetrics[text]>=0.5.1
datasets>=1.8,<1.13
34 changes: 0 additions & 34 deletions tests/text/seq2seq/core/test_metrics.py

This file was deleted.