Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Running on a single machine with multiple GPUs error #9875

Open
1 task done
Wiselnn570 opened this issue Oct 31, 2024 · 6 comments
Open
1 task done

[Bug]: Running on a single machine with multiple GPUs error #9875

Wiselnn570 opened this issue Oct 31, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@Wiselnn570
Copy link

Wiselnn570 commented Oct 31, 2024

Your current environment

Name: vllm
Version: 0.6.3.post2.dev171+g890ca360

Model Input Dumps

No response

🐛 Describe the bug

I used the interface from this vllm repository to load the model and ran eval scripts on vlmevalkit(https://github.com/open-compass/VLMEvalKit)

torchrun --nproc-per-node=8 run.py --data Video-MME --model Qwen2_VL-M-RoPE-80k

for evaluation, but I got the error

RuntimeError: world_size (8) is not equal to tensor_model_parallel_size (1) x pipeline_model_parallel_size (1). 

Could you please advise on how to resolve this?
Here is the interface

from vllm import LLM
llm = LLM("/mnt/hwfile/mllm/weixilin/cache/Qwen2-VL-7B-Instruct", 
            max_model_len=100000,
            limit_mm_per_prompt={"video": 10},
            )

raise RuntimeError(

Seems that the error occur at this assertion, so what change should I make to fit the assertion, thanks.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@Wiselnn570 Wiselnn570 added the bug Something isn't working label Oct 31, 2024
@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 31, 2024

vLLM has its own multiprocessing setup for TP/PP. You should avoid using torchrun with vLLM.

@DarkLight1337
Copy link
Member

cc @youkaichao

@youkaichao
Copy link
Member

yeah we don't support torchrun , but it would be good to provide some scripts to run multiple vllm instances with a proxy erver using litellm .

@youkaichao
Copy link
Member

@Wiselnn570 contribution welcome!

@youkaichao
Copy link
Member

@Wiselnn570
Copy link
Author

@youkaichao @DarkLight1337 Sure, I'm glad to contribute to this community when I have time! One more question, recently, I encountered an issue while modifying the positional encoding in the mrope_input_positions section of the Qwen2-VL code, and I try but don't know how to resolve it. In short, I'm aiming to explore the model's performance when extrapolating to a 60k context on the Qwen2-VL 7B model, using video data for testing. I tried replacing this section (

MRotaryEmbedding.get_input_positions(
) with vanilla-ROPE(That is, placing image, video, and text tokens all on the main diagonal of the M-RoPE.), which caused the max value of the mrope_input_positions up to approximately 59k, but it eventually led to an error.

../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [64,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [65,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [66,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [67,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [68,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [69,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
...
INFO 11-03 18:58:15 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20241103-185815.pkl...
WARNING 11-03 18:58:15 model_runner_base.py:143] Failed to pickle inputs of failed execution: CUDA error: device-side assert triggered
WARNING 11-03 18:58:15 model_runner_base.py:143] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
WARNING 11-03 18:58:15 model_runner_base.py:143] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
WARNING 11-03 18:58:15 model_runner_base.py:143] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
WARNING 11-03 18:58:15 model_runner_base.py:143] 
RuntimeError: Error in model execution: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I have already tested the original M-RoPE, which outputs correctly with a 60k context, and the maximum mrope_input_positions value is around 300. So, I am wondering if the position value is too large, causing it to exceed the index. How should I modify it to support vanilla-RoPE (Or perhaps some other 3D positional encoding, where the positional encoding values are quite large.) for evaluation? Thanks!

p.s. I noticed that this function (

def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
) was called several times before inferring on my provided video test data, and I’m wondering if this might be related.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants