Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decode error while inferencing a batch of prompts #340

Closed
SiriusNEO opened this issue Jul 3, 2023 · 22 comments · Fixed by #3500 or #3685
Closed

Decode error while inferencing a batch of prompts #340

SiriusNEO opened this issue Jul 3, 2023 · 22 comments · Fixed by #3500 or #3685
Assignees
Labels
bug Something isn't working

Comments

@SiriusNEO
Copy link
Contributor

I'm trying to benchmark the performance of vLLM OPT. But I find that when I pass a relatively large batch of prompts to vLLM, it will raise decode error when the sequence length meets a threshold (which makes the problem look like an OOM).

A minimal reproduction for this issue:

from vllm import LLM, SamplingParams

def make_input(bs):
    return ["Hello!" for _ in range(bs)]

bs = 128
generate_length = 200

# Create a sampling params object.
sampling_params = SamplingParams(
    temperature=0.8, 
    top_p=0.95, 
    max_tokens=generate_length)

# Create an LLM.
llm = LLM(
    model="facebook/opt-125m",
    use_dummy_weights=True,
)
input = make_input(bs)
out = llm.generate(input, sampling_params)

When bs=128, the error happens in the 108-th token approximately. The error looks like

Traceback (most recent call last):
  File "vllm-none-problem-repro.py", line 21, in <module>
    out = llm.generate(input, sampling_params)
  File "/llm-bench/vllm-src/vllm/entrypoints/llm.py", line 127, in generate
    return self._run_engine(use_tqdm)
  File "/llm-bench/vllm-src/vllm/entrypoints/llm.py", line 147, in _run_engine
    step_outputs = self.llm_engine.step()
  File "/llm-bench/vllm-src/vllm/engine/llm_engine.py", line 246, in step
    self._decode_sequences(seq_groups)
  File "/llm-bench/vllm-src/vllm/engine/llm_engine.py", line 263, in _decode_sequences
    new_token, new_output_text = detokenize_incrementally(
  File "/llm-bench/vllm-src/vllm/transformers_utils/tokenizer.py", line 73, in detokenize_incrementally
    output_text = tokenizer.convert_tokens_to_string(output_tokens)
  File "/opt/conda/lib/python3.8/site-packages/transformers/tokenization_utils_fast.py", line 533, in convert_tokens_to_string
    return self.backend_tokenizer.decoder.decode(tokens)
TypeError: argument 'tokens': 'NoneType' object cannot be converted to 'PyString

If I use a smaller bs, the "threshold" will also increase (>108). For example, it's around 210 when bs=64. Seems that there is a limit for bs * length.

@zhuohan123 zhuohan123 added the bug Something isn't working label Jul 3, 2023
@tju01
Copy link

tju01 commented Jul 27, 2023

I found that the batch size is only indirectly the reason and it doesn't have anything to do with OOM or similar things. For example, if I just change the random seed to the following and keep the sequence length and batch size the same, then the bug doesn't happen anymore for this specific batch size, but it will happen for another larger one:

# Create an LLM.
llm = LLM(
    model="facebook/opt-125m",
    use_dummy_weights=True,
    seed=2,
)

The reason is that for some models there can be a mismatch between the config.vocab_size and the len(tokenizer). The model outputs a distribution over tokens in the range vocab_size, but only tokens in the range len(tokenizer) should actually be sampled. The remaining tokens are just padding and when sampling these tokens and decoding them, the result will be None instead of a string and so the exception will be thrown.

To fix this, if I change the config.vocab_size (which is 50272 for facebook/opt-125m) in the following line to 50265, i.e. len(tokenizer), then the bug doesn't happen anymore for any seed and batch size.

self.sampler = Sampler(config.vocab_size)

I have also observed this for LLaMA & LLaMA-2 where it seems like for some models on huggingface the vocab_size does correspond to the actual number of tokens that should be sampled while for some others it doesn't. It depends on whether the number of tokens is already a multiple of 16 or if there needs to be padding. There might also be other models than OPT and LLaMA where this happens.

from transformers import AutoTokenizer, PretrainedConfig

print(len(AutoTokenizer.from_pretrained('meta-llama/Llama-2-13b-hf'))) # 32000
print(PretrainedConfig.from_pretrained('meta-llama/Llama-2-13b-hf').vocab_size) # 32000

print(len(AutoTokenizer.from_pretrained('NousResearch/Nous-Hermes-Llama2-13b'))) # 32001
print(PretrainedConfig.from_pretrained('NousResearch/Nous-Hermes-Llama2-13b').vocab_size) # 32032

A fix in vLLM could be to obtain the number of tokens from the tokenizer instead of the config.json file.

@SiriusNEO
Copy link
Contributor Author

@tju01 Thank you! Now it looks like larger batch size and sequence length are just to increase the probability this error happens.

@xxm1668
Copy link

xxm1668 commented Dec 14, 2023

怎么替换?

@Tan-YiFan
Copy link

@xxm1668 Modifying the vocab_size in ~/.cache/huggingface/hub/<model>/snapshots/<commit>/config.json would help.

Another solution would be modifying the vLLM. For example, if you are using opt model, you should modify vocab_size in __init__ function of OPTForCausalLM: self.sampler = Sampler(config.vocab_size)

@xxm1668
Copy link

xxm1668 commented Dec 15, 2023

thx

@yhyu13
Copy link

yhyu13 commented Dec 23, 2023

Would there be a fix to this bug?

We should essentially compare len(tokenizer) and config.vocab_size and pick the smaller value

@zhangmiaosen2000
Copy link

zhangmiaosen2000 commented Dec 30, 2023

  1. I met the same problems. I cannot fix the problem by switching vocab_size in config.josn because I am using CodeLLaMA, which use vocab_parallel_embedding and there is an assertion in it when loading weights. In order to use vocab parallel, it is also necessary for the model's vocab to be 2^n*k like.
  2. Hard coding in LLaMAForCausalLM: self.sampler = Sampler(1111111) # for example does solve the problem.
  3. Simple way to fix this bug: Write another attribute whose value is len(tokenizer) (e.g., write tokenizer_vocab_size into model's config when initializing LLM) then use self.sampler = Sampler(config.tokenizer_vocab_size) instead.

Hope those information will help.

@cassanof
Copy link
Contributor

cassanof commented Feb 9, 2024

This is still an issue...

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 21, 2024

Hi all. Since this issue only exists a few models, so it would be better to add the warning log about this (#3500).

@andersonbcdefg
Copy link

Warning log doesn't fix the issue with generation though... this problem is very common in fine-tuned versions of models with added tokens, such as any fine-tuned with ChatML prompts. For example, NousResearch/Hermes-2-Pro-Mistral-7B. These fine-tuned models are unusable with vLLM with this bug.

@andersonbcdefg
Copy link

@simon-mo is there already a way to solve this problem without modifying the vLLM source code? I really would prefer not to have to maintain a fork of vLLM to use Nous-Hermes & other finetuned models...

@simon-mo
Copy link
Collaborator

@esmeetu any suggestion here to have a good default way to support fine-tuned model?

@esmeetu esmeetu self-assigned this Mar 26, 2024
@esmeetu esmeetu reopened this Mar 26, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Mar 26, 2024

Looking into this

from transformers import AutoConfig, AutoTokenizer

config = AutoConfig.from_pretrained("NousResearch/Hermes-2-Pro-Mistral-7B")
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Hermes-2-Pro-Mistral-7B")

print(config.vocab_size)  # >> 32032
print(len(tokenizer))     # >> 32002

So it seems like 30 tokens are missing from the tokenizer that are included in the model. Not sure how this is possible, but I can see why this would cause issues if the model predicts one of the 30 tokens that are not in the vocabulary of the tokenizer

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Mar 26, 2024

I think the issue is that the vocab_size is expanded to be a nice multiple for the gpu during training. These tokens will not be trained (since there is nothing in the dataset), so they are very unlikely to be sampled.

@simon-mo

I think an approach we could take is to expand the tokenizer to have more pad tokens in this scenario. This will allow

  • vocab_size to be a nice multiple for the gpus
  • gracefully handle a case where one of the "fake" tokens are predicted

Thoughts?

@Blubberblub
Copy link

Blubberblub commented Mar 27, 2024

Had the same issue with LeoLM/leo-mistral-hessianai-7b-chat model from huggingface. Vocab size: 32128, tokenizer len: 32002.

@youkaichao
Copy link
Member

I took a look at how transformers deals with this problem. Their idea is simple: if we get a token id larger than the length of tokenizer length, the decode step just regard the token as an empty string.

Here is a demo:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
print(len(tokenizer)) # 50265
tokenizer.decode([340]) # a token id that exists, i.e. ' news'
tokenizer.decode([34000000]) # a token id that does not exist, returns empty string ''

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 28, 2024

@youkaichao I agree with the transformers solution. But now i can't reproduce this issue using large batch sizes. Could someone provide the reproduction code for this? cc @simon-mo @andersonbcdefg

@youkaichao
Copy link
Member

That's not easy to reproduce, but it indeed can happen, especially when we sample for many steps, and sample out a token id that's out of the scope of the tokenizer. Basically an index-out-of-range problem.

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 28, 2024

@andersonbcdefg @Blubberblub Could you try this PR #3685?

@Toska12138
Copy link

Benchmarking Yi-34B-Chat with vllm. Encountered the same error. The decoder need to check whether the ids are inside vocab. (See image below, some ids below are missing.)
image

jikunshang pushed a commit to jikunshang/vllm that referenced this issue Sep 30, 2024
@Zero1002
Copy link

So @Toska12138 , How do you solve this problem in Yi-34B-Chat with vllm?

@Toska12138
Copy link

So @Toska12138 , How do you solve this problem in Yi-34B-Chat with vllm?

The naive way is to modify the tokenizer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet