Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static cache is locked after torch.compile with model.generate #30351

Closed
2 of 4 tasks
mobicham opened this issue Apr 19, 2024 · 17 comments · Fixed by #30476
Closed
2 of 4 tasks

Static cache is locked after torch.compile with model.generate #30351

mobicham opened this issue Apr 19, 2024 · 17 comments · Fixed by #30476

Comments

@mobicham
Copy link
Contributor

mobicham commented Apr 19, 2024

System Info

  • transformers version: 4.39.0.dev0
  • Platform: Linux-5.15.0-89-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.20.1
  • Safetensors version: 0.4.1
  • Accelerate version: 0.21.0

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When using torch.compile(model.forward) with static cache, the cache seems to be locked with the first prompt that was used for the compilation time. I re-implemented the generate logic and the same issue happens, so it's not just a bug with model.generate. This happens with older and newer versions of transformers.

Here's a code snippet:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id  = "meta-llama/Llama-2-7b-chat-hf"
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, attn_implementation="sdpa").cuda().eval();
tokenizer = AutoTokenizer.from_pretrained(model_id) 
tokenizer.add_bos_token = False

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

inputs = tokenizer(["<s> [INST] Write an essay about large language models [/INST]"], return_tensors="pt").to(model.device)
for _ in range(3):
	gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None)
print(tokenizer.decode(gen_out[0]))

# Output: OK
#  <s>  [INST] Write an essay about large language models [/INST]   Large language models have revolutionized the field of natural language processing in recent years. 
# These models are trained on vast amounts of text data and are capable of generating text, classifying text, and answering questions with remarkable accuracy. 
# In this essay, we will explore the current state of large language models, their potential applications, and the challenges and limitations that come with their use.....

inputs = tokenizer(["<s> [INST] How to make a chocolate cake? [/INST]"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None)
print(tokenizer.decode(gen_out[0]))

# Output: WRONG still talks about the previous prompt.
# <s>  [INST] How to make a chocolate cake? [/INST]  ge language models (LLMs) are a class of artificial intelligence (AI) models that have gained significant 
#attention in recent years due to their impressive language processing capabilities. Here, we will explore the concept of LLMs, their applications, 
# and their potential impact on various fields.
# What are Large Language Models?
# LLMs are neural network-based models that are trained on vast amounts of text data to generate language outputs that are coherent and natural

Expected behavior

The output should correspond to the input prompt, not the prompt the model was first compiled with.

Thank you!

@amyeroberts
Copy link
Collaborator

cc @gante

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Apr 22, 2024

Super weird, and I can indeed reproduce.
The fix is use_cache=False. It's counter intuitive, but this will work:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id  = "meta-llama/Llama-2-7b-chat-hf"
model     = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, attn_implementation="sdpa").cuda().eval();
tokenizer = AutoTokenizer.from_pretrained(model_id) 
tokenizer.add_bos_token = False

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

inputs = tokenizer(["<s> [INST] Write an essay about large language models [/INST]"], return_tensors="pt").to(model.device)
for _ in range(3):
	gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))

inputs = tokenizer(["<s> [INST] How to make a chocolate cake? [/INST]"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)
print(tokenizer.decode(gen_out[0]))

@gante
Copy link
Member

gante commented Apr 22, 2024

When trying the original script with torch==2.4.0.dev20240418+cu121 I get a Aborted (core dumped) preceded by RuntimeError: Triton Error [CUDA]: device-side assert triggered and a bunch of out of bounds memory access 👀

@ArthurZucker's suggested script gets the same exceptions (because the first calls hit the same issue)

@ArthurZucker
Copy link
Collaborator

Ah! I did not get that and successfully generated, no idea what went wrong with yours

@ArthurZucker
Copy link
Collaborator

image that''s what I got and it was pretty fast

@mobicham
Copy link
Contributor Author

mobicham commented Apr 22, 2024

Thanks @ArthurZucker !
Is there a way to replicate this use_cache=False behavior but when manually writing the generate function, like this one (based on your code): https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py

The reason is because it's better to compile the decode_one_token function instead of the whole forward pass, to avoid annoying compilation everytime the input prompt shape changes.

I guess here pass use_cache=False? https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py#L72

@gante I get that sometimes as well. I think it's a bit better with torch nightly build.

@mobicham
Copy link
Contributor Author

@ArthurZucker use_cache=False is not really a solution, the speed is much slower vs. use_cache=True.
I was not able to make it work properly by setting use_cache=False directly in the model forward pass either.

@gante that cuda issue mainly happens when you compile the whole forward pass, normally you only need to compile the forward pass for the decoding part only (input is 1 token and fixed), not the prefill.

@ArthurZucker
Copy link
Collaborator

@mobicham, normally you should not have this issue with the script that compiles decode_one_token. I pushed a fix to main that should have solved this: #30380, which was probably not overwriting the cache.
I think reset_cache might not work as expected

@mobicham
Copy link
Contributor Author

mobicham commented Apr 23, 2024

@ArthurZucker thanks! I found a hack: warm-up with use_cache=False the very first time you compile, then use_cache=True for generation. It still needs to warm-up again with use_cache=True but at least the output is correct.

Update: the warm-up with the full torch.compile takes a lot of VRAM. The best would be to make it work with decode_one_token. Still haven't found a proper way of doing it.

There's another problem: if you compile using max_new_tokens=100 for example and use max_new_tokens=1000 after the warm-up, you get RuntimeError: CUDA error: device-side assert triggered. The trick is to use a larger max_new_tokens at compilation time, then it works with any value less than that.

model.forward = torch.compile(model.forward, **{"mode":"reduce-overhead", "fullgraph":True})

prompt = "Write an essay about large language models."

# warm-up
for _ in range(10):
	gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=1000, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=False)

prompt = "How do I make a cake"
import time
t1 = time.time()
gen_out = model.generate(**tokenize_prompt(prompt), do_sample=False, cache_implementation="static", max_new_tokens=100, pad_token_id=tokenizer.pad_token_id, temperature=None, top_p=None, use_cache=True)
t2 = time.time()
print(len(gen_out[0])/(t2-t1), "tokens/sec")

@mobicham
Copy link
Contributor Author

Was not able to test the fix because there's another problem with 4.41.0: #30417

@ArthurZucker
Copy link
Collaborator

Super weird and we'll fix it asap

@ArthurZucker
Copy link
Collaborator

Might be related to #30414 as well

@mobicham
Copy link
Contributor Author

mobicham commented Apr 23, 2024

I was finally able to make it work without blowing up the VRAM:

  1. Compile with inputs of size [batch_size, 1]: https://github.com/mobiusml/hqq/blob/master/hqq/utils/generation_hf.py#L57-L72
  2. Warm up with 3 prompts with use_cache=False

With this approach, a 4-bit Llama2-7B takes ~5.6GB of runtime with a max 1024 cache size.
If I try the same with model.generate() I run out of VRAM after the 2nd or 3-rd warm-up prompt.

The only issue is the speed. With the approach above I get 165 tokens/sec, it should to be ~205 tokens/sec.

Update: the speed depends on the size of the initialized cache for some reason.

Update 2: It is actually not fixed, the outputs still mix some outputs from previous results.
Will try the fix as soon as a #30417 is fixed

@ArthurZucker
Copy link
Collaborator

Wow thanks a lot for all this valuable debugging, would really love to fix this!

@mobicham
Copy link
Contributor Author

mobicham commented Apr 24, 2024

Thanks @ArthurZucker
I spent the whole day playing with this, the latest version is here . Here's what I noticed so far:

  • For the warm-up, you need to feed it different prompts sequentially, you need at least 3, meaning you need to do generate(prompt1), generate(prompt2), generate(prompt3). If you don't do that and use the same prompt, the cache get totally locked with prompt1
  • Normally, you need to reset the cache before each generation. However, with the compiled version, it crashes if you reset it. When I warm-up the compilation with small 1-token inputs, the output still looks a bit strange, so the cache contains information from some previous prompts. Even if I manually delete and re-create it, same issue.
  • Cache sizes need to be powers of 2, otherwise it crashes with RuntimeError: CUDA error: device-side assert triggered

@ArthurZucker
Copy link
Collaborator

BTW we are gonna move with #30476

@mobicham
Copy link
Contributor Author

Thank you for the update!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants