Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Dynamic length in static cache #30862

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented May 16, 2024

What does this PR do?

The current version is a minimal change that works, maybe not the best way

Current static cache is nice (when running with torch.compile). However, in each generation step, the new position (to be generated) computes the attentions against all positions in the cache, which is not optimal. In fact, we only need to compute the attentions against the positions prior the current position.

This PR implement dynamic length computation with static cache, which work with torch.compile. The following table demonstrate the speedup gain (with torch.compile) of this implementation over the current main branch.

The correctness is verified by

RUN_SLOW=1 TF_FORCE_GPU_ALLOW_GROWTH=true python3 -m pytest -v tests/models/gemma/test_modeling_gemma.py -k "test_compile_static_cache"

The data below is based on

this script
import os
import torch
import datetime

from transformers import AutoTokenizer, AutoModelForCausalLM

token = "ADD_YOUR_OWN_TOKEN"

os.environ["TOKENIZERS_PARALLELISM"] = "false"

batch_size = 1
n_iter = 5

ckpt = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(ckpt, token=token)
model = AutoModelForCausalLM.from_pretrained(ckpt, token=token, torch_dtype=torch.float16).to("cuda")

model.generation_config.max_new_tokens = 1024
model.generation_config.max_new_tokens = 1024

model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

input_text = "Why dogs are cute."
input_ids = tokenizer([input_text] * batch_size, return_tensors="pt").to("cuda")

for i in range(n_iter):
    s = datetime.datetime.now()
    outputs = model.generate(**input_ids, do_sample=False)
    t = datetime.datetime.now()
    e = (t-s).total_seconds()
    print(e)

with some modification to run it with different configurations, running on A100 with torch==2.3+cu121.

Benchmark

I will re-run (part of) the benchmark as the following numbers are on top of of an older commit of main

benchmark data on the hub

Static cache compiled: full length v.s. optimal length (this PR)

gemma-2b (18 layers)

seq. length speedup
1024 1.03 x
2048 1.11 x
4096 1.24 x
8192 1.38 x

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -1218,6 +1231,7 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"_length": int(cache_position[-1]) + 1,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant of cache_position, however, this is the only way I can figure out to make the dynamic length computation works with torch.compile.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I also think it makes sense slicing useless ops, I wondered about this question myself :) The speedup of 15% is very nice, and I can confirm the speedup on my setup as well! (RTX3090) 🔥

Regarding the API (_length): I understand why it is done. Without an interger in the signature of GemmaSdpaAttention.forward, slicing tensors like key_states will always fail due to it being a data-dependent operation (=forbidden) OR producing a variable length tensor. With an integer, each value for the integer has its own compiled function with data-independent tensor slicing.

Still, if we are to go forward, we should find a better solution for the API. StaticCache already introduced the cache_position input, this would further complicate the API. I see three possible paths:

  1. cache_position becomes a list of integers instead of a tensor, we use cache_position[-1] + 1 to slice the tensors;
  2. we pass the full cache_position array (a torch.arange up to the sequence length). The different shape of cache_position in each GemmaSdpaAttention.forward will trigger recompilation, solving the dynamic shape problem
  3. instead of cache_position, we use the sequence length (=_length, an int) to control generation with static cache.

Note that in all 3 cases, the StaticCache needs a tensor like the current cache_position. However, it should be trivial to build from any of the solutions above. From a usage perspective, option 3 is probably the easiest to understand. @ArthurZucker @ydshieh WDYT?

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 21, 2024

Regarding the API (_length): I understand why it is done. Without an interger in the signature of GemmaSdpaAttention.forward, slicing tensors like key_states will always fail due to it being a data-dependent operation (=forbidden) OR producing a variable length tensor.

Exactly!

For option 2, I am a bit worried. For example,

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Currently, cache_position is only 1 element (after the first step). If we go for optioin2, it will be full length. Then we are updating the whole cache. Of course, the key_states, value_states arguments in the update is just the last part (in the sequence), and we will have to slice cache_position here too. So the issue of data-dependent operation still pop up here.

I personally would prefer option 3 for its simplicity as long as we can re-build the tensor (cache_position) that is required by update and other places requiring it. Would like to hear from @ArthurZucker too.

@ArthurZucker
Copy link
Collaborator

Hey before having a look, when you mention speedups, I don't think it makes sense to compute anything that does not use the full number of layers.
Also how many tokens are generated? Is this speedup only for the prefill phase?

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 23, 2024

@ArthurZucker I am running on A10, so even with gemma-2 (18 layers), I can only compile with 768 sequence length.
However, as you can see from the tables, more layers more speedup, and longer sequence more speedup too.

Also how many tokens are generated?

from 256 to 8192 (as long as it could compile within A10 GPU memory).

The speedup gain and the reason behind it is kind easy to see. However, if there Is any extra particular case(s) you want me to perform?

@ArthurZucker
Copy link
Collaborator

That is what was not clear for me I wanted to know the amount of generated tokens not the prefill 😉
And most importantly, the new argument is pretty annoying 😓

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 23, 2024

Yes. We probably need to come up with a good new approach as @gante suggested.
I will run full layers (18) for google/gemma-2 in the meantime.

@ydshieh ydshieh mentioned this pull request May 23, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting for the final bench!

Comment on lines 575 to 581
if _length > 0:
key_states = key_states[:, :, :_length, :]
value_states = value_states[:, :, :_length, :]
causal_mask = causal_mask[:, :, :, :_length] if causal_mask is not None else causal_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can only be an int, if it's a list, there is bound to be device transfer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, so far it is int. Let me run the final bench first and come back to the API design

@zucchini-nlp
Copy link
Member

@ydshieh I tried to use your implementation in my PR. I am also trying to get the actual length in compiled models, but in my case length is used to decide which rope scaling to do. Therefore, passing length as model kwarg fails with dynamic control flow in fullgraph setting.

So, what do you guys think on going back to seen_tokens attribute of the cache class. It doesn't cause control flow errors because cache is still a model attribute, and we update seen_tokens as we generate instead of passing every forward pass. And I think it will work for kv-cropping done in this PR

cc @gante @ArthurZucker ?

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 24, 2024

@ydshieh I tried to use your implementation in my PR. I am also trying to get the actual length in compiled models, but in my case length is used to decide which rope scaling to do. Therefore, passing length as model kwarg fails with dynamic control flow in fullgraph setting.

but in my case length is used to decide which rope scaling to do. Therefore, passing length as model kwarg fails with dynamic control flow in fullgraph setting.

Could you ping the lines where length is used + where the compile issues. I could take a look.

Oh, you use the length as conditional?

@zucchini-nlp
Copy link
Member

Yes, in Phi3 RoPE it's used as conditional and I've been trying to compile it

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 24, 2024

Do you still have that commit (where you interoperate your PR with mine and leads to compile failure) ? If so, could you share please 🙏

@zucchini-nlp
Copy link
Member

Sorry, I reverted your changes back but I just pushed the one which works for me, with "seen_tokens". I get the length here and then use it in self.rotary_emb

@zucchini-nlp
Copy link
Member

Update: we discussed with @ydshieh using the length in cond control flow. It works indeed, but only in torch 2.3.0 or.2.4.0. In the 2.2.0 it would fail.

So this feature will also benefit Phi3 compilation, when merged :)

@ydshieh
Copy link
Collaborator Author

ydshieh commented May 27, 2024

Hi @gante

When I run with TORCH_LOGS="recompiles" on main (95b3c381) and this PR (862cde4c), the only recompilation in both commits happens at the second call to the forward (see below) which makes sense.

So this PR doesn't introduce any extra recompilation.

(if we call generate with another input with different sequence, there would be one more recompilation. But after that, everything is ready to use and no further recompilation even if a 3rd input is given with different length)

V0527 12:12:26.470280 140416858965824 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function forward in /transformers/src/transformers/models/gemma/modeling_gemma.py:1058
V0527 12:12:26.470280 140416858965824 torch/_dynamo/guards.py:1425] [__recompiles]     triggered by the following guard failure(s):
V0527 12:12:26.470280 140416858965824 torch/_dynamo/guards.py:1425] [__recompiles]     - tensor 'L['input_ids']' stride mismatch at index 0. expected 6, actual 7
11.347857

@ydshieh ydshieh force-pushed the dynamic_length_in_static_cache branch from 862cde4 to b447901 Compare May 27, 2024 14:34
@ydshieh
Copy link
Collaborator Author

ydshieh commented May 28, 2024

it should be trivial to build from any of the solutions above

  1. cache_position becomes a list of integers instead of a tensor, we use cache_position[-1] + 1 to slice the tensors;

But we will need a tensor in StaticCache.update (that is what @ArthurZucker told me), so this option is not good I think.

  1. we pass the full cache_position array (a torch.arange up to the sequence length). The different shape of cache_position in each GemmaSdpaAttention.forward will trigger recompilation, solving the dynamic shape problem
  1. instead of cache_position, we use the sequence length (=_length, an int) to control generation with static cache.

Given a length (say _length) or a full cache_position along, it's not enough to reconstruct the (current) cache_position. The problem is that we don't know if we are in the first generation step or the steps after it in order to determine if we want to reconstruct a full cache_position or a single (current) position to be used in StaticCache.update.

We can probably use q_len, but it is obtained from a input tensor. I don't know if this will work well with torch.compile.

@gante Do you have any comment regarding this and something you think I could give it a try?

@gante
Copy link
Member

gante commented Jun 14, 2024

@ydshieh sorry for the delayed response, I've now placed this issue on top of my priorities 🤗

Regarding your previous comment:

  1. cache_position becomes a list of integers instead of a tensor, we use cache_position[-1] + 1 to slice the tensors;

But we will need a tensor in StaticCache.update (that is what @ArthurZucker told me), so this option is not good I think.

I believe we can convert the cache_positions (of list type) to a tensor right before calling StaticCache.update, getting the best of both worlds 🙌 I think this is the path with minimal API changes -- assuming it works, the only needed change is the type of the input!

Copy link
Collaborator Author

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding this dynamic length computation within static cache for torch.compile , it turns out that we get even more speedup for a long (enough) cache size (so we don't need to recreate the cache object that trigger recompile and very slow) while the generation is short compared to cache length.

The following figures shows we can have even 6x speedup (when the cache size is 16384 and the decoding steps is finished earlier, say 256 tokens).

Of course, it's arguable about the usage of such long cache size, and we also need to run against real dataset and/or batch generation, but even with cache size 4096, a 2x speedup is still there.

p.s. long cache size has some issue to compile, but that issue also present in our main branch and so far is not caused only by this dynamic length of this PR. (edited)
3 files

diff_cache_size_decoding_steps_4096
diff_decoding_steps_cache_size_4096
diff_decoding_steps_cache_size_16384

Comment on lines 22 to 30
class CacheInfo:

def __init__(self, position, length):
self.position = position
self._length = length
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduce this class to encapsulate all the information (position, length) instead of passing those as separate arguments.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jun 14, 2024

I believe we can convert the cache_positions (of list type) to a tensor right before calling StaticCache.update, getting the best of both worlds 🙌

OK, I can test this approach

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jun 14, 2024

@gante

I tried it, but there are some slow down in the compile timing (the first/second iteration). See the numbers below.
(well, also some slow down after the first 2 iterations. I have to check with longer decoding steps (say 1024 - 8192) to see how much increased).

See the changes here: it's not optimized but just to try the idea quickly.

The slow down is likely due to the overhead of python/torch switching/conversion.

Furthermore, compared to the new class CacheInfo approach, there are more places to change (i.e. some previous torch operations have to be changed to list operations).

Let me know what you think, especially about the increased compile time.

On T4, decoding 256 steps (without dynamic length involved, just compare tensor v.s. list approach)

with cache_position being tensor:

79.57995
93.045676
6.057991
6.05349
6.057576

with cache_position being list but converted to tensor before calling update:

108.264561
102.169996
6.601498
6.358185
6.774343

on A100 (with decoding steps: 1024)

79.57995
93.045676
6.057991
6.05349
6.057576

v.s.

108.264561
102.169996
6.601498
6.358185
6.774343

@gante
Copy link
Member

gante commented Jun 14, 2024

@ydshieh thank you for exploring the alternative! 💛 The execution time is indeed the key metric, and it is clearly inferior (probably because it leads to more recompiled code sections).

Last counter-proposal: have you tried using _length alone, i.e. NOT passing cache_position and rebuilding it inside the attention layers' forward with torch.arange, from _length + input shape? In terms of API, it would also be preferable, a simple integer is preferable to a custom class :)

(Happy to try it if you're low on bandwidth!)

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jun 14, 2024

Could try it next week. However FYI, cache_position is also used in several places, like

  • prepare_inputs_for_generation: past_length
  • _update_causal_mask: causal_mask depends on it
  • GemmaModel.forward: position_ids depends on it

Also _assisted_decoding seems some special logic involved cache_position

            if "cache_position" in candidate_kwargs:
                candidate_kwargs["cache_position"] = torch.cat(
                    (
                        candidate_kwargs["cache_position"],
                        torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
                    ),
                    dim=0,
                )

That is why I am somehow afraid to break things 😅

@gante
Copy link
Member

gante commented Jun 14, 2024

@ydshieh I think we can solve all those cases from _length :D Assuming it works, and that the speedups are similar, I think it's well worth the effort 💪

@ydshieh ydshieh force-pushed the dynamic_length_in_static_cache branch from c0300c3 to 9168904 Compare July 4, 2024 08:05
@ydshieh ydshieh force-pushed the dynamic_length_in_static_cache branch from 0258a4e to 9ab68d0 Compare July 4, 2024 08:13
@helunwencser
Copy link
Contributor

helunwencser commented Jul 29, 2024

hi @ydshieh , @gante, is there any update on this PR? I want to use phi-3 with static kv cache. This PR seems super useful. Can we do similar change for phi-3 as well?

@ArthurZucker
Copy link
Collaborator

Static cache for phi 3 will need a separate PR to support cache positions

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jul 30, 2024

@gante Let me know your thoughts on the current POC whenever you get the time to take a loot. Thanks.

@guangy10
Copy link
Contributor

@ArthurZucker @gante What is the plan to move this work/PR forward?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds promising, not sure if @ydshieh had time to pick this back up but looks good overall. Needs heavy testing tho!

Comment on lines +284 to +289
if q_len > 1:
# prefill
cache_position = torch.arange(cache_length, dtype=torch.int32, device=hidden_states.device)
else:
# decoding
cache_position = torch.tensor([cache_length - 1], dtype=torch.int32, device=hidden_states.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am very much supprised that this now works in torch compile with reduce overhead, as when testing this use to always create the same tensor (constant cache lenght) outputs were different. WOuld need investigation on which torch version supports this!

@ArthurZucker
Copy link
Collaborator

@guangy10 would this help for torch export ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants