Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[run_summarization.py] wrong dataset leads to CUDA error:s #11344

Closed
stas00 opened this issue Apr 20, 2021 · 24 comments · Fixed by #13559
Closed

[run_summarization.py] wrong dataset leads to CUDA error:s #11344

stas00 opened this issue Apr 20, 2021 · 24 comments · Fixed by #13559

Comments

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

Feeding --dataset_name cnn_dailymail to --model_name_or_path google/pegasus-xsum leads to lots of errors from pytorch - perhaps there is a way to detect that the dataset is inappropriate and give a nice relevant assert instead?

You'd think that --dataset_name cnn_dailymail and --dataset_name xsum should be interchangeable...

python examples/seq2seq/run_summarization.py --model_name_or_path google/pegasus-xsum --do_train \
--do_eval --dataset_name cnn_dailymail --dataset_config "3.0.0"  \
--output_dir /tmp/tst-summarization --per_device_train_batch_size=1 --per_device_eval_batch_size=1 \
--overwrite_output_dir --predict_with_generate
[....]
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [290,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [290,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [290,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
(crashes w/o traceback here)

If I run it on one gpu I get:

[...]
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [138,0,0], thread: [95,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
    return forward_call(*input, **kwargs)
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/models/pegasus/modeling_pegasus.py", line 763, in forward
    layer_outputs = encoder_layer(
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/models/pegasus/modeling_pegasus.py", line 323, in forward
    hidden_states, attn_weights, _ = self.self_attn(
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/models/pegasus/modeling_pegasus.py", line 190, in forward
    query_states = self.q_proj(hidden_states) * self.scaling
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 94, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/functional.py", line 1860, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

Thanks.

@sgugger, @patil-suraj

@stas00
Copy link
Contributor Author

stas00 commented Apr 20, 2021

This fails too:

CUDA_LAUNCH_BLOCKING=1 python examples/seq2seq/run_summarization.py \
--model_name_or_path google/pegasus-xsum --do_eval --dataset_name xsum --output_dir output_dir \
--per_device_eval_batch_size=16 --predict_with_generate --max_val_samples 20
***** Running Evaluation *****
  Num examples = 20
  Batch size = 16
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [69,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [70,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [71,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [72,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [73,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [74,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [75,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [76,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [77,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [78,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [79,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [80,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [81,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [82,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [83,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [84,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [85,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [86,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [87,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [88,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [89,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [90,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [91,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [92,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [93,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [94,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [95,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
Traceback (most recent call last):
  File "examples/seq2seq/run_summarization.py", line 591, in <module>
    main()
  File "examples/seq2seq/run_summarization.py", line 547, in main
    metrics = trainer.evaluate(
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/trainer_seq2seq.py", line 75, in evaluate
    return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/trainer.py", line 1853, in evaluate
    output = eval_loop(
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/trainer.py", line 2005, in evaluation_loop
    loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/trainer_seq2seq.py", line 167, in prediction_step
    generated_tokens = self.model.generate(
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/generation_utils.py", line 931, in generate
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/generation_utils.py", line 413, in _prepare_encoder_decoder_kwargs_for_generation
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/models/pegasus/modeling_pegasus.py", line 721, in forward
    embed_pos = self.embed_positions(input_shape)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/nvme1/code/huggingface/transformers-gpt-neo-nan/src/transformers/models/pegasus/modeling_pegasus.py", line 139, in forward
    return super().forward(positions)
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 156, in forward
    return F.embedding(
  File "/home/stas/anaconda3/envs/py38-pt18/lib/python3.8/site-packages/torch/nn/functional.py", line 2037, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: CUDA error: device-side assert triggered
Collecting environment information...
PyTorch version: 1.9.0a0+git548765d
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.2.152
GPU models and configuration:
GPU 0: GeForce GTX 1070 Ti
GPU 1: GeForce RTX 3090

@sgugger
Copy link
Collaborator

sgugger commented Apr 20, 2021

I'm not sure it's a dataset thing. I think there is something wrong inside the Pegasus model, there have been multiple issues with it not working with Trainer.

@stas00
Copy link
Contributor Author

stas00 commented Apr 21, 2021

Hmm, after updating datasets to the latest version the cmd line in OP started to work. But it crashes in the same way if I add --max_train_samples 20 --max_val_samples 20.

@xuyeliu
Copy link

xuyeliu commented Apr 24, 2021

Hi, do you know how to use GPU when running summarization.py? I have 2 GPUs on my computer, but it didn't use them... Thank you very much!

@stas00
Copy link
Contributor Author

stas00 commented Apr 24, 2021

@liubest, please kindly use https://discuss.huggingface.co/ if you run into troubles after reading README.md, which should cover most of the questions on this example usage.

@xuyeliu
Copy link

xuyeliu commented May 1, 2021

@liubest, please kindly use https://discuss.huggingface.co/ if you run into troubles after reading README.md, which should cover most of the questions on this example usage.

Thank you for your reply. I have one more question and it is not found in the forum. When using run_summarization.py, how to run transformer models like t5-small, facebook/bart-large-cnn without loading pre-trained weights? I only want to train their original model architecture without pre-trained model. Thank you very much!

@stas00
Copy link
Contributor Author

stas00 commented May 2, 2021

You will find probably dozens tutorials if you use google: Please try huggingface train model from scratch.

Please let's not derail this issue by asking unrelated questions. If you still have a problem please start a new Issue. Thank you!

@patrickvonplaten
Copy link
Contributor

I'm also interested in solving this problem. @stas00, let me know if I should look into it

@stas00
Copy link
Contributor Author

stas00 commented May 2, 2021

Yes, please, @patrickvonplaten - thank you!

@patrickvonplaten
Copy link
Contributor

@stas00, I checked and the problem simply seems to be that max_source_length is too high. It's set to 1024 by default even though Pegasus can only handle 512. So, the following command should just run fine:

python examples/pytorch/summarization/run_summarization.py --model_name_or_path google/pegasus-xsum --do_train \
--do_eval --dataset_name cnn_dailymail --dataset_config "3.0.0"  \
--output_dir /tmp/tst-summarization --per_device_train_batch_size=1 --per_device_eval_batch_size=1 \
--overwrite_output_dir --predict_with_generate --max_source_length 512

@patrickvonplaten
Copy link
Contributor

By the way errors like those /workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [95,0,0] Assertion srcIndex < srcSelectDimSize failed are in my experience very often out of index errors and it helps to run the same code on CPU which then gives a better error message

@stas00
Copy link
Contributor Author

stas00 commented May 13, 2021

@stas00, I checked and the problem simply seems to be that max_source_length is too high. It's set to 1024 by default even though Pegasus can only handle 512. So, the following command should just run fine:

python examples/pytorch/summarization/run_summarization.py --model_name_or_path google/pegasus-xsum --do_train \
--do_eval --dataset_name cnn_dailymail --dataset_config "3.0.0"  \
--output_dir /tmp/tst-summarization --per_device_train_batch_size=1 --per_device_eval_batch_size=1 \
--overwrite_output_dir --predict_with_generate --max_source_length 512

Thank you for investigating this, @patrickvonplaten - could we programmatically defend against this mismatch?

@stas00
Copy link
Contributor Author

stas00 commented May 13, 2021

By the way errors like those /workspace/pytorch/aten/src/ATen/native/cuda/Indexing.cu:666: indexSelectLargeIndex: block: [174,0,0], thread: [95,0,0] Assertion srcIndex < srcSelectDimSize failed are in my experience very often out of index errors and it helps to run the same code on CPU which then gives a better error message

Yes! so with CUDA_VISIBLE_DEVICES=""

we should document this at https://huggingface.co/transformers/troubleshooting.html

Also CUDA_LAUNCH_BLOCKING=1 is another important debug technique for gpu

@patil-suraj
Copy link
Contributor

@stas00 , @patrickvonplaten , Pegasus actually uses SinusoidalPositionalEmbedding, so there is no seq length limit. We should resize the embedding if cur len is greater than the default len. That's what we do in FSMT and M2M100

@patrickvonplaten
Copy link
Contributor

On the other hand Pegasus has only been trained on a max length of 512, so I'm not sure whether it's a good idea to "silently" extend the input to a length of 1024 since the model will probably produce garbage, or do you guys have had different experiences @stas00 @patil-suraj ?

Think I'd prefer to reduce max length automatically to model.config.max_position_embeddings and throw a warning

@patil-suraj
Copy link
Contributor

That makes sense, but even though pegaus is pre-trained with 512, they use different max_position_embeddings when fine-tuning

fo example for xsum model max_position_embeddings is 512 https://huggingface.co/google/pegasus-xsum/blob/main/config.json#L44

and for cnn_dm, pubmed it is 1024
https://huggingface.co/google/pegasus-pubmed/blob/main/config.json#L38
https://huggingface.co/google/pegasus-pubmed/blob/main/config.json#L38

@stas00
Copy link
Contributor Author

stas00 commented May 14, 2021

Think I'd prefer to reduce max length automatically to model.config.max_position_embeddings and throw a warning

This is very likely to be unnoticed.

We misuse warnings too much, they are ok when you have 5 lines of output, when you have 100s of those chances that the user will see it is close to 0. Especially when things seem to work, albeit with setting changes behind the scenes.

I feel that @patil-suraj's suggestion of granting user's wish is a better one and if they get garbage then it's loud and clear that they did something wrong. Here, a warning of asking for a longer value than preset will work, as they are likely to search for the culprit.

And in situations where we know what the user is asking for is surely not going to work, we should assert.

@patrickvonplaten
Copy link
Contributor

Ok - good arguments! IMO we should only allow this resizing though for models that use Sinusoidal position embeddings a.k.a. position embeddings that have set .grad to False.

In terms of implementation, I'd suggest to add a general resize_position_embeddings(self, max_posituon_embeddings) to PreTrainedModel that throws a NotImplementedError and is then overwritten in Pegasus

@patrickvonplaten
Copy link
Contributor

We should also overwrite the config.max_position_embeddings when doing so

@huggingface huggingface deleted a comment from github-actions bot Jun 10, 2021
@stas00
Copy link
Contributor Author

stas00 commented Jul 5, 2021

@patrickvonplaten, do you have some resources to come back so that we could complete this issue? It looks like it fell between the cracks. Thank you.

@huggingface huggingface deleted a comment from github-actions bot Jul 5, 2021
@huggingface huggingface deleted a comment from github-actions bot Aug 4, 2021
@patrickvonplaten
Copy link
Contributor

Ok so the plan is to:

  1. Add a resize_position_embeddings to PreTrainedModel just like we are doing it for the word embeddings
  2. resize_position_embeddings should probably log or warn depending on whether it's sinus position embeddings or learned ones
  3. The function should overwrite config.max_position_embeddings

=> Happy to open a PR for this one, but would be great to first hear @LysandreJik and @sgugger's opinion on it as well

@sgugger
Copy link
Collaborator

sgugger commented Aug 30, 2021

Works for me!

@omerarshad
Copy link

@sgugger ,can you share your working code?

@sgugger
Copy link
Collaborator

sgugger commented Aug 31, 2021

No I meant the plan suggested by @patrickvonplaten in the above message works for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants
@stas00 @omerarshad @patrickvonplaten @patil-suraj @xuyeliu @sgugger and others