Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Automatic Prefix Caching #2614

Closed
zhuohan123 opened this issue Jan 26, 2024 · 16 comments · Fixed by #2762
Closed

[RFC] Automatic Prefix Caching #2614

zhuohan123 opened this issue Jan 26, 2024 · 16 comments · Fixed by #2762
Labels
enhancement New feature or request RFC

Comments

@zhuohan123
Copy link
Member

zhuohan123 commented Jan 26, 2024

This RFC discusses our plan for implementing automatic prefix caching in vLLM.

High-level idea

We observe that every block in the KV cache can be uniquely identified by

hash(prefix tokens, tokens in this block)

With this, we can add another indirection in vLLM's KV cache management:

Logical block table --> hash table --> physical block table.

Then, all the sharing in vLLM, including sharing prefixes, can be achieved by logical blocks pointing to the block with the same hash value. Automatic prefix caching can be achieved by not freeing blocks with reference one in the KV cache. Specifically, this design enables us to manage the KV blocks as ordinary caches in operating systems.

We can maintain the following information in every block:

  • Block's hash
  • Reference count
  • Last accessed time
  • Total access count
  • The prefix length of this block

Then, for example, the following cache eviction policy will give the same policy as in RadixAttetion:

  1. Check the reference count first. Only evict the blocks with ref count == 0.
  2. Then check the last accessed time. Prefer to free older blocks following LRU.
  3. If the last accessed time is the same, check the prefix length. Free the one with longer prefix lengths first.

Major benefits of this design over a KV block Trie

  • Sometimes, caching is not limited to prefix caching:
    • With Mistral's sliding window attention, we only need to cache the last tokens in the sliding window.
    • With attention sinks, we need to cache the first few tokens and the latest tokens.
  • Maintaining hash table is simpler than maintaining a tree.
  • Extensible to more advanced caching policy (the one above is just an example).

Notes

  • An arbitrary caching policy may randomly free a block in the middle of a prefix. Then we need an attention kernel that can compute attention on sequences like the following: “ooooxxxxxooooxxxoooxxxooooxxx”, where we need to compute attention on all “x” tokens. This kernel can be implemented and is not required for the first version.
  • We would only cache the complete blocks, and we will keep partial blocks out of the hash table.

Deliverables

P0

  • Make every complete KV block cacheable. Do not immediately free KV blocks with ref count 0.
  • Implement the cache eviction policy above, with a good abstracted class on eviction policy.
    • The cache eviction policy class should take a sequence of token IDs (or block hases), and return a list of blocks. Some of blocks can be already in the cache, and some of the blocks can be a new block that is just evicted by the policy.
  • Refactor the current block table to use hash:
    • Add the attributes above to every block object in vLLM.
    • For every request, keep the list of block objects as block table.
    • [new] Implement global tables of blocks
      • Two tables: complete block table and partial block table.
      • When an incomplete block becomes a complete block, we need to merge it with an existing complete block or promote it to a new complete block.
  • The preemption policy is kept the same as before.

P1

  • Make sure the policy works for sliding windows.
  • Faster hash function

P2

  • Kernel optimization for "ooxxxoooxxxoooxxx" case.
  • Better preemption strategy for OOM cases.
  • Support block swapping.
@zhuohan123 zhuohan123 added RFC enhancement New feature or request labels Jan 26, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Jan 26, 2024

are you looking for any contributions to deliver this feature @zhuohan123 ?

@simon-mo simon-mo pinned this issue Jan 26, 2024
@zhuohan123
Copy link
Member Author

are you looking for any contributions to deliver this feature @zhuohan123 ?

Yes, indeed. If anyone is interested, please let me know!

@zcnrex
Copy link
Contributor

zcnrex commented Jan 29, 2024

@zhuohan123 Please let me know how I can help deliver this feature

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Jan 29, 2024

are you looking for any contributions to deliver this feature @zhuohan123 ?

Yes, indeed. If anyone is interested, please let me know!

I am very interested. I will reachout to you on discord (and will look for you at the meetup wednesday)

@robertgshaw2-neuralmagic
Copy link
Collaborator

I am working on a PR for this

@atyshka
Copy link

atyshka commented Jan 31, 2024

Would this cache be reset for each batch or could it also be used to persistently cache across different batches/runs? I'm interested in low-latency applications and being able to incrementally append input while retaining the KV cache would be very useful

@robertgshaw2-neuralmagic
Copy link
Collaborator

@atyshka this feature will automatically enable reuse of the cache across different batches and run

@sh1ng
Copy link
Contributor

sh1ng commented Feb 9, 2024

Thank you @zhuohan123 for your analysis.
Correct me if I'm wrong, but I don't see a match difference between a hash table and a tree if we consider KV blocks as leaves. IIUC it's what you are proposing by

hash(prefix tokens, tokens in this block)

We can rewrite it as

hash(parent_block_hash, tokens in this block)
parent_block_hash = 0 for the first block

I made my analysis of attention sink #1304 (comment) and the main challenge is the necessity to move KV cache if you remove a token in the middle.
We can also reserve a few blocks for system/reusable prompts. Attention sink may use a block filled solely with \n tokens.

@SageMoore
Copy link
Contributor

@zhuohan123 we're wrapping up our PR for P0 of this RFC and we had a question about the prefix length that will be encoded in each block and used as a tie breaker for eviction.
Should this prefix length be:

  • The full prefix length of the Sequence.
  • The number of prefix tokens in the current block. I.E a max value of block.block_size
  • The number of prefix tokens before and including the current block? So each subsequent prefix block would have a higher value for block.prefix_length.

@zhuohan123
Copy link
Member Author

@zhuohan123 we're wrapping up our PR for P0 of this RFC and we had a question about the prefix length that will be encoded in each block and used as a tie breaker for eviction.

Should this prefix length be:

  • The full prefix length of the Sequence.

  • The number of prefix tokens in the current block. I.E a max value of block.block_size

  • The number of prefix tokens before and including the current block? So each subsequent prefix block would have a higher value for block.prefix_length.

It should be the last one. The goal is to identify the leaf nodes in the prefix tree. And the prefix length here should be the same as the tree depth.

@gawainx
Copy link

gawainx commented Mar 26, 2024

Any experiment results to show how prefix cache could improve the latency and throughput ?

@merrymercy
Copy link
Contributor

merrymercy commented May 16, 2024

Since many people are going to read this, I want to add a few clarifications here.

I do not see any items listed here that cannot be easily implemented in the tree structure.

  • For sliding window/attention sink, you can keep the metadata in the tree for matching and drop the GPU tensors for KV cache. You need to save the metadata anyway.
  • For multiple LoRA adapters, insert the ID of the LoRA adapter before the first token. Then it is done.
  • For image/video tokens, use the hash of the image/video as the key. It is already implemented in SGLang.
  • For advanced scheduling/eviction policy, having a tree structure can actually give you more information. You may want to evict a whole subtree or prioritize requests under a subtree. Having access to the tree structure makes the advanced scheduling policy possible.
  • With radix tree, you can make matching very efficient like this Optimize radix tree matching sgl-project/sglang#364

In SGLang, we use TokenAttention (block_size/page_size = 1), which also simplifies many things. With optimized kernels flash-infer, it achieves the same performance as the larger block sizes.

@zhyncs
Copy link
Contributor

zhyncs commented May 17, 2024

Since many people are going to read this, I want to add a few clarifications here.

I do not see any items listed here that cannot be easily implemented in the tree structure.

  • For sliding window/attention sink, you can keep the metadata in the tree for matching and drop the GPU tensors for KV cache. You need to save the metadata anyway.
  • For multiple LoRA adapters, insert the ID of the LoRA adapter before the first token. Then it is done.
  • For image/video tokens, use the hash of the image/video as the key. It is already implemented in SGLang.
  • For advanced scheduling/eviction policy, having a tree structure can actually give you more information. You may want to evict a whole subtree or prioritize requests under a subtree. Having access to the tree structure makes the advanced scheduling policy possible.
  • With radix tree, you can make matching very efficient like this Optimize radix tree matching sgl-project/sglang#364

In SGLang, we use TokenAttention (block_size/page_size = 1), which also simplifies many things. With optimized kernels flash-infer, it achieves the same performance as the larger block sizes.

In LMDeploy TurboMind, we've implemented the Automatic Prefix Cache by incorporating both HashTable and RadixTree methods. The overall implementation is very straightforward, requiring no modifications to the kernel while achieving good compatibility with the existing framework features. Currently, since TurboMind does not support Sliding Window and LoRA, we have no plans for a second-phase optimization at this time. InternLM/lmdeploy#1450

@josephrocca
Copy link

josephrocca commented Jun 7, 2024

Any experiment results to show how prefix cache could improve the latency and throughput?

@gawainx Big speedup in my testing with vLLM on 2x4090s running a 70B model. You just need to make sure it's compatible with your load balancer (if any) so requests get preferentially routed to the correct machine.

RE @zhyncs' comment above, I'm seeing a massive speedup in my testing of LMDeploy, in part because it allows 8-bit and 4-bit KV caching in conjunction with prefix caching. Currently vLLM does not allow quantized cache in conjunction with prefix caching (nor does it allow chunked prefill with prefix caching, as an aside), so the cache can only store about eight 2000-token prefixes, vs about 32 in LMDeploy (with 4-bit cache).

With LMDeploy, for 2000 token prompt + 100 tokens output, if prompt is cached, and limiting concurrency so that maximum time-to-first-token to 3 seconds for all requests, then end-to-end generated tokens/sec summed across all concurrent requests goes from ~10 to ~300 (!!) for a 70B Llama 2 model. For comparison, vLLM goes from ~10 to ~100 under the same assumptions. In a more realistic scenario (i.e. where not all prompts are cached), I'm expecting closer to a 4x speedup from LMDeploy, which is amazing. Definitely worth testing for your use case! I think all LMDeploy needs is to fix stop param and add vLLM's include_stop_str_in_output, and it will be ready to deploy for most use cases.

@AnaRhisT94
Copy link

Any experiment results to show how prefix cache could improve the latency and throughput?

@gawainx Big speedup in my testing with vLLM on 2x4090s running a 70B model. You just need to make sure it's compatible with your load balancer (if any) so requests get preferentially routed to the correct machine.

RE @zhyncs' comment above, I'm seeing a massive speedup in my testing of LMDeploy, in part because it allows 8-bit and 4-bit KV caching in conjunction with prefix caching. Currently vLLM does not allow quantized cache in conjunction with prefix caching (nor does it allow chunked prefill with prefix caching, as an aside), so the cache can only store about eight 2000-token prefixes, vs about 32 in LMDeploy (with 4-bit cache).

With LMDeploy, for 2000 token prompt + 100 tokens output, if prompt is cached, and limiting concurrency so that maximum time-to-first-token to 3 seconds for all requests, then end-to-end generated tokens/sec summed across all concurrent requests goes from ~10 to ~300 (!!) for a 70B Llama 2 model. For comparison, vLLM goes from ~10 to ~100 under the same assumptions. In a more realistic scenario (i.e. where not all prompts are cached), I'm expecting closer to a 4x speedup from LMDeploy, which is amazing. Definitely worth testing for your use case! I think all LMDeploy needs is to fix stop param and add vLLM's include_stop_str_in_output, and it will be ready to deploy for most use cases.

Is it hard to implement the FP8 KVCaching in vLLM? Can't they just take LMDeploy's implementation?

kzawora-intel pushed a commit to HabanaAI/vllm-fork that referenced this issue Oct 29, 2024
This PR enables automatic prefix caching in intel gaudi HPUs.
Please refer to this
[RFC](vllm-project#2614) for detailed
informations about prefix caching.
@toilaluan
Copy link

@robertgshaw2-neuralmagic Hi ser, can I save the prefix cache to the disk and reload when needed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request RFC
Projects
None yet
Development

Successfully merging a pull request may close this issue.