Skip to content

Commit

Permalink
[examples/run_s2s] remove task_specific_params and update rouge compu…
Browse files Browse the repository at this point in the history
…tation (#10133)

* fix rouge metrics and task specific params

* fix typo

* round metrics

* typo

* remove task_specific_params
  • Loading branch information
patil-suraj authored Feb 12, 2021
1 parent 3124577 commit f51188c
Showing 1 changed file with 44 additions and 19 deletions.
63 changes: 44 additions & 19 deletions examples/seq2seq/run_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from dataclasses import dataclass, field
from typing import Optional

import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import load_dataset, load_metric

import transformers
from filelock import FileLock
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
Expand All @@ -44,6 +46,10 @@
from transformers.trainer_utils import get_last_checkpoint, is_main_process


with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -110,10 +116,22 @@ class DataTrainingArguments:
default=None,
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
Expand Down Expand Up @@ -298,6 +316,9 @@ def main():
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
Expand Down Expand Up @@ -335,15 +356,7 @@ def main():
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

# Get the default prefix if None is passed.
if data_args.source_prefix is None:
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
prefix = task_specific_params.get("prefix", "")
else:
prefix = ""
else:
prefix = data_args.source_prefix
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

# Preprocessing the datasets.
# We need to tokenize inputs and targets.
Expand Down Expand Up @@ -487,6 +500,19 @@ def preprocess_function(examples):
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
metric = load_metric(metric_name)

def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]

# rougeLSum expects newline after each sentence
if metric_name == "rouge":
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
else: # sacrebleu
labels = [[label] for label in labels]

return preds, labels

def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
Expand All @@ -498,22 +524,19 @@ def compute_metrics(eval_preds):
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

# Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [label.strip() for label in decoded_labels]
if metric_name == "sacrebleu":
decoded_labels = [[label] for label in decoded_labels]
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

result = metric.compute(predictions=decoded_preds, references=decoded_labels)

# Extract a few results from ROUGE
if metric_name == "rouge":
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
else:
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"bleu": result["score"]}

prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)

result = {k: round(v, 4) for k, v in result.items()}
return result

# Initialize our Trainer
Expand Down Expand Up @@ -555,6 +578,7 @@ def compute_metrics(eval_preds):
logger.info("*** Evaluate ***")

results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
results = {k: round(v, 4) for k, v in results.items()}

output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
if trainer.is_world_process_zero():
Expand All @@ -574,6 +598,7 @@ def compute_metrics(eval_preds):
num_beams=data_args.num_beams,
)
test_metrics = test_results.metrics
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)

output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
if trainer.is_world_process_zero():
Expand Down

0 comments on commit f51188c

Please sign in to comment.