Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [64,0,0] Assertion -sizes[i] <= index && in dex < sizes[i] && "index out of bounds" failed. #9965

Closed
Wiselnn570 opened this issue Nov 3, 2024 · 5 comments

Comments

@Wiselnn570
Copy link

Wiselnn570 commented Nov 3, 2024

Recently, I have encountered an issue while modifying the positional encoding in the mrope_input_positions section of the Qwen2-VL code, and I have tried but don't know how to resolve it. In short, I'm aiming to explore the model's performance when extrapolating to a 60k context on the Qwen2-VL 7B model, using video data for testing. I tried replacing this section (

MRotaryEmbedding.get_input_positions(
) with vanilla-ROPE(That is, placing image, video, and text tokens all on the main diagonal of the M-RoPE.), which caused the max value of the mrope_input_positions up to approximately 59k, but it eventually led to an error.

../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [64,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [65,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [66,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [67,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [68,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [69,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
...
INFO 11-03 18:58:15 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20241103-185815.pkl...
WARNING 11-03 18:58:15 model_runner_base.py:143] Failed to pickle inputs of failed execution: CUDA error: device-side assert triggered
WARNING 11-03 18:58:15 model_runner_base.py:143] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
WARNING 11-03 18:58:15 model_runner_base.py:143] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
WARNING 11-03 18:58:15 model_runner_base.py:143] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
WARNING 11-03 18:58:15 model_runner_base.py:143] 
RuntimeError: Error in model execution: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I have already tested the original M-RoPE, which outputs correctly with a 60k context, and the maximum mrope_input_positions value is around 300. So, I am wondering if the position value is too large, causing it to exceed the index. How should I modify it to support vanilla-RoPE (Or perhaps some other 3D positional encoding, where the positional encoding values are quite large.) for evaluation? Thanks!

p.s. I noticed that this function (

def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
) was called several times before inferring on my provided video test data, and I’m wondering if this might be related.

using environment
Name: vllm
Version: 0.6.3.post2.dev171+g890ca360

Originally posted by @Wiselnn570 in #9875 (comment)

@DarkLight1337
Copy link
Member

cc @WoosukKwon

@ywang96
Copy link
Member

ywang96 commented Nov 4, 2024

also cc @fyabc: if you can provide any insight on this that would be much appreciated since mrope is only used for Qwen2VL!

@Wiselnn570
Copy link
Author

Wiselnn570 commented Nov 4, 2024

also cc @fyabc: if you can provide any insight on this that would be much appreciated since mrope is only used for Qwen2VL!

@ywang96 Yes, so it's difficult for me to debug this CUDA error due to its particularity. Specifically, I implemented get_input_vanilla_positions by following this function:

def get_input_positions(
, and replaced MRotaryEmbedding.get_input_positions with get_input_vanilla_positions at
MRotaryEmbedding.get_input_positions(
. The program is able to return results normally, when the number of input frames is small, due to the maximum value of the positional encoding is low. However, when testing with 48k, 64k, or 80k context lengths, this error occurs.

@fyabc
Copy link
Contributor

fyabc commented Nov 6, 2024

@Wiselnn570 Hi, can you add export CUDA_LAUNCH_BLOCKING=1 and provide a detailed traceback about this cuda error? I guess that the error is raised when reading cos_sin_cache in RotaryEmbedding.

Also, can you check max_position_embeddings value in your model config? You should increase this value when running long queries.

@Wiselnn570
Copy link
Author

@Wiselnn570 Hi, can you add export CUDA_LAUNCH_BLOCKING=1 and provide a detailed traceback about this cuda error? I guess that the error is raised when reading cos_sin_cache in RotaryEmbedding.

Also, can you check max_position_embeddings value in your model config? You should increase this value when running long queries.

Thank you; it looks like your suggestion was spot-on. Increasing the cos_sin_cache resolved the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants