-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mistral
] Add Flash Attention-2 support for mistral
#26464
[Mistral
] Add Flash Attention-2 support for mistral
#26464
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Might be worth adding support for sliding window attention ? |
|
||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
|
||
use_sliding_windows = _is_flash_using_slicing_windows and kv_seq_len > self.config.sliding_window and not self.training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo, should be "_is_flash_using_sliding_windows"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also I don't think this bool is needed.
If self.config.sliding_window is not None -> use sliding_window always, whether training or inferencing no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm good point, for some reason I thought that feature works only for inference (from the source code's readme: https://github.com/mistralai/mistral-src#sliding-window-to-speed-up-inference-and-reduce-memory-pressure I have read "speed up inference" so I thought that was only available for inference) - will remove that condition
if not use_sliding_windows: | ||
attn_output_unpad = flash_attn_varlen_func( | ||
query_states, | ||
key_states, | ||
value_states, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_in_batch_q, | ||
max_seqlen_k=max_seqlen_in_batch_k, | ||
dropout_p=dropout, | ||
softmax_scale=softmax_scale, | ||
causal=True, | ||
) | ||
else: | ||
attn_output_unpad = flash_attn_varlen_func( | ||
query_states, | ||
key_states, | ||
value_states, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_in_batch_q, | ||
max_seqlen_k=max_seqlen_in_batch_k, | ||
dropout_p=dropout, | ||
softmax_scale=softmax_scale, | ||
causal=True, | ||
window_size=(self.config.sliding_window, self.config.sliding_window) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could probably be factored ? something like window_size=(self.config.sliding_window or -1, -1) ?
) | ||
else: | ||
attn_output = flash_attn_func( | ||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True, window_size=(self.config.sliding_window // 2, self.config.sliding_window // 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment on factoring.
I have no idea what's going on with // 2 here but I don't know what padding_mask is :|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah ignore the //2 I used it for testing purpose :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
padding mask is the "pure" attention mask (not causal mask) 0 if padding token 1 if not --> I use it in the control flow for flash attention modules whether I need to pad / unpad or no
…ransformers into add-mistral-fa-2
For the sake of completeness, Script I used to benchmark transformers + FA2: https://gist.github.com/younesbelkada/691c1dec3da2f0a7de29c1d1096d860f Script I used to benchmark mistral original source code: https://gist.github.com/younesbelkada/ada0d9c2c48ab034486dbaaf95d29fae (assuming you have cloned their repository and run it under the root folder of the repo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! Looking good.
if not _is_flash_using_sliding_windows: | ||
logger.warning_once( | ||
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation" | ||
" make sure to upgrade flash-attn library." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should go in the import instead of here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so because it will raise the warning if you have FA installed even if you not use the flash attention converted mistral model
# In PEFT, usually we cast the layer norms in float32 for training stability reasons | ||
# therefore the input hidden states gets silently casted in float32. Hence, we need | ||
# cast them back in float16 just to be sure everything works as expected. | ||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms | ||
# in fp32. (LlamaRMSNorm handles it correctly) | ||
input_dtype = query_states.dtype | ||
if input_dtype == torch.float32: | ||
logger.warning_once( | ||
"The input hidden states seems to be silently casted in float32, this might be related to" | ||
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
" float16." | ||
) | ||
|
||
query_states = query_states.to(torch.float16) | ||
key_states = key_states.to(torch.float16) | ||
value_states = value_states.to(torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if this should be included since its peft only always casts to float16 even if input is bfloat16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MistralRMSNorm
behaves exactly as LLamaRMSNorm so it will silently cast the hidden states in fp32, therefore this is needed. As mentioned offline I will address a proper fix for bf16 issues
logger.warning_once( | ||
"You are attempting to perform batched generation with padding_side='right'" | ||
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " | ||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
given how bad the outputs were might be a good idea to for padding side / raise and error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@younesbelkada the implementation here doesn't seem specific to generation in any way. Is the error message wrong or is the implementation wrong (or am I missing something)? That is, should I be able to run forward with right padding and flash2?
Co-authored-by: Arthur <[email protected]>
…ransformers into add-mistral-fa-2
…ransformers into add-mistral-fa-2
Mistral
] Add mistral + FA 2Mistral
] Add Flash Attention-2 support for mistral
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks of iterating!
if past_key_value is not None: | ||
# Activate slicing cache only if the config has a value `sliding_windows` attribute | ||
if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: | ||
slicing_tokens = kv_seq_len - self.config.sliding_window |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@younesbelkada sorry to bother again but, if kv_seq_len > N times sliding_windows for instance seq_len = 9000 and sliding_window = 4096 shouldn't slicing_tokens be 9000 - 2 x 4096 instead of 9000 - 4096 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm possibly yes, I need to double check with the original code, I'll get back to you on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vince62s I could not find any relevant piece of code in the source code of mistral to confirm your statement, can you help me identifying the place where you think we indeed need to slice only N*4096?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was mistaken. We are adding key/value one by one here so kv_seq_len is never > self.config.sliding_window + 1 it is exactly equal to. I was misled my line 374.
…26464) * add FA-2 support for mistral * fixup * add sliding windows * fixing few nits * v1 slicing cache - logits do not match * add comment * fix bugs * more mem efficient * add warning once * add warning once * oops * fixup * more comments * copy * add safety checker * fixup * Update src/transformers/models/mistral/modeling_mistral.py Co-authored-by: Arthur <[email protected]> * copied from * up * raise when padding side is right * fixup * add doc + few minor changes * fixup --------- Co-authored-by: Arthur <[email protected]>
…26464) * add FA-2 support for mistral * fixup * add sliding windows * fixing few nits * v1 slicing cache - logits do not match * add comment * fix bugs * more mem efficient * add warning once * add warning once * oops * fixup * more comments * copy * add safety checker * fixup * Update src/transformers/models/mistral/modeling_mistral.py Co-authored-by: Arthur <[email protected]> * copied from * up * raise when padding side is right * fixup * add doc + few minor changes * fixup --------- Co-authored-by: Arthur <[email protected]>
What does this PR do?
Adds Flash Attention 2 for Mistral For Causal - we still need to discuss how to integrate it with local attention
cc @ArthurZucker @LysandreJik