Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPT-Neo] Simplify local attention #13491

Merged
merged 3 commits into from
Sep 10, 2021

Conversation

patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Sep 9, 2021

What does this PR do?

Co-authored-by: finetuneanon [email protected]

This PR is a continuation of #11630 which simplifies GPT Neo's local attention implementation. All credit to @finetuneanon for finding this issue, providing the fix, and a detailed explanation. Thanks a lot for working on this :)

The issue is described in #11320 and performance evaluation results are available here:
#12106 (comment)

This PR does some cleanup and updates tests on top of @finetuneanon changes.

Fixes #11320, Fixes #11787, Fixes #12964, Fixes #11096

# in the causal_mask.
causal_mask = causal_mask * attention_mask
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9))
Copy link
Contributor

@patrickvonplaten patrickvonplaten Sep 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't break in fp16? It's -1e4 in modeling_gpt2.py I think


query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this correctly convert to fp16?

torch.tensor(-1e9).to(torch.float16)

gives -inf -> which could later lead to problems in fp16 no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten @sgugger

-1e9 is the value used in the original codebase and actually, the attention weights are always computed in fp32

query = query.to(torch.float32)
key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = attn_weights.to(value.dtype)
attn_weights = attn_dropout(attn_weights)

so the masked_bias is always fp32. The attn_weights are only cast back to the original dtype after softmax.

However, as the models are trained on TPU and possibly with bf16 (see #11076 (comment)) I'm not sure we can guarantee that the models will always work with fp16. See #11076.

@@ -232,6 +230,86 @@ def create_and_check_gpt_neo_model_past(self, config, input_ids, input_mask, hea
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))

def create_and_check_gpt_neo_model_attention_mask_past(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome new tests! Could we also add one/two fp16 tests? To make sure generation and forward pass works correctly :-)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great clean-up! Think this should solve the vram memory issue

The only thing that would be good to check IMO is if this implementation is 100% fp16 compatible (would be great to add some tests as well for this)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! Agreed with Patrick on the FP16 tests to make sure it all works fine.

src/transformers/models/gpt_neo/modeling_gpt_neo.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good!

Comment on lines +148 to +149
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will save to the state dict - is that voluntary? If not, you can add the persistent=False flag (introduced in pytorch 1.6)

@patil-suraj patil-suraj merged commit 010965d into huggingface:master Sep 10, 2021
@patil-suraj patil-suraj deleted the fix-gpt-neo-local-attn branch September 10, 2021 17:22
@tianleiwu
Copy link
Contributor

I had error in running:

  File "src\transformers\models\gpt_neo\configuration_gpt_neo.py", line 218, in __init__
    from .modeling_gpt_neo import GPTNeoAttentionMixin
ImportError: cannot import name 'GPTNeoAttentionMixin'

Need a patch for this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
6 participants