Skip to content

Commit

Permalink
Support loading checkpoints quantized using Autofp8 (#286)
Browse files Browse the repository at this point in the history
Support loading
https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127

Skip cuda checks
Use scaled_fp8_quant instead of _scaled_mm
Fix weights and weight_scale for guudi2 flot8_e4m3fn range.

---------

Co-authored-by: Nir David <[email protected]>
Co-authored-by: Konrad Zawora <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2024
1 parent 45ee586 commit 29fb5ed
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 23 deletions.
3 changes: 2 additions & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ ray == 2.32.0
triton
pandas
tabulate
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0a7adab

vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0a7adab
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform

if current_platform.is_hpu():
from vllm_hpu_extension.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant

logger = init_logger(__name__)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def _get_scheme_from_parts(
# TODO @dsikka: clean-up conditions
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
is_fp8_w8a8_supported = current_platform.is_hpu() or \
self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(),
error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
Expand Down Expand Up @@ -314,7 +316,8 @@ def get_scheme(

# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
if not current_platform.is_hpu():
self._check_scheme_supported(scheme.get_min_capability())

return scheme

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.utils import is_hip

__all__ = ["CompressedTensorsW8A8Fp8"]
Expand All @@ -23,7 +24,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.cutlass_fp8_supported = not current_platform.is_hpu() and \
cutlass_fp8_supported()

@classmethod
def get_min_capability(cls) -> int:
Expand Down
24 changes: 16 additions & 8 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once

if current_platform.is_hpu():
from vllm_hpu_extension.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant

ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)
Expand Down Expand Up @@ -116,14 +120,18 @@ class Fp8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()

# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
if current_platform.is_cuda_alike():
self.cutlass_fp8_supported = cutlass_fp8_supported()

# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
self.use_marlin = False
else:
self.cutlass_fp8_supported = False
self.use_marlin = False

def create_weights(
Expand Down
40 changes: 32 additions & 8 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None

if current_platform.is_hpu():
import habana_frameworks.torch.utils.experimental as htexp
from vllm_hpu_extension.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant


def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
Expand All @@ -25,7 +30,15 @@ def cutlass_fp8_supported() -> bool:
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dtype = torch.float16
device = tensor.device
if current_platform.is_hpu():
dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
#dequant on cpu to avoid nan on gaudi2
tensor = tensor.to('cpu')

fake_qweight = tensor.to(dtype).to(device)
dq_weight = fake_qweight * inv_scale
return dq_weight

Expand Down Expand Up @@ -58,7 +71,10 @@ def requantize_with_max_scale(
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()

if current_platform.is_hpu() and htexp._get_device_type(
) == htexp.synDeviceType.synDeviceGaudi2:
max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max /
torch.finfo(torch.float8_e4m3fnuz).max)
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
Expand Down Expand Up @@ -129,12 +145,20 @@ def apply_fp8_linear(

if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
if current_platform.is_hpu():
#hpu does not support torch._scaled_mm (SW-197036)
output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight,
False, None, input.dtype,
x_scale, weight_scale, None,
False)
else:
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)

# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
Expand Down
3 changes: 1 addition & 2 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,7 @@ def _set_gc_threshold(self) -> None:

def load_model(self) -> None:
import habana_frameworks.torch.core as htcore
if self.model_config.quantization == 'inc':
htcore.hpu_set_env()
htcore.hpu_set_env()
with HabanaMemoryProfiler() as m:
with HabanaMemoryProfiler() as m_getmodel:
self.model = get_model(model_config=self.model_config,
Expand Down

0 comments on commit 29fb5ed

Please sign in to comment.