Skip to content

Commit

Permalink
[Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 with CUDA<12…
Browse files Browse the repository at this point in the history
….4 (vllm-project#10095)

Signed-off-by: mgoin <[email protected]>
Signed-off-by: Loc Huynh <[email protected]>
  • Loading branch information
mgoin authored and JC1DA committed Nov 11, 2024
1 parent f9026cc commit 3b59c00
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \
if current_platform.is_rocm() else None
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)


def cutlass_fp8_supported() -> bool:
Expand Down Expand Up @@ -166,8 +165,7 @@ def apply_fp8_linear(

# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if (TORCH_DEVICE_IDENTITY is not None
and TORCH_DEVICE_IDENTITY.device != weight.device):
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)

# GEMM
Expand Down

0 comments on commit 3b59c00

Please sign in to comment.