Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples/run_s2s] remove task_specific_params and update rouge computation #10133

Merged
merged 5 commits into from
Feb 12, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions examples/seq2seq/run_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from dataclasses import dataclass, field
from typing import Optional

import nltk # Here to have a nice missing dependency error message early on
stas00 marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
from datasets import load_dataset, load_metric

import transformers
from filelock import FileLock
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
Expand All @@ -44,6 +46,10 @@
from transformers.trainer_utils import get_last_checkpoint, is_main_process


with FileLock(".lock") as lock:
nltk.download("punkt", quiet=True)


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -110,10 +116,22 @@ class DataTrainingArguments:
default=None,
metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
metadata={
"help": "An optional input evaluation data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
test_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input test data file to evaluate the metrics (rouge/sacreblue) on "
"(a jsonlines or csv file)."
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
Expand Down Expand Up @@ -298,6 +316,9 @@ def main():
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
Expand Down Expand Up @@ -335,15 +356,7 @@ def main():
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

# Get the default prefix if None is passed.
if data_args.source_prefix is None:
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
prefix = task_specific_params.get("prefix", "")
else:
prefix = ""
else:
prefix = data_args.source_prefix
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

# Preprocessing the datasets.
# We need to tokenize inputs and targets.
Expand Down Expand Up @@ -487,6 +500,19 @@ def preprocess_function(examples):
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
metric = load_metric(metric_name)

def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]

# rougeLSum expects newline after each sentence
if metric_name == "rouge":
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
else: # sacrebleu
labels = [[label] for label in labels]

return preds, labels

def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
Expand All @@ -498,22 +524,19 @@ def compute_metrics(eval_preds):
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

# Some simple post-processing
decoded_preds = [pred.strip() for pred in decoded_preds]
decoded_labels = [label.strip() for label in decoded_labels]
if metric_name == "sacrebleu":
decoded_labels = [[label] for label in decoded_labels]
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

result = metric.compute(predictions=decoded_preds, references=decoded_labels)

# Extract a few results from ROUGE
if metric_name == "rouge":
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
# Extract a few results from ROUGE
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
else:
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"bleu": result["score"]}

prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)

result = {k: round(v, 4) for k, v in result.items()}
return result

# Initialize our Trainer
Expand Down Expand Up @@ -555,6 +578,7 @@ def compute_metrics(eval_preds):
logger.info("*** Evaluate ***")

results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
results = {k: round(v, 4) for k, v in results.items()}

output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
if trainer.is_world_process_zero():
Expand All @@ -574,6 +598,7 @@ def compute_metrics(eval_preds):
num_beams=data_args.num_beams,
)
test_metrics = test_results.metrics
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)

output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
if trainer.is_world_process_zero():
Expand Down