-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Static cache is locked after torch.compile with model.generate #30351
Comments
cc @gante |
Super weird, and I can indeed reproduce. import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, attn_implementation="sdpa").cuda().eval();
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_bos_token = False
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
inputs = tokenizer(["<s> [INST] Write an essay about large language models [/INST]"], return_tensors="pt").to(model.device)
for _ in range(3):
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))
inputs = tokenizer(["<s> [INST] How to make a chocolate cake? [/INST]"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0])) |
When trying the original script with @ArthurZucker's suggested script gets the same exceptions (because the first calls hit the same issue) |
Ah! I did not get that and successfully generated, no idea what went wrong with yours |
Thanks @ArthurZucker ! The reason is because it's better to compile the I guess here pass @gante I get that sometimes as well. I think it's a bit better with torch nightly build. |
@ArthurZucker @gante that cuda issue mainly happens when you compile the whole forward pass, normally you only need to compile the forward pass for the decoding part only (input is 1 token and fixed), not the prefill. |
@ArthurZucker thanks! I found a hack: warm-up with Update: the warm-up with the full torch.compile takes a lot of VRAM. The best would be to make it work with There's another problem: if you compile using model.forward = torch.compile(model.forward, **{"mode":"reduce-overhead", "fullgraph":True})
prompt = "Write an essay about large language models."
# warm-up
for _ in range(10):
gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=1000, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
prompt = "How do I make a cake"
import time
t1 = time.time()
gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=True)
t2 = time.time()
print(len(gen_out[0])/(t2-t1), "tokens/sec") |
Was not able to test the fix because there's another problem with 4.41.0: #30417 |
Super weird and we'll fix it asap |
Might be related to #30414 as well |
I was finally able to make it work without blowing up the VRAM:
With this approach, a 4-bit Llama2-7B takes ~5.6GB of runtime with a max 1024 cache size. The only issue is the speed. With the approach above I get 165 tokens/sec, it should to be ~205 tokens/sec. Update: the speed depends on the size of the initialized cache for some reason. Update 2: It is actually not fixed, the outputs still mix some outputs from previous results. |
Wow thanks a lot for all this valuable debugging, would really love to fix this! |
Thanks @ArthurZucker
|
BTW we are gonna move with #30476 |
Thank you for the update! |
System Info
transformers
version: 4.39.0.dev0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
When using
torch.compile(model.forward)
with static cache, the cache seems to be locked with the first prompt that was used for the compilation time. I re-implemented the generate logic and the same issue happens, so it's not just a bug withmodel.generate
. This happens with older and newer versions of transformers.Here's a code snippet:
Expected behavior
The output should correspond to the input prompt, not the prompt the model was first compiled with.
Thank you!
The text was updated successfully, but these errors were encountered: