-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decode error while inferencing a batch of prompts #340
Comments
I found that the batch size is only indirectly the reason and it doesn't have anything to do with OOM or similar things. For example, if I just change the random seed to the following and keep the sequence length and batch size the same, then the bug doesn't happen anymore for this specific batch size, but it will happen for another larger one: # Create an LLM.
llm = LLM(
model="facebook/opt-125m",
use_dummy_weights=True,
seed=2,
) The reason is that for some models there can be a mismatch between the config.vocab_size and the len(tokenizer). The model outputs a distribution over tokens in the range To fix this, if I change the vllm/vllm/model_executor/models/opt.py Line 276 in 58a072b
I have also observed this for LLaMA & LLaMA-2 where it seems like for some models on huggingface the from transformers import AutoTokenizer, PretrainedConfig
print(len(AutoTokenizer.from_pretrained('meta-llama/Llama-2-13b-hf'))) # 32000
print(PretrainedConfig.from_pretrained('meta-llama/Llama-2-13b-hf').vocab_size) # 32000
print(len(AutoTokenizer.from_pretrained('NousResearch/Nous-Hermes-Llama2-13b'))) # 32001
print(PretrainedConfig.from_pretrained('NousResearch/Nous-Hermes-Llama2-13b').vocab_size) # 32032 A fix in vLLM could be to obtain the number of tokens from the tokenizer instead of the |
@tju01 Thank you! Now it looks like larger batch size and sequence length are just to increase the probability this error happens. |
怎么替换? |
@xxm1668 Modifying the Another solution would be modifying the vLLM. For example, if you are using opt model, you should modify |
thx |
Would there be a fix to this bug? We should essentially compare len(tokenizer) and config.vocab_size and pick the smaller value |
Hope those information will help. |
This is still an issue... |
Hi all. Since this issue only exists a few models, so it would be better to add the warning log about this (#3500). |
Warning log doesn't fix the issue with generation though... this problem is very common in fine-tuned versions of models with added tokens, such as any fine-tuned with ChatML prompts. For example, NousResearch/Hermes-2-Pro-Mistral-7B. These fine-tuned models are unusable with vLLM with this bug. |
@simon-mo is there already a way to solve this problem without modifying the vLLM source code? I really would prefer not to have to maintain a fork of vLLM to use Nous-Hermes & other finetuned models... |
@esmeetu any suggestion here to have a good default way to support fine-tuned model? |
Looking into this from transformers import AutoConfig, AutoTokenizer
config = AutoConfig.from_pretrained("NousResearch/Hermes-2-Pro-Mistral-7B")
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Hermes-2-Pro-Mistral-7B")
print(config.vocab_size) # >> 32032
print(len(tokenizer)) # >> 32002 So it seems like 30 tokens are missing from the tokenizer that are included in the model. Not sure how this is possible, but I can see why this would cause issues if the model predicts one of the 30 tokens that are not in the vocabulary of the tokenizer |
I think the issue is that the I think an approach we could take is to expand the tokenizer to have more pad tokens in this scenario. This will allow
Thoughts? |
Had the same issue with LeoLM/leo-mistral-hessianai-7b-chat model from huggingface. Vocab size: 32128, tokenizer len: 32002. |
I took a look at how Here is a demo: from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
print(len(tokenizer)) # 50265
tokenizer.decode([340]) # a token id that exists, i.e. ' news'
tokenizer.decode([34000000]) # a token id that does not exist, returns empty string '' |
@youkaichao I agree with the transformers solution. But now i can't reproduce this issue using large batch sizes. Could someone provide the reproduction code for this? cc @simon-mo @andersonbcdefg |
That's not easy to reproduce, but it indeed can happen, especially when we sample for many steps, and sample out a token id that's out of the scope of the tokenizer. Basically an index-out-of-range problem. |
@andersonbcdefg @Blubberblub Could you try this PR #3685? |
you know the drill
So @Toska12138 , How do you solve this problem in Yi-34B-Chat with vllm? |
The naive way is to modify the tokenizer |
I'm trying to benchmark the performance of vLLM OPT. But I find that when I pass a relatively large batch of prompts to vLLM, it will raise decode error when the sequence length meets a threshold (which makes the problem look like an OOM).
A minimal reproduction for this issue:
When
bs=128
, the error happens in the 108-th token approximately. The error looks likeIf I use a smaller bs, the "threshold" will also increase (>108). For example, it's around 210 when
bs=64
. Seems that there is a limit forbs * length
.The text was updated successfully, but these errors were encountered: