-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generation / FIX: Fix multi-device generation #30746
Generation / FIX: Fix multi-device generation #30746
Conversation
The fix is to initialize the special tokens on the correct devices all the time, I updated the description of the PR |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cc @gante ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a suggestion to enable this change on all modalities!
src/transformers/generation/utils.py
Outdated
device = None | ||
if "input_ids" in model_kwargs and isinstance(model_kwargs["input_ids"], torch.Tensor): | ||
device = model_kwargs["input_ids"].device | ||
|
||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I get it right: the device
comes from the main model input, and not from the model itself.
Assuming what I wrote above is correct, we should get the device
variable after the _prepare_model_inputs
call, which extracts the main model input from the different keywords we might see (for instance, Whisper
does not use input_ids
). In that case, I would move these lines to after L1532 (currently batch_size = inputs_tensor.shape[0]
), and use device=inputs_tensor.device
:D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes totally sense! Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect, thank you for iterating 👌
Thanks ! cc @ArthurZucker for the final review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing. A small test is welcome (instead of the slow one!) to make sure we catch this earlier!
@@ -476,6 +476,7 @@ def _prepare_attention_mask_for_generation( | |||
) | |||
can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id | |||
attention_mask_from_padding = inputs.ne(pad_token_id).long() | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weird that this is changed 😄
thanks ! I don't think we can add tests as they would require a GPU, this is implictly tested through our models + quantization slow tests, hence how I catched the bug |
ok if there is no way to repro with a minimal trick putting the weights on meta device voluntarily! |
….41.1 Fixes #31. The handling of special tokens in `transformers` was changed in huggingface/transformers#30624 and huggingface/transformers#30746. This updates the XTTS streaming code accordingly.
….41.1 Fixes #31. The handling of special tokens in `transformers` was changed in huggingface/transformers#30624 and huggingface/transformers#30746. This updates the XTTS streaming code accordingly.
….41.1 Fixes #31. The handling of special tokens in `transformers` was changed in huggingface/transformers#30624 and huggingface/transformers#30746. This updates the XTTS streaming code accordingly.
What does this PR do?
Fixes failing tests for multi-device (e.g. Multi-GPU, GPU + CPU etc) generation. The fix is simply to make sure
pad_token_id
and all other special tokens are initialized on the correct device (e.g. for models offloaded on CPUself.device
return"meta"
which breaks the generation after 😢 )cc @gante @ArthurZucker