From ef038f7797855412dd3679cdc043222bb610edd7 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 6 Nov 2024 19:54:13 -0500 Subject: [PATCH] [Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 with CUDA<12.4 (#10095) Signed-off-by: mgoin Signed-off-by: Mozhou --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 445117ac99a34..ec73533126ab6 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -7,8 +7,7 @@ # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \ - if current_platform.is_rocm() else None +TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) def cutlass_fp8_supported() -> bool: @@ -166,8 +165,7 @@ def apply_fp8_linear( # Making sure the dummy tensor is on the same device as the weight global TORCH_DEVICE_IDENTITY - if (TORCH_DEVICE_IDENTITY is not None - and TORCH_DEVICE_IDENTITY.device != weight.device): + if TORCH_DEVICE_IDENTITY.device != weight.device: TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) # GEMM