-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prefix Caching with FP8 KV cache support #3234
Conversation
3116104
to
5d31fe7
Compare
We found that triton==2.1.0 could not support fp8 kernel correctly, which is depended by torch 2.1.2. We can:
What do you think? @zhaoyang-star @WoosukKwon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the second option is simple and clear. @WoosukKwon What do you think about this ?
@zhaoyang-star @chenxu2048 Can it wait a bit? We have some issues in upgrading the PyTorch version (#2804). |
@WoosukKwon Sure. Do we have any schedule about the upgrade? |
Use torch.fp8_e5m2 instead of torch.uint8 in python interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@WoosukKwon FYI, prefix caching is not working on Turing GPU with |
@chenxu2048 Hello, we tests on the 7B qwen model, use the modifications to the fp8 support for prefix_prefill and found that there is a loss in accuracy compared to disable the prefix_prefill and use fp8. Could this be because the process of converting fp8 to bf16 in Triton and the implementation of fp8_e5m2_unscaled::vec_conversion are different? |
Could you please do evaluation under the case (enable prefix_prefill meanwhile disbale fp8_e5m2) ? Just to verify if the accuray drop is caused by fp8_e5m2. |
In the test where 'enable prefix_prefill mean while disable fp8_e5m2', the results were normal. Only in the case of both prefix_prefill and fp8_e5m2 being enabled does the accuracy get affected. |
Hi, @snippetzero. Could you provide the model and some inputs for testing? Without enabling prefix caching, both the Key and Value are computed in the pre-filling stage are in FP16, while in the decoding stage, FP8 is used. However, when prefix caching is enabled, both the Key and Value in the KV Cache are in FP8. I think additional precision loss might be introduced in pre-filling and prefix KV Cache. |
Flash attention 3 will support FP8 soon and Flashinfer already supports it. So I would like to bump this PR and hopefully put it on the roadmap again. |
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) | ||
k = tl.load(K_cache + off_k, | ||
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, | ||
other=0.0).to(q.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure you should just do this? What about fp8 scaling factor?
assert Lq == Lk and Lk == Lv | ||
assert Lk in {16, 32, 64, 128} | ||
|
||
sm_scale = 1.0 / (Lq**0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happened to the FP8 scaling factor?
FlashInfer supports it in v0.2.0 which is being released so we will be unblocked soon |
Fix #3156
The FP8 KV cache uses tensors with dtype torch.uint8 and converts them from fp8_e5m2 to float16 in the paged attention kernel. The Prefix cache Triton kernel cannot handle key_cache and value_cache with "wrong" dtype. Therefore, we convert them to fp8_e5m2 before the kernel and they can be upcast to the correct dtype inside the kernel.
Code was tested with Qwen/Qwen-7B-Chat.
This PR requires triton 2.2.0 and torch 2.2.x, and it depends on #2804.