-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPT-Neo] Simplify local attention #13491
[GPT-Neo] Simplify local attention #13491
Conversation
# in the causal_mask. | ||
causal_mask = causal_mask * attention_mask | ||
self.register_buffer("bias", bias) | ||
self.register_buffer("masked_bias", torch.tensor(-1e9)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't break in fp16? It's -1e4
in modeling_gpt2.py
I think
|
||
query_length, key_length = query.size(-2), key.size(-2) | ||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() | ||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this correctly convert to fp16?
torch.tensor(-1e9).to(torch.float16)
gives -inf
-> which could later lead to problems in fp16 no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1e9 is the value used in the original codebase and actually, the attention weights are always computed in fp32
transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py
Lines 274 to 286 in 09549aa
query = query.to(torch.float32) | |
key = key.to(torch.float32) | |
attn_weights = torch.matmul(query, key.transpose(-1, -2)) | |
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype)) | |
if attention_mask is not None: | |
# Apply the attention mask | |
attn_weights = attn_weights + attention_mask | |
attn_weights = nn.Softmax(dim=-1)(attn_weights) | |
attn_weights = attn_weights.to(value.dtype) | |
attn_weights = attn_dropout(attn_weights) |
so the masked_bias
is always fp32
. The attn_weights
are only cast back to the original dtype
after softmax
.
However, as the models are trained on TPU and possibly with bf16
(see #11076 (comment)) I'm not sure we can guarantee that the models will always work with fp16
. See #11076.
@@ -232,6 +230,86 @@ def create_and_check_gpt_neo_model_past(self, config, input_ids, input_mask, hea | |||
# test that outputs are equal for slice | |||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) | |||
|
|||
def create_and_check_gpt_neo_model_attention_mask_past( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome new tests! Could we also add one/two fp16 tests? To make sure generation and forward pass works correctly :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great clean-up! Think this should solve the vram memory issue
The only thing that would be good to check IMO is if this implementation is 100% fp16 compatible (would be great to add some tests as well for this)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! Agreed with Patrick on the FP16 tests to make sure it all works fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good!
self.register_buffer("bias", bias) | ||
self.register_buffer("masked_bias", torch.tensor(-1e9)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will save to the state dict - is that voluntary? If not, you can add the persistent=False
flag (introduced in pytorch 1.6)
I had error in running:
Need a patch for this PR. |
What does this PR do?
Co-authored-by: finetuneanon [email protected]
This PR is a continuation of #11630 which simplifies GPT Neo's local attention implementation. All credit to @finetuneanon for finding this issue, providing the fix, and a detailed explanation. Thanks a lot for working on this :)
The issue is described in #11320 and performance evaluation results are available here:
#12106 (comment)
This PR does some cleanup and updates tests on top of @finetuneanon changes.
Fixes #11320, Fixes #11787, Fixes #12964, Fixes #11096