-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Dynamic length in static cache #30862
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -1218,6 +1231,7 @@ def prepare_inputs_for_generation( | |||
"past_key_values": past_key_values, | |||
"use_cache": use_cache, | |||
"attention_mask": attention_mask, | |||
"_length": int(cache_position[-1]) + 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is redundant of cache_position
, however, this is the only way I can figure out to make the dynamic length computation works with torch.compile
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH I also think it makes sense slicing useless ops, I wondered about this question myself :) The speedup of 15% is very nice, and I can confirm the speedup on my setup as well! (RTX3090) 🔥
Regarding the API (_length
): I understand why it is done. Without an interger in the signature of GemmaSdpaAttention.forward
, slicing tensors like key_states
will always fail due to it being a data-dependent operation (=forbidden) OR producing a variable length tensor. With an integer, each value for the integer has its own compiled function with data-independent tensor slicing.
Still, if we are to go forward, we should find a better solution for the API. StaticCache
already introduced the cache_position
input, this would further complicate the API. I see three possible paths:
cache_position
becomes a list of integers instead of a tensor, we usecache_position[-1] + 1
to slice the tensors;- we pass the full
cache_position
array (a torch.arange up to the sequence length). The different shape ofcache_position
in eachGemmaSdpaAttention.forward
will trigger recompilation, solving the dynamic shape problem - instead of
cache_position
, we use the sequence length (=_length
, anint
) to control generation with static cache.
Note that in all 3 cases, the StaticCache
needs a tensor like the current cache_position
. However, it should be trivial to build from any of the solutions above. From a usage perspective, option 3 is probably the easiest to understand. @ArthurZucker @ydshieh WDYT?
Exactly! For option 2, I am a bit worried. For example, if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) Currently, I personally would prefer option 3 for its simplicity as long as we can re-build the tensor ( |
Hey before having a look, when you mention speedups, I don't think it makes sense to compute anything that does not use the full number of layers. |
@ArthurZucker I am running on A10, so even with gemma-2 (18 layers), I can only compile with 768 sequence length.
from 256 to 8192 (as long as it could compile within A10 GPU memory). The speedup gain and the reason behind it is kind easy to see. However, if there Is any extra particular case(s) you want me to perform? |
That is what was not clear for me I wanted to know the amount of generated tokens not the prefill 😉 |
Yes. We probably need to come up with a good new approach as @gante suggested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Waiting for the final bench!
if _length > 0: | ||
key_states = key_states[:, :, :_length, :] | ||
value_states = value_states[:, :, :_length, :] | ||
causal_mask = causal_mask[:, :, :, :_length] if causal_mask is not None else causal_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can only be an int, if it's a list, there is bound to be device transfer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, so far it is int. Let me run the final bench first and come back to the API design
@ydshieh I tried to use your implementation in my PR. I am also trying to get the actual length in compiled models, but in my case length is used to decide which rope scaling to do. Therefore, passing length as model kwarg fails with dynamic control flow in fullgraph setting. So, what do you guys think on going back to cc @gante @ArthurZucker ? |
Could you ping the lines where length is used + where the compile issues. I could take a look. Oh, you use the length as conditional? |
Yes, in Phi3 RoPE it's used as conditional and I've been trying to compile it |
Do you still have that commit (where you interoperate your PR with mine and leads to compile failure) ? If so, could you share please 🙏 |
Sorry, I reverted your changes back but I just pushed the one which works for me, with "seen_tokens". I get the length here and then use it in |
Update: we discussed with @ydshieh using the length in cond control flow. It works indeed, but only in torch 2.3.0 or.2.4.0. In the 2.2.0 it would fail. So this feature will also benefit Phi3 compilation, when merged :) |
Hi @gante When I run with So this PR doesn't introduce any extra recompilation. (if we call
|
862cde4
to
b447901
Compare
But we will need a tensor in
Given a length (say We can probably use @gante Do you have any comment regarding this and something you think I could give it a try? |
@ydshieh sorry for the delayed response, I've now placed this issue on top of my priorities 🤗 Regarding your previous comment:
I believe we can convert the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding this dynamic length computation within static cache for torch.compile
, it turns out that we get even more speedup for a long (enough) cache size (so we don't need to recreate the cache object that trigger recompile and very slow) while the generation is short compared to cache length.
The following figures shows we can have even 6x
speedup (when the cache size is 16384
and the decoding steps is finished earlier, say 256
tokens).
Of course, it's arguable about the usage of such long cache size, and we also need to run against real dataset and/or batch generation, but even with cache size 4096
, a 2x
speedup is still there.
p.s. long cache size has some issue to compile, but that issue also present in our main branch and so far is not caused only by this dynamic length of this PR. (edited)
3 files
src/transformers/cache_utils.py
Outdated
class CacheInfo: | ||
|
||
def __init__(self, position, length): | ||
self.position = position | ||
self._length = length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Introduce this class to encapsulate all the information (position
, length
) instead of passing those as separate arguments.
OK, I can test this approach |
I tried it, but there are some slow down in the compile timing (the first/second iteration). See the numbers below. See the changes here: it's not optimized but just to try the idea quickly. The slow down is likely due to the overhead of python/torch switching/conversion. Furthermore, compared to the Let me know what you think, especially about the increased compile time. On T4, decoding 256 steps (without dynamic length involved, just compare tensor v.s. list approach)with 79.57995
93.045676
6.057991
6.05349
6.057576 with 108.264561
102.169996
6.601498
6.358185
6.774343 on A100 (with decoding steps: 1024)
v.s.
|
@ydshieh thank you for exploring the alternative! 💛 The execution time is indeed the key metric, and it is clearly inferior (probably because it leads to more recompiled code sections). Last counter-proposal: have you tried using (Happy to try it if you're low on bandwidth!) |
Could try it next week. However FYI,
Also if "cache_position" in candidate_kwargs:
candidate_kwargs["cache_position"] = torch.cat(
(
candidate_kwargs["cache_position"],
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
),
dim=0,
) That is why I am somehow afraid to break things 😅 |
@ydshieh I think we can solve all those cases from |
c0300c3
to
9168904
Compare
0258a4e
to
9ab68d0
Compare
Static cache for phi 3 will need a separate PR to support cache positions |
@gante Let me know your thoughts on the current POC whenever you get the time to take a loot. Thanks. |
@ArthurZucker @gante What is the plan to move this work/PR forward? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds promising, not sure if @ydshieh had time to pick this back up but looks good overall. Needs heavy testing tho!
if q_len > 1: | ||
# prefill | ||
cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device) | ||
else: | ||
# decoding | ||
cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am very much supprised that this now works in torch compile with reduce overhead, as when testing this use to always create the same tensor (constant cache lenght) outputs were different. WOuld need investigation on which torch version supports this!
@guangy10 would this help for torch export ? |
What does this PR do?
The current version is a minimal change that works, maybe not the best way
Current static cache is nice (when running with
torch.compile
). However, in each generation step, the new position (to be generated) computes the attentions against all positions in the cache, which is not optimal. In fact, we only need to compute the attentions against the positions prior the current position.This PR implement dynamic length computation with static cache, which work with
torch.compile
. The following table demonstrate the speedup gain (withtorch.compile
) of this implementation over the currentmain
branch.The correctness is verified by
The data below is based on
this script
with some modification to run it with different configurations, running on
A100
withtorch==2.3+cu121
.Benchmark
I will re-run (part of) the benchmark as the following numbers are on top of of an older commit of
main
benchmark data on the hub
Static cache compiled: full length v.s. optimal length (this PR)
gemma-2b (18 layers)