Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Caching with FP8 KV cache support #3234

Closed
wants to merge 11 commits into from

Conversation

chenxu2048
Copy link
Contributor

@chenxu2048 chenxu2048 commented Mar 6, 2024

Fix #3156

The FP8 KV cache uses tensors with dtype torch.uint8 and converts them from fp8_e5m2 to float16 in the paged attention kernel. The Prefix cache Triton kernel cannot handle key_cache and value_cache with "wrong" dtype. Therefore, we convert them to fp8_e5m2 before the kernel and they can be upcast to the correct dtype inside the kernel.

Code was tested with Qwen/Qwen-7B-Chat.

This PR requires triton 2.2.0 and torch 2.2.x, and it depends on #2804.

@zhaoyang-star zhaoyang-star mentioned this pull request Mar 11, 2024
3 tasks
@chenxu2048
Copy link
Contributor Author

We found that triton==2.1.0 could not support fp8 kernel correctly, which is depended by torch 2.1.2. We can:

  1. use 2.1.0 by default and guide user to install triton 2.2.0 in doc.
  2. upgrade pytorch to 2.2.1, which depends triton=2.2.0
  3. use triton nightly build, which version is 2.1.0.xxxx and do not break torch in pip.

What do you think? @zhaoyang-star @WoosukKwon

Copy link
Contributor

@zhaoyang-star zhaoyang-star left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the second option is simple and clear. @WoosukKwon What do you think about this ?

vllm/model_executor/layers/attention/ops/paged_attn.py Outdated Show resolved Hide resolved
@WoosukKwon
Copy link
Collaborator

@zhaoyang-star @chenxu2048 Can it wait a bit? We have some issues in upgrading the PyTorch version (#2804).

@chenxu2048
Copy link
Contributor Author

@zhaoyang-star @chenxu2048 Can it wait a bit? We have some issues in upgrading the PyTorch version (#2804).

@WoosukKwon Sure. Do we have any schedule about the upgrade?

Copy link
Contributor

@zhaoyang-star zhaoyang-star left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 14, 2024

@WoosukKwon FYI, prefix caching is not working on Turing GPU with trion=2.1.0 now, also need to upgrade 2.2.0.

@chenxu2048 chenxu2048 changed the title Prefix Caching with FP8 KV cache support Draft: Prefix Caching with FP8 KV cache support Mar 15, 2024
@chenxu2048 chenxu2048 changed the title Draft: Prefix Caching with FP8 KV cache support Prefix Caching with FP8 KV cache support Mar 15, 2024
@chenxu2048 chenxu2048 marked this pull request as draft March 15, 2024 02:13
@snippetzero
Copy link

@chenxu2048 Hello, we tests on the 7B qwen model, use the modifications to the fp8 support for prefix_prefill and found that there is a loss in accuracy compared to disable the prefix_prefill and use fp8. Could this be because the process of converting fp8 to bf16 in Triton and the implementation of fp8_e5m2_unscaled::vec_conversion are different?

@zhaoyang-star
Copy link
Contributor

@chenxu2048 Hello, we tests on the 7B qwen model, use the modifications to the fp8 support for prefix_prefill and found that there is a loss in accuracy compared to disable the prefix_prefill and use fp8. Could this be because the process of converting fp8 to bf16 in Triton and the implementation of fp8_e5m2_unscaled::vec_conversion are different?

Could you please do evaluation under the case (enable prefix_prefill meanwhile disbale fp8_e5m2) ? Just to verify if the accuray drop is caused by fp8_e5m2.

@snippetzero
Copy link

@chenxu2048 Hello, we tests on the 7B qwen model, use the modifications to the fp8 support for prefix_prefill and found that there is a loss in accuracy compared to disable the prefix_prefill and use fp8. Could this be because the process of converting fp8 to bf16 in Triton and the implementation of fp8_e5m2_unscaled::vec_conversion are different?

Could you please do evaluation under the case (enable prefix_prefill meanwhile disbale fp8_e5m2) ? Just to verify if the accuray drop is caused by fp8_e5m2.

In the test where 'enable prefix_prefill mean while disable fp8_e5m2', the results were normal. Only in the case of both prefix_prefill and fp8_e5m2 being enabled does the accuracy get affected.

@chenxu2048
Copy link
Contributor Author

@chenxu2048 Hello, we tests on the 7B qwen model, use the modifications to the fp8 support for prefix_prefill and found that there is a loss in accuracy compared to disable the prefix_prefill and use fp8. Could this be because the process of converting fp8 to bf16 in Triton and the implementation of fp8_e5m2_unscaled::vec_conversion are different?

Hi, @snippetzero. Could you provide the model and some inputs for testing?

Without enabling prefix caching, both the Key and Value are computed in the pre-filling stage are in FP16, while in the decoding stage, FP8 is used. However, when prefix caching is enabled, both the Key and Value in the KV Cache are in FP8.

I think additional precision loss might be introduced in pre-filling and prefix KV Cache.

@laurens-gs
Copy link

Flash attention 3 will support FP8 soon and Flashinfer already supports it. So I would like to bump this PR and hopefully put it on the roadmap again.

(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0).to(q.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you should just do this? What about fp8 scaling factor?

assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}

sm_scale = 1.0 / (Lq**0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happened to the FP8 scaling factor?

@comaniac
Copy link
Collaborator

comaniac commented Aug 6, 2024

FlashInfer supports it in v0.2.0 which is being released so we will be unblocked soon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FP8 KV cache doesn't work with prefix caching
8 participants