-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel][Model] logits_soft_cap for Gemma2 with flashinfer #6051
[Kernel][Model] logits_soft_cap for Gemma2 with flashinfer #6051
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should check the flashinfer version and raise if it's too old
vllm/worker/model_runner.py
Outdated
logger.warning("Please use Flashinfer backend for models with" | ||
"logits_soft_cap (i.e., Gemma-2)." | ||
" Otherwise, the output might be wrong." | ||
" Set Flashinfer backend by " | ||
"export VLLM_ATTENTION_BACKEND=FLASHINFER.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should just raise an exception IMO.
Co-authored-by: Simon Mo <[email protected]>
vllm/worker/model_runner.py
Outdated
logits_soft_cap = getattr(self.model_config.hf_config, | ||
'final_logit_softcapping', None) | ||
if logits_soft_cap is not None and self.attn_backend.get_name( | ||
) != "flashinfer": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could I check if logits_soft_cap
is supposed to be the attn_logit_softcapping
value instead? The two values are different in the Gemma2 config.
"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yongxb Nice catch! final_logit_softcapping
is used to cap the final logits before sampling. @LiuXiaoxuanPKU Could you please fix this?
I think this warning can be removed to avoid confusion: vllm/vllm/model_executor/models/gemma2.py Line 140 in 7cd2ebb
|
I am able to run and reproduce the reported MMLU scores for both 9b and 27b models 👍 However, if I don't disable CUDA graph, vLLM will crash with this error:
|
Thanks for reporting! Could you give me an minimal reproducible example since I can run gemma-2 with flashinfer cudagraph on my end. Thanks! |
I am using the
|
I tried the script and data on H100, it seems work. Could you report your environment? Flashinfer only supports GPU with compute capability greater than 8.0 (https://developer.nvidia.com/cuda-gpus). Not sure if that might be the problem. |
I am using H100 with CUDA 12.5. Can you try sync you branch to the latest? #4412 might be related (it refactors |
yes, it's a merge conflict. Just fixed, please try again. Thanks! |
Thanks for the fix. It works now, w/ or w/o CUDA graph. |
…ect#6051) Co-authored-by: Simon Mo <[email protected]>
…ect#6051) Co-authored-by: Simon Mo <[email protected]>
…ect#6051) Co-authored-by: Simon Mo <[email protected]>
…ect#6051) Co-authored-by: Simon Mo <[email protected]> Signed-off-by: Alvant <[email protected]>
Add logits_soft_cap for flashinfer, which is needed by Gemma2 model, also add a simple gemma2 test.