Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: vllm stuck when using prompt_token_ids and setting prompt_logprobs #5872

Closed
xinyangz opened this issue Jun 26, 2024 · 9 comments · Fixed by #6223
Closed

[Bug]: vllm stuck when using prompt_token_ids and setting prompt_logprobs #5872

xinyangz opened this issue Jun 26, 2024 · 9 comments · Fixed by #6223
Labels
bug Something isn't working

Comments

@xinyangz
Copy link

xinyangz commented Jun 26, 2024

Your current environment

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.5
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.10.192-183.736.amzn2.x86_64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 525.85.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             96
On-line CPU(s) list:                0-95
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          2
Stepping:                           7
BogoMIPS:                           5999.99
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2
 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni
 pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf
_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap cl
flushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          1.5 MiB (48 instances)
L1i cache:                          1.5 MiB (48 instances)
L2 cache:                           48 MiB (48 instances)
L3 cache:                           71.5 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-23,48-71
NUMA node1 CPU(s):                  24-47,72-95
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        KVM: Mitigation: VMX unsupported
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, STIBP disabled, RSB filling
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] transformers==4.41.2
[pip3] triton==2.3.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.0.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity
GPU0     X      NV12    NV12    NV12    NV12    NV12    NV12    NV12    0-23,48-71      0
GPU1    NV12     X      NV12    NV12    NV12    NV12    NV12    NV12    0-23,48-71      0
GPU2    NV12    NV12     X      NV12    NV12    NV12    NV12    NV12    0-23,48-71      0
GPU3    NV12    NV12    NV12     X      NV12    NV12    NV12    NV12    0-23,48-71      0
GPU4    NV12    NV12    NV12    NV12     X      NV12    NV12    NV12    24-47,72-95     1
GPU5    NV12    NV12    NV12    NV12    NV12     X      NV12    NV12    24-47,72-95     1
GPU6    NV12    NV12    NV12    NV12    NV12    NV12     X      NV12    24-47,72-95     1
GPU7    NV12    NV12    NV12    NV12    NV12    NV12    NV12     X      24-47,72-95     1

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

The Issue

When using the LLM class with both prompt_token_ids and prompt_logprobs, I have found vLLM sometimes would stuck. A minimal reproducing example is as follows:

import numpy as np
from vllm import LLM, SamplingParams

model = "mistral-community/Mistral-7B-v0.2"

MAX_RND_TOKEN_ID = 20000
N_REQUESTS = 512
N_SEQ_LEN = 100

np.random.seed(1)
dummy_requests = np.random.randint(1, MAX_RND_TOKEN_ID, size=(N_REQUESTS, N_SEQ_LEN), dtype=np.int32)
dummy_requests = [e.tolist() for e in dummy_requests]

llm = LLM(model=model, gpu_memory_utilization=0.65, tensor_parallel_size=1)
sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=1)

llm.generate(prompt_token_ids=dummy_requests, sampling_params=sampling_params)

Running with the official docker image:

docker run --gpus all --shm-size=10g --rm -e CUDA_VISIBLE_DEVICES=0 -v "$(pwd):/app" --entrypoint python3 vllm/vllm-openai:v0.5.0.post1 /app/vllm_reproduce.py

The generation would stuck. Note that this does not happen every time. For example, with a relatively small N_REQUESTS, sometimes the generation will run just fine.

If we detokenize and use raw text as input, or if we turn prompt_logprobs=None, the code would not stuck.

The Analysis

I've done some initial analysis. The code seems to stuck at the detokenization stage of output_processor.

# ...(more stack traces omitted)...
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 694, in _process_model_outputs
[rank0]:     self.output_processor.process_prompt_logprob(seq_group, outputs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/output_processor/single_step.py", line 65, in process_prompt_logprob
[rank0]:     self.detokenizer.decode_prompt_logprobs_inplace(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/transformers_utils/detokenizer.py", line 60, in decode_prompt_logprobs_inplace
[rank0]:     new_read_offset) = detokenize_incrementally(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/transformers_utils/detokenizer.py", line 289, in detokenize_incrementally
[rank0]:     new_text = tokenizer.convert_tokens_to_string(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_fast.py", line 619, in convert_tokens_to_string
[rank0]:     return self.backend_tokenizer.decoder.decode(tokens)
[rank0]: KeyboardInterrupt

The decode_prompt_logprobs_inplace function (https://github.com/vllm-project/vllm/blob/v0.5.0.post1/vllm/transformers_utils/detokenizer.py#L24-L87) seems suspicious to me.

I checked for two things. First, printing len(prev_tokens). Second, checking the advancement logic and see if condition if token_id == all_token_ids[token_position] was ever met.

I've found that for examples that run fine, usually prev_tokens is always empty and the condition is rarely True. Then, for the example that vLLM got stuck, I saw this:

process_prompt_logprob, seq_group_request_id 81
all_token_ids[token_position] in prompt_logprobs_for_token True
token_position 0 len(prev_tokens) None
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 1 len(prev_tokens) 1
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 2 len(prev_tokens) 2
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 3 len(prev_tokens) 4
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 4 len(prev_tokens) 8
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 5 len(prev_tokens) 16
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 6 len(prev_tokens) 32
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 7 len(prev_tokens) 64
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 8 len(prev_tokens) 128
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 9 len(prev_tokens) 256
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 10 len(prev_tokens) 512
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 11 len(prev_tokens) 1024
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 12 len(prev_tokens) 2048
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 13 len(prev_tokens) 4096
all_token_ids[token_position] in prompt_logprobs_for_token False
token_position 14 len(prev_tokens) 8192

It seems prev_tokens was growing unexpectedly.

Overall, the detokenization logic seems confusing to me, especially the parts where prev_tokens is updated. I am not sure why the problem did not occur when not using prompt_token_ids. I am also not sure if the issue has been observed previously. This is supposed to be a pretty common use case for evaluations. (E.g., in lm-evaluation-harness).

@zifeitong
Copy link
Contributor

zifeitong commented Jun 26, 2024

Hi,

Can you give #5846 a try? I think it would fix this bug.

Thanks!

@xinyangz
Copy link
Author

Hi,

Can you give #5846 a try? I think it would fix this bug.

Thanks!

Can confirm it is fixed! Thank you

@aurickq
Copy link
Contributor

aurickq commented Jun 28, 2024

Unfortunately I'm still seeing this issue using lm_eval after applying #5846. But now it triggers when submitting more than 1 prompt at a time. Attached a minimum reproduction script.

offline_inference.py.txt

@zifeitong
Copy link
Contributor

Unfortunately I'm still seeing this issue using lm_eval after applying #5846. But now it triggers when submitting more than 1 prompt at a time. Attached a minimum reproduction script.

offline_inference.py.txt

Thanks for the reproducer.

I find that I can only trigger the endless-loop when both enable_chunked_prefill=True and max_tokens=1 are set (unfortunately it's also what lm_eval uses).

It appears #5919 and #5846 fix two different bugs.

@Xuekai-Zhu
Copy link

The same problem. I have confirmed #5846 , cannot fix this bug!

@Xuekai-Zhu
Copy link

I found an an alternative approach, and only require log_probs, you can decode the tokens yourself. Simply block the following two lines in vllm/engine/output_processor/single_step.py:

# self.detokenizer.decode_prompt_logprobs_inplace(
            #     seq_group, prompt_logprobs)

Full function:

def process_prompt_logprob(self, seq_group: SequenceGroup,
                               outputs: List[SequenceGroupOutput]) -> None:
        assert len(outputs) == 1, ("Single step should only has 1 output.")
        output = outputs[0]
        prompt_logprobs = output.prompt_logprobs
        if (prompt_logprobs is not None
                and seq_group.sampling_params.detokenize and self.detokenizer):
            # self.detokenizer.decode_prompt_logprobs_inplace(
            #     seq_group, prompt_logprobs)
            if not seq_group.prompt_logprobs:
                # The first prompt token's logprob is None because it doesn't
                # have tokens that are precedent.
                seq_group.prompt_logprobs = [None]
            seq_group.prompt_logprobs.extend(prompt_logprobs)

@zifeitong
Copy link
Contributor

zifeitong commented Jul 3, 2024

I found an an alternative approach, and only require log_probs, you can decode the tokens yourself. Simply block the following two lines in vllm/engine/output_processor/single_step.py:

# self.detokenizer.decode_prompt_logprobs_inplace(
            #     seq_group, prompt_logprobs)

Full function:

def process_prompt_logprob(self, seq_group: SequenceGroup,
                               outputs: List[SequenceGroupOutput]) -> None:
        assert len(outputs) == 1, ("Single step should only has 1 output.")
        output = outputs[0]
        prompt_logprobs = output.prompt_logprobs
        if (prompt_logprobs is not None
                and seq_group.sampling_params.detokenize and self.detokenizer):
            # self.detokenizer.decode_prompt_logprobs_inplace(
            #     seq_group, prompt_logprobs)
            if not seq_group.prompt_logprobs:
                # The first prompt token's logprob is None because it doesn't
                # have tokens that are precedent.
                seq_group.prompt_logprobs = [None]
            seq_group.prompt_logprobs.extend(prompt_logprobs)

If you don't need the decoded text, just set SamplingParams.detokenize to False.

@Xuekai-Zhu
Copy link

Xuekai-Zhu commented Jul 3, 2024

@zifeitong But if you set detokenize=False, prompt_logprobs=1 in SamplingParams, you can not get the prompt_logprobs, the return prompt_logprobs==None.

detokenize=False, and prompt_logprobs=1 is confilct !

sampling params:

sampling_params = SamplingParams(max_tokens=1, 
                                    #  min_tokens=args.min_length
                                     prompt_logprobs=1,
                                     detokenize=False 
                                     # n=1,  temperature=0.7, top_p=0.8, repetition_penalty=1.05, top_k=20, 
                                     )

output:

prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='', token_ids=[11794], cumulative_logprob=-5.925021171569824, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1719989851.6168363, last_token_time=1719989851.6168363, first_scheduled_time=1719989851.9300385, first_token_time=1719989852.3293676, time_in_queue=0.3132021427154541, finished_time=1719989852.3296342), lora_request=None

@dddddjcole
Copy link

@zifeitong But if you set detokenize=False, prompt_logprobs=1 in SamplingParams, you can not get the prompt_logprobs, the return prompt_logprobs==None.

detokenize=False, and prompt_logprobs=1 is confilct !

sampling params:

sampling_params = SamplingParams(max_tokens=1, 
                                    #  min_tokens=args.min_length
                                     prompt_logprobs=1,
                                     detokenize=False 
                                     # n=1,  temperature=0.7, top_p=0.8, repetition_penalty=1.05, top_k=20, 
                                     )

output:

prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='', token_ids=[11794], cumulative_logprob=-5.925021171569824, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1719989851.6168363, last_token_time=1719989851.6168363, first_scheduled_time=1719989851.9300385, first_token_time=1719989852.3293676, time_in_queue=0.3132021427154541, finished_time=1719989852.3296342), lora_request=None

你好,现在这个问题解决了吗,我想用sampling_kwargs = SamplingParams(temperature=0,prompt_logprobs=0,max_tokens=1) 跑一下ppl,但是发现中间总是卡顿,而且gpu的利用率比较低

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
5 participants