diff --git a/t5/evaluation/metrics.py b/t5/evaluation/metrics.py index f51ef62f..908cf828 100644 --- a/t5/evaluation/metrics.py +++ b/t5/evaluation/metrics.py @@ -65,20 +65,22 @@ def bleu(targets, predictions): return {"bleu": bleu_score.score} -def rouge(targets, predictions, score_keys=None): +def rouge(targets, predictions, score_keys=None, spm_model=None): """Computes rouge score. Args: targets: list of strings predictions: list of strings score_keys: list of strings with the keys to compute. + spm_model: string, path to SentencePieceModel model. If provided, this model + is used for tokenizing the targets and predictions. Returns: dict with score_key: rouge score across all targets and predictions """ if score_keys is None: score_keys = ["rouge1", "rouge2", "rougeLsum"] - scorer = rouge_scorer.RougeScorer(score_keys) + scorer = rouge_scorer.RougeScorer(score_keys, spm_model=spm_model) aggregator = scoring.BootstrapAggregator() def _prepare_summary(summary):