Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 with CUDA<12.4 #10095

Merged
merged 2 commits into from
Nov 7, 2024

Conversation

mgoin
Copy link
Collaborator

@mgoin mgoin commented Nov 6, 2024

We need to pass a dummy tensor as input_scale now. There was a special case for ROCm in apply_fp8_linear since they moved to torch 2.5, but we did not update this case to include CUDA now that we have also upgrade pytorch versions.

Issue was found in the vLLM slack:

[rank0]:   File "/home/ray/anaconda3/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py", line 135, in apply_weights
[rank0]:     return apply_fp8_linear(
[rank0]:            ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ray/anaconda3/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 176, in apply_fp8_linear
[rank0]:     output = torch._scaled_mm(qinput,
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: TypeError: _scaled_mm(): argument 'scale_a' must be Tensor, not NoneType

Easy to replicate by setting cutlass_fp8_supported to be False and running any FP8-dynamic model like:

vllm serve neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic

With this PR and forcing off cutlass, I have a successful eval:

vllm (pretrained=neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3146|±  |0.0128|
|     |       |strict-match    |     5|exact_match|↑  |0.3146|±  |0.0128|

Copy link

github-actions bot commented Nov 6, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 6, 2024
@comaniac comaniac changed the title [Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 [Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 with CUDA 12.4- Nov 6, 2024
@mgoin mgoin changed the title [Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 with CUDA 12.4- [Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 with CUDA<12.4 Nov 6, 2024
Comment on lines -10 to -11
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \
if current_platform.is_rocm() else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why it was checking for is_rocm here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assumption being made was is_rocm == torch 2.5. If we really wanted to we could check the torch version, but I think we can just assume torch 2.5 as a floor now

@hanming-lu
Copy link

Thanks!

Signed-off-by: mgoin <[email protected]>
@comaniac comaniac enabled auto-merge (squash) November 7, 2024 00:51
@comaniac comaniac merged commit 4ab3256 into vllm-project:main Nov 7, 2024
50 checks passed
spliii pushed a commit to spliii/vllm that referenced this pull request Nov 7, 2024
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants