-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 with CUDA<12.4 #10095
[Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 with CUDA<12.4 #10095
Conversation
Signed-off-by: mgoin <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \ | ||
if current_platform.is_rocm() else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why it was checking for is_rocm here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assumption being made was is_rocm == torch 2.5
. If we really wanted to we could check the torch version, but I think we can just assume torch 2.5 as a floor now
Thanks! |
Signed-off-by: mgoin <[email protected]>
….4 (vllm-project#10095) Signed-off-by: mgoin <[email protected]> Signed-off-by: Mozhou <[email protected]>
….4 (vllm-project#10095) Signed-off-by: mgoin <[email protected]> Signed-off-by: Loc Huynh <[email protected]>
….4 (vllm-project#10095) Signed-off-by: mgoin <[email protected]> Signed-off-by: Sumit Dubey <[email protected]>
We need to pass a dummy tensor as input_scale now. There was a special case for ROCm in apply_fp8_linear since they moved to torch 2.5, but we did not update this case to include CUDA now that we have also upgrade pytorch versions.
Issue was found in the vLLM slack:
Easy to replicate by setting
cutlass_fp8_supported
to be False and running any FP8-dynamic model like:With this PR and forcing off cutlass, I have a successful eval: