Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix -1e4 as attn mask #17306

Merged
merged 19 commits into from
Jun 20, 2022
Merged

Fix -1e4 as attn mask #17306

merged 19 commits into from
Jun 20, 2022

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented May 17, 2022

What does this PR do?

Fix the issues regarding -1e4 as attention mask.

Fix #17215 #17121 #14859

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 17, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh ydshieh changed the title [WIP] Fix -1e4 as attn mask Fix -1e4 as attn mask May 17, 2022
@ydshieh ydshieh marked this pull request as ready for review May 17, 2022 20:31
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PyTorch implementation relies on self.device which breaks the model parallelism for big model inference, so we should avoid using it (I actually removed lots of instance where we used it recently, and will hunt the other ones in another PR ;-) )

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented May 17, 2022

Generally, this looks good to me. I'd prefer though to not factor out a one-liner into a function (even if we have to add the one-liner 100+ times). It's not good for readability to have to jump to modeling_utils.py and the code saved is not worth it for a one-liner.

Also, I'd advocate to make three separate PRs (one for PT, one for TF, one for Flax). Think it should be both easier to maintain the PRs as well as review them.

A first test should then be that all slow tests pass. After that it would indeed be nice if we could run some fine-tuning for the most important models (BERT on GLUE, GPT2 on causal LM, T5 on translation maybe). Maybe also not even necessary to verify that everything is correct with a training run if the slow tests all pass

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 18, 2022

Hi,

@patrickvonplaten:

  • I removed the new function.
  • I have to modify FlaxT5Attention otherwise the PT/Flax T5 equivalence tests will fail.

@sgugger:

  • since there is no more new function mask_value(), so no more device issue. There is one place I need to use tensor and device though:

mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)

Would this be a problem for model parallelism for big model inference? It is attn_weights.device instead of self.dtype though.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good in general. I have pretty much the same comments as Patrick. I would advocate to do some fine-tuning even if the slow tests pass to make sure it doesn't break anything. Especially with models like T5 which have had issues with attention_mask.

@sgugger
Copy link
Collaborator

sgugger commented May 18, 2022

@ydshieh Using the weight device is perfectly fine, thanks for checking!

@LysandreJik
Copy link
Member

Cool, exciting!

@ydshieh ydshieh force-pushed the no_-1e4_for_attn_mask branch 5 times, most recently from 40c0ce7 to 70eb792 Compare May 25, 2022 08:35
@ydshieh
Copy link
Collaborator Author

ydshieh commented May 25, 2022

Hi, @patrickvonplaten @patil-suraj @sgugger @LysandreJik

This PR is ready for review.

  • Only dealing with PyTorch models: but need to change FlaxT5 too to make the test pass.
  • In general, change to torch.finfo(correct-dtype).min instead of -10000, -1e9 etc.
  • In particular, changes in modeling_utils.py
  • Verified the change by training a T5 from scratch as well as finetuning the t5-small checkpoint

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing all of those!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @ydshieh ! Looks good to me

src/transformers/models/hubert/modeling_hubert.py Outdated Show resolved Hide resolved
@ydshieh ydshieh merged commit d3cb288 into huggingface:main Jun 20, 2022
younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 25, 2022
* Use torch.finfo(self.dtype).min

* for GPTNeoX

* for Albert

* For Splinter

* Update src/transformers/models/data2vec/modeling_data2vec_audio.py

Co-authored-by: Patrick von Platen <[email protected]>

* fix -inf used in Bart-like models

* Fix a few remaining -inf

* more fix

* clean up

* For CLIP

* For FSMT

* clean up

* fix test

* Add dtype argument and use it for LayoutLMv3

* update FlaxLongT5Attention

Co-authored-by: ydshieh <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

-1e9 constants in T5 implementation
8 participants