Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

google/pegasus-cnn_dailymail generates blank file #11289

Closed
2 of 4 tasks
chz816 opened this issue Apr 16, 2021 · 10 comments
Closed
2 of 4 tasks

google/pegasus-cnn_dailymail generates blank file #11289

chz816 opened this issue Apr 16, 2021 · 10 comments

Comments

@chz816
Copy link

chz816 commented Apr 16, 2021

Environment info

  • transformers version: 4.2.0 and 4.5.1
  • Platform: linux
  • Python version: 3.6
  • PyTorch version (GPU?): 1.7.1
  • Tensorflow version (GPU?): NA
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes (and I also try to not use distributed but problem exists)

Who can help

@patrickvonplaten, @patil-suraj

Information

Model I am using (Bert, XLNet ...): google/pegasus-cnn_dailymail

The problem arises when using:

The tasks I am working on is:

  • an official GLUE/SQUaD task: summarization with ROUGE
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

I am trying to generate the summaries from Pegasus on CNN/DM and XSUM datasets. I use the same dataset shared by HuggingFace (from README.md in https://github.com/huggingface/transformers/tree/master/examples/legacy/seq2seq). My experiments are run on 3 V100 GPUs. I use google/pegasus-cnn_dailymail for CNN/DM and google/pegasus-xsum for XSUM.

  1. The results on XSUM is perfect. I run the following code and receive the ROUGE score as: {'rouge1': 47.0271, 'rouge2': 24.4924, 'rougeL': 39.2529, 'n_obs': 11333, 'seconds_per_sample': 0.035, 'n_gpus': 3}
python -m torch.distributed.launch --nproc_per_node=3  run_distributed_eval.py \
    --model_name google/pegasus-xsum  \
    --save_dir $OUTPUT_DIR \
    --data_dir $DATA_DIR \
    --bs 64 \
    --fp16
  1. I was expecting similar SOTA performance on CNNDM, so I run the following code and receive: {"n_gpus": 3, "n_obs": 11490, "rouge1": 0.1602, "rouge2": 0.084, "rougeL": 0.1134, "seconds_per_sample": 0.1282}.

(Note: here the batch size is changed due to memory limitation. Although experiments are performed on the same devices, CNN/DM requires more spaces considering the unique feature of dataset itself.)

python -m torch.distributed.launch --nproc_per_node=3  run_distributed_eval.py \
    --model_name google/pegasus-cnn_dailymail  \
    --save_dir $OUTPUT_DIR \
    --data_dir $DATA_DIR \
    --bs 32 \
    --fp16
  1. I look at the generated test_generations.txt file to try to figure out why google/pegasus-cnn_dailymail doesn't work. Then I found most of lines in test_generations.txt are blank. (Please using the attached image for an example)

image

Expected behavior

It is so wired that google/pegasus-xsum works out perfectly while google/pegasus-cnn_dailymail does not generate summaries successfully. I am confused so I switch the transformers version (4.2.0 and 4.5.1), and I re-run the experiments on different GPUs. This problem exists. Could you please give me any suggestions? Thank you!

@patil-suraj
Copy link
Contributor

Hi @chz816

I can reproduce the issue. This is because pegasus doesn't really work with fp16since its trained with bfloat16, so in most cases, it overflows and returns nan logits. The model works as expected in fp32, so if you run the above command without the --fp16 arg, it should give the expected results.

cc @stas00

@chz816
Copy link
Author

chz816 commented Apr 20, 2021

Thank you @patil-suraj!

I have generated the summaries using pegasus-cnn_dailymail with the following performance: {'rouge1': 43.146, 'rouge2': 20.7292, 'rougeL': 30.4596, 'n_obs': 11490, 'seconds_per_sample': 0.2415, 'n_gpus': 3}. It is lower than expected, but I think it can be explained by smaller batch size, which is caused by the memory limitation.

python -m torch.distributed.launch --nproc_per_node=3  run_distributed_eval.py \
    --model_name google/pegasus-cnn_dailymail  \
    --save_dir $OUTPUT_DIR \
    --data_dir $DATA_DIR \
    --bs 16

Can you maybe explain why this problem does not exist for google/pegasus-xsum? Thank you!

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

As @patil-suraj pointed out many models trained in bfloat16 can't be run under mixed precision fp16 (albeit pytorch are discussing bfloat16 mixed precision)

pegasus-cnn_dailymail has an issue of underflow under fp16:

let's take a single frame - Linear forward for lm_head:

fp32:

abs min  abs max  metadata
                  lm_head Linear
4.66e-10 1.13e+01 weight
6.29e-07 4.47e+00 input[0]
1.63e-07 3.00e+01 output

fp16:

                  lm_head Linear
0.00e+00 1.13e+01 weight
6.76e-07 5.38e+00 input[0]
0.00e+00 3.08e+01 output

As you can see 4.66e-10 under fp16 underflows into 0.0.

edit: well, actually this would be the case if we did model.half() (which is what deepspeed does, and that's where it'd immediately underflow on the very first use), so here it's probably something else. I will need some time to try to understand what's going on here.

This is from WIP PR #11274 - still polishing some nuances but should be ready soon.

Let me check google/pegasus-xsum

@patil-suraj
Copy link
Contributor

Regarding the cnn_dailymail scores, please see this issue #6844

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

@chz816, meanwhile could you please give me a way to reproduce your case? Ideally with some public dataset and best with the current version of the examples (master or last release), which would be using examples/seq2seq/run_summarization.py

e.g.:

python examples/seq2seq/run_summarization.py --model_name_or_path google/pegasus-cnn_dailymail \
--do_train --do_eval --dataset_name cnn_dailymail --dataset_config "3.0.0" --source_prefix \
"summarize: " --output_dir /tmp/tst-summarization --per_device_train_batch_size=1 \
--per_device_eval_batch_size=1 --overwrite_output_dir --predict_with_generate 

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

Thank you, @patil-suraj.

Oh, this is the legacy script so it does do:

    if fp16:
        model = model.half()

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

wget https://cdn-datasets.huggingface.co/summarization/pegasus_data/cnn_dailymail.tar.gz
tar -xvzf cnn_dailymail.tar.gz
python -m torch.distributed.launch --nproc_per_node=1  run_distributed_eval.py \
--model_name google/pegasus-cnn_dailymail  --save_dir output_dir --data_dir cnn_dailymail \
--bs 8 --fp16

So the detection is quick (had to bolt it on manually, since this script isn't using the Trainer):

Detected inf/nan during batch_number=0
Last 10 forward frames:
abs min  abs max  metadata
                  model.encoder.layers.14.fc1 Linear
0.00e+00 1.88e+01 weight
2.73e-05 2.54e+00 bias
5.96e-08 9.05e+00 input[0]
0.00e+00 3.16e+02 output
                  model.encoder.layers.14.fc2 Linear
5.96e-08 3.29e+01 weight
5.40e-03 2.66e+01 bias
0.00e+00 1.03e+02 input[0]
0.00e+00 8.00e+03 output
                  model.encoder.layers.14 PegasusEncoderLayer
0.00e+00 6.45e+04 input[0]
0.00e+00 0.00e+00 input[1]
0.00e+00 6.45e+04 output[0]
                  model.encoder.layers.15.self_attn_layer_norm LayerNorm
5.63e-03 3.85e-01 weight
1.69e-05 2.49e-01 bias
0.00e+00 6.45e+04 input[0]
0.00e+00 1.50e+00 output
                  model.encoder.layers.15.self_attn.q_proj Linear
8.34e-07 2.95e+00 weight
0.00e+00 0.00e+00 bias
0.00e+00 1.50e+00 input[0]
5.96e-08 8.52e+00 output
                  model.encoder.layers.15.self_attn.k_proj Linear
2.38e-07 1.85e+00 weight
0.00e+00 0.00e+00 bias
0.00e+00 1.50e+00 input[0]
1.19e-07 9.30e+00 output
                  model.encoder.layers.15.self_attn.v_proj Linear
5.96e-08 4.03e+00 weight
0.00e+00 0.00e+00 bias
0.00e+00 1.50e+00 input[0]
6.56e-07 2.95e+01 output
                  model.encoder.layers.15.self_attn.out_proj Linear
5.96e-08 2.25e+01 weight
0.00e+00 0.00e+00 bias
5.96e-08 1.25e+01 input[0]
3.58e-07 1.29e+03 output
                  model.encoder.layers.15.self_attn PegasusAttention
3.58e-07 1.29e+03 output[0]
             None output[1]
             None output[2]
                  model.encoder.layers.15.final_layer_norm LayerNorm
7.32e-02 2.69e+00 weight
2.00e-05 1.02e+00 bias
0.00e+00      inf input[0]
     nan      nan output

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

I'm able to reproduce this with the "modern" version of the script:

rm -rf output_dir; USE_TF=0 PYTHONPATH=src python examples/seq2seq/run_summarization.py \
--model_name_or_path google/pegasus-cnn_dailymail --do_eval --dataset_name cnn_dailymail \
--dataset_config "3.0.0" --output_dir output_dir \
--per_device_eval_batch_size=16 --predict_with_generate --fp16_full_eval --max_val_samples 10

[...]

***** eval metrics *****
  eval_gen_len              =        9.0
  eval_loss                 =        nan
  eval_mem_cpu_alloc_delta  =      -55MB
  eval_mem_cpu_peaked_delta =       55MB
  eval_mem_gpu_alloc_delta  =     1089MB
  eval_mem_gpu_peaked_delta =     7241MB
  eval_rouge1               =        0.0
  eval_rouge2               =        0.0
  eval_rougeL               =        0.0
  eval_rougeLsum            =        0.0
  eval_runtime              = 0:00:07.71
  eval_samples              =         10
  eval_samples_per_second   =      1.295
  init_mem_cpu_alloc_delta  =        0MB
  init_mem_cpu_peaked_delta =        0MB
  init_mem_gpu_alloc_delta  =        0MB
  init_mem_gpu_peaked_delta =        0MB

@chz816
Copy link
Author

chz816 commented Apr 20, 2021

Thank you for your response @stas00 ! Yeah I am able to resolve the issue without --fp16, but I am still little confused why google/pegasus-xsum works well with ---fp16 argument, since they are from the same seq2seq model. Any ideas? Thank you!

@stas00
Copy link
Contributor

stas00 commented Apr 21, 2021

For some reason I can't even run google/pegasus-xsum #11344, so I'm not able to look inside.

I can only guess that perhaps google/pegasus-xsum was trained in mixed precision fp16?

@chz816 chz816 closed this as completed Apr 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants