Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mistral] Add Flash Attention-2 support for mistral #26464

Merged
merged 28 commits into from
Oct 3, 2023

Conversation

younesbelkada
Copy link
Contributor

What does this PR do?

Adds Flash Attention 2 for Mistral For Causal - we still need to discuss how to integrate it with local attention

cc @ArthurZucker @LysandreJik

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1",
    use_flash_attention_2=True,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
).to(0)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

text = "Hello my name is"
inputs = tokenizer(text, return_tensors="pt").to(0)

out = model.generate(**inputs, max_new_tokens=4096, use_cache=True, do_sample=True)
print(tokenizer.batch_decode(out, skip_special_tokens=True))

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 28, 2023

The documentation is not available anymore as the PR was closed or merged.

@timlacroix
Copy link
Contributor

Might be worth adding support for sliding window attention ?
params are window_size_left = config.sliding_window-1 // window_size_right = -1 I believe ?
See this PR adding support in flashattentionv2
Dao-AILab/flash-attention@083e8f5


query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

use_sliding_windows = _is_flash_using_slicing_windows and kv_seq_len > self.config.sliding_window and not self.training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo, should be "_is_flash_using_sliding_windows"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I don't think this bool is needed.

If self.config.sliding_window is not None -> use sliding_window always, whether training or inferencing no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm good point, for some reason I thought that feature works only for inference (from the source code's readme: https://github.com/mistralai/mistral-src#sliding-window-to-speed-up-inference-and-reduce-memory-pressure I have read "speed up inference" so I thought that was only available for inference) - will remove that condition

Comment on lines 438 to 464
if not use_sliding_windows:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
)
else:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
window_size=(self.config.sliding_window, self.config.sliding_window)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could probably be factored ? something like window_size=(self.config.sliding_window or -1, -1) ?

)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True, window_size=(self.config.sliding_window // 2, self.config.sliding_window // 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment on factoring.
I have no idea what's going on with // 2 here but I don't know what padding_mask is :|

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah ignore the //2 I used it for testing purpose :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

padding mask is the "pure" attention mask (not causal mask) 0 if padding token 1 if not --> I use it in the control flow for flash attention modules whether I need to pad / unpad or no

@younesbelkada
Copy link
Contributor Author

Sharing some results here vs mistral official implementation that uses xformers.memory_efficient_attention:

Context length = 12 / max_new_tokens=512 / bs=1

HF transformers + FA-2

Latency: 15.1241201171875
33.85320904838279 tokens / s
Max allocated memory: 15218032640

Mistral + mem efficient:

Latency: 17.23331640625
29.709893785407036 tokens / s
Max allocated memory: 14636799488

Context length = 11K / max_new_tokens=512 / bs=1

HF transformers + FA-2

Latency: 16.497216796875
31.03553807312431 tokens / s
Max allocated memory: 18673463808 

Mistral + mem efficient:

Latency: 22.50997265625
22.74547409802565 tokens / s
Max allocated memory: 17303250944

Context length = 11K / max_new_tokens=512 / bs=2 with 11K padding tokens on the second batch

HF transformers + FA-2

Latency: 33.95778515625
15.077544004832287 tokens / s
Max allocated memory: 22320273408

Mistral + mem efficient:

Latency: 30.407841796875
16.83776189774238 tokens / s
Max allocated memory: 17841224192

Context length = 11K / max_new_tokens=512 / bs=4 with 11K padding tokens on the second, third and fourth batch

HF transformers + FA-2

Latency: 48.86058984375
10.478792860203109 tokens / s
Max allocated memory: 29610738688

Mistral + mem efficient:

Latency: 45.27477734375
11.308724858272097 tokens / s
Max allocated memory: 18914968576

--> obviously the pad / unpad overhead takes it over for the HF implementation whereas the official repository deals with padding tokens differently. Note also that the max allocated memory increases if one adds padding token. Also note the current cache slicing mechanism assumes users are under padding=left regime. Generation should be performed with padding_side=left whereas this should have no impact for training as the cache is not used during training.

Here is a plot that compares pure forward on HF native vs HF + FA-2

Screenshot 2023-10-02 at 17 49 12 Screenshot 2023-10-02 at 17 49 06

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Oct 2, 2023

For the sake of completeness,

Script I used to benchmark transformers + FA2: https://gist.github.com/younesbelkada/691c1dec3da2f0a7de29c1d1096d860f

Script I used to benchmark mistral original source code: https://gist.github.com/younesbelkada/ada0d9c2c48ab034486dbaaf95d29fae (assuming you have cloned their repository and run it under the root folder of the repo)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Looking good.

docs/source/en/perf_infer_gpu_one.md Show resolved Hide resolved
src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
src/transformers/models/mistral/modeling_mistral.py Outdated Show resolved Hide resolved
Comment on lines 365 to 369
if not _is_flash_using_sliding_windows:
logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
" make sure to upgrade flash-attn library."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should go in the import instead of here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so because it will raise the warning if you have FA installed even if you not use the flash attention converted mistral model

Comment on lines 408 to 423
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this should be included since its peft only always casts to float16 even if input is bfloat16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MistralRMSNorm behaves exactly as LLamaRMSNorm so it will silently cast the hidden states in fp32, therefore this is needed. As mentioned offline I will address a proper fix for bf16 issues

Comment on lines 879 to 883
logger.warning_once(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given how bad the outputs were might be a good idea to for padding side / raise and error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada the implementation here doesn't seem specific to generation in any way. Is the error message wrong or is the implementation wrong (or am I missing something)? That is, should I be able to run forward with right padding and flash2?

@younesbelkada younesbelkada marked this pull request as ready for review October 3, 2023 08:51
@younesbelkada younesbelkada changed the title [Mistral] Add mistral + FA 2 [Mistral] Add Flash Attention-2 support for mistral Oct 3, 2023
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks of iterating!

@younesbelkada younesbelkada merged commit ae9a344 into huggingface:main Oct 3, 2023
18 checks passed
@younesbelkada younesbelkada deleted the add-mistral-fa-2 branch October 3, 2023 11:44
@younesbelkada younesbelkada mentioned this pull request Oct 4, 2023
4 tasks
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window:
slicing_tokens = kv_seq_len - self.config.sliding_window

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada sorry to bother again but, if kv_seq_len > N times sliding_windows for instance seq_len = 9000 and sliding_window = 4096 shouldn't slicing_tokens be 9000 - 2 x 4096 instead of 9000 - 4096 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm possibly yes, I need to double check with the original code, I'll get back to you on this!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vince62s I could not find any relevant piece of code in the source code of mistral to confirm your statement, can you help me identifying the place where you think we indeed need to slice only N*4096?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was mistaken. We are adding key/value one by one here so kv_seq_len is never > self.config.sliding_window + 1 it is exactly equal to. I was misled my line 374.

blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
…26464)

* add FA-2 support for mistral

* fixup

* add sliding windows

* fixing few nits

* v1 slicing cache - logits do not match

* add comment

* fix bugs

* more mem efficient

* add warning once

* add warning once

* oops

* fixup

* more comments

* copy

* add safety checker

* fixup

* Update src/transformers/models/mistral/modeling_mistral.py

Co-authored-by: Arthur <[email protected]>

* copied from

* up

* raise when padding side is right

* fixup

* add doc + few minor changes

* fixup

---------

Co-authored-by: Arthur <[email protected]>
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
…26464)

* add FA-2 support for mistral

* fixup

* add sliding windows

* fixing few nits

* v1 slicing cache - logits do not match

* add comment

* fix bugs

* more mem efficient

* add warning once

* add warning once

* oops

* fixup

* more comments

* copy

* add safety checker

* fixup

* Update src/transformers/models/mistral/modeling_mistral.py

Co-authored-by: Arthur <[email protected]>

* copied from

* up

* raise when padding side is right

* fixup

* add doc + few minor changes

* fixup

---------

Co-authored-by: Arthur <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants