-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
google/pegasus-cnn_dailymail generates blank file #11289
Comments
Hi @chz816 I can reproduce the issue. This is because pegasus doesn't really work with cc @stas00 |
Thank you @patil-suraj! I have generated the summaries using python -m torch.distributed.launch --nproc_per_node=3 run_distributed_eval.py \
--model_name google/pegasus-cnn_dailymail \
--save_dir $OUTPUT_DIR \
--data_dir $DATA_DIR \
--bs 16 Can you maybe explain why this problem does not exist for |
As @patil-suraj pointed out many models trained in
let's take a single frame - Linear forward for fp32:
fp16:
As you can see edit: well, actually this would be the case if we did This is from WIP PR #11274 - still polishing some nuances but should be ready soon. Let me check |
Regarding the cnn_dailymail scores, please see this issue #6844 |
@chz816, meanwhile could you please give me a way to reproduce your case? Ideally with some public dataset and best with the current version of the examples (master or last release), which would be using e.g.:
|
Thank you, @patil-suraj. Oh, this is the legacy script so it does do:
|
So the detection is quick (had to bolt it on manually, since this script isn't using the
|
I'm able to reproduce this with the "modern" version of the script:
|
Thank you for your response @stas00 ! Yeah I am able to resolve the issue without |
For some reason I can't even run I can only guess that perhaps |
Environment info
transformers
version: 4.2.0 and 4.5.1Who can help
@patrickvonplaten, @patil-suraj
Information
Model I am using (Bert, XLNet ...): google/pegasus-cnn_dailymail
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
I am trying to generate the summaries from Pegasus on CNN/DM and XSUM datasets. I use the same dataset shared by HuggingFace (from README.md in https://github.com/huggingface/transformers/tree/master/examples/legacy/seq2seq). My experiments are run on 3 V100 GPUs. I use
google/pegasus-cnn_dailymail
for CNN/DM andgoogle/pegasus-xsum
for XSUM.{'rouge1': 47.0271, 'rouge2': 24.4924, 'rougeL': 39.2529, 'n_obs': 11333, 'seconds_per_sample': 0.035, 'n_gpus': 3}
{"n_gpus": 3, "n_obs": 11490, "rouge1": 0.1602, "rouge2": 0.084, "rougeL": 0.1134, "seconds_per_sample": 0.1282}
.(Note: here the batch size is changed due to memory limitation. Although experiments are performed on the same devices, CNN/DM requires more spaces considering the unique feature of dataset itself.)
test_generations.txt
file to try to figure out whygoogle/pegasus-cnn_dailymail
doesn't work. Then I found most of lines intest_generations.txt
are blank. (Please using the attached image for an example)Expected behavior
It is so wired that
google/pegasus-xsum
works out perfectly whilegoogle/pegasus-cnn_dailymail
does not generate summaries successfully. I am confused so I switch the transformers version (4.2.0 and 4.5.1), and I re-run the experiments on different GPUs. This problem exists. Could you please give me any suggestions? Thank you!The text was updated successfully, but these errors were encountered: