Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use flash-attn via xformers #877

Merged
merged 2 commits into from
Aug 30, 2023
Merged

use flash-attn via xformers #877

merged 2 commits into from
Aug 30, 2023

Conversation

tmm1
Copy link
Contributor

@tmm1 tmm1 commented Aug 25, 2023

@WoosukKwon WoosukKwon self-requested a review August 26, 2023 01:14
@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Aug 26, 2023

Hi @tmm1 thanks for letting us know the performance issue and submitting the PR.

While using FA2 might improve the performance, we have concerns in using it because it does not support attention bias like ALiBi, V100 GPUs, FP32 data type, and head_size 256 (which is used for GPT-J). So, to use FA2, I believe we should make a fallback option to xformers cutlass backend.

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 26, 2023

So, to use FA2, I believe we should make a fallback option to xformers cutlass backend.

thanks @WoosukKwon, I updated the PR to allow xformers to fallback

cc @danthe3rd

@zhaoyang-star
Copy link
Contributor

So, to use FA2, I believe we should make a fallback option to xformers cutlass backend.

thanks @WoosukKwon, I updated the PR to allow xformers to fallback

cc @danthe3rd

Hi @tmm1 , I am very interested in your PR. I see the PR does not allow xformers to fallback.

@danthe3rd
Copy link

Hi, xformers maintainer here
The way it's setup in this PR allows xformers to decide which backend to use - this is also what we recommend. It will use by default Flashv2 if available, but will fallback to cutlass if it's not available (eg if using a custom bias, fp32 or v100).

@zhaoyang-star
Copy link
Contributor

Hi @danthe3rd , thanks for your explanation. I just wonder why xformers.ops.memory_efficient_attention_forward does not been used in the decoding stage? The API has used in the prefill stage in vLLM. Hand-written kernel may have lower perf than xformers if we are not an expert of CUDA.

@danthe3rd
Copy link

danthe3rd commented Aug 29, 2023

At the moment, memory_efficient_attention is optimized the training mostly - which is similar to the prefilling stage in terms of problem sizes. We have a backend for next token decoding, but it's not fully optimized for all cases, but we're working on it :)

@zhaoyang-star
Copy link
Contributor

zhaoyang-star commented Aug 29, 2023

At the moment, memory_efficient_attention is optimized the training mostly - which is similar to the prefilling stage in terms of problem sizes. We have a backend for next token decoding, but it's not fully optimized for all cases, but we're working on it :)

Thanks a lot. So it makes sense that most llm inference framworks have hand-written cuda kernel for fused attention impl.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you for your contribution!

@zhuohan123 zhuohan123 merged commit 7547138 into vllm-project:main Aug 30, 2023
2 checks passed
liuyanyi pushed a commit to liuyanyi/vllm that referenced this pull request Sep 12, 2023
@KexinFeng
Copy link

KexinFeng commented Sep 14, 2023

Hi @tmm1 @WoosukKwon
I have a follow up question to the question

I just wonder why xformers.ops.memory_efficient_attention_forward does not been used in the decoding stage?

Is flash attention (or similar algorithms where softmax is calculated streamingly with fused kernal) also implemented inside vllm.attention_ops.single_query_cached_kv_attention? It looks to me that, in principle, flash attention algorithm is very compatible with paged attention, meaning that the softmax can in principle be computed streamingly and with a fused kernal, in paged storage of tensor, too.

@Lvjinhong
Copy link

For allow xformers to pick the best available implementation, I don't quite understand this change, so how should I use flash?

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Flash Attention V2
7 participants