Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loading checkpoints quantized using Autofp8 #286

Merged
merged 28 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a6f8dee
Inc on vLLM - Split qk and v calculations
nirda7 Aug 6, 2024
23e931b
Support loading checkpoints quantized using Autofp8
Yantom1 Sep 16, 2024
363de3c
ruff fixes
Yantom1 Sep 16, 2024
e4fc78b
ruff fixes
Yantom1 Sep 16, 2024
d165c6e
isort fixes
Yantom1 Sep 16, 2024
6f0016b
ruff format
Yantom1 Sep 16, 2024
7f587eb
Update habana_model_runner.py
Yantom1 Sep 16, 2024
c204f3f
isort fixes
Yantom1 Sep 16, 2024
2e00486
yapf fixes
Yantom1 Sep 16, 2024
0f40204
revert commit
Yantom1 Sep 17, 2024
cd24505
Merge branch 'habana_main' into yan_autofp8
Yantom1 Sep 17, 2024
343b533
Revert "Inc on vLLM - Split qk and v calculations"
Yantom1 Sep 18, 2024
8657c4c
formnat.sh
Yantom1 Sep 18, 2024
6b485fb
delete ops.py
Yantom1 Sep 18, 2024
2e603ea
fix imports
Yantom1 Sep 18, 2024
a7a036a
isort fix
Yantom1 Sep 18, 2024
2b4a196
update vllm-hpu-extension commit hash
Yantom1 Sep 19, 2024
454acc9
pr fix
Yantom1 Sep 23, 2024
26d8321
Merge branch 'habana_main' into yan_autofp8
Yantom1 Sep 24, 2024
e92abd6
Update fp8.py
Yantom1 Sep 24, 2024
f150851
Update fused_moe.py
Yantom1 Sep 24, 2024
c7dcbbc
Merge branch 'habana_main' into yan_autofp8
Yantom1 Sep 24, 2024
3e8762e
Update compressed_tensors.py
Yantom1 Sep 25, 2024
5726801
Update compressed_tensors_w8a8_fp8.py
Yantom1 Sep 25, 2024
426e8e1
Update llama.py
Yantom1 Sep 25, 2024
4cf34f4
Update compressed_tensors_w8a8_fp8.py
Yantom1 Sep 25, 2024
f58d4c1
Update vllm/model_executor/layers/quantization/fp8.py
Yantom1 Sep 25, 2024
db9affe
Update fp8.py
Yantom1 Sep 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ ray == 2.32.0
triton
pandas
tabulate
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0a7adab

vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0a7adab
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform

if current_platform.is_hpu():
from vllm_hpu_extension.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant

logger = init_logger(__name__)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def _get_scheme_from_parts(
# TODO @dsikka: clean-up conditions
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
is_fp8_w8a8_supported = current_platform.is_hpu() or \
self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(),
error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
Expand Down Expand Up @@ -314,7 +316,8 @@ def get_scheme(

# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
if not current_platform.is_hpu():
self._check_scheme_supported(scheme.get_min_capability())

return scheme

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
from vllm.utils import is_hip

__all__ = ["CompressedTensorsW8A8Fp8"]
Expand All @@ -23,7 +24,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.cutlass_fp8_supported = not current_platform.is_hpu() and \
cutlass_fp8_supported()

@classmethod
def get_min_capability(cls) -> int:
Expand Down
24 changes: 16 additions & 8 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from vllm.platforms import current_platform
from vllm.utils import is_hip, print_warning_once

if current_platform.is_hpu():
from vllm_hpu_extension.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant

ACTIVATION_SCHEMES = ["static", "dynamic"]

logger = init_logger(__name__)
Expand Down Expand Up @@ -116,14 +120,18 @@ class Fp8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()

# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
if torch.cuda.is_available():
Yantom1 marked this conversation as resolved.
Show resolved Hide resolved
self.cutlass_fp8_supported = cutlass_fp8_supported()

# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
self.use_marlin = False
else:
self.cutlass_fp8_supported = False
self.use_marlin = False

def create_weights(
Expand Down
40 changes: 32 additions & 8 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None

if current_platform.is_hpu():
import habana_frameworks.torch.utils.experimental as htexp
from vllm_hpu_extension.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant


def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
Expand All @@ -25,7 +30,15 @@ def cutlass_fp8_supported() -> bool:
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dtype = torch.float16
device = tensor.device
if current_platform.is_hpu():
dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
#dequant on cpu to avoid nan on gaudi2
tensor = tensor.to('cpu')

fake_qweight = tensor.to(dtype).to(device)
dq_weight = fake_qweight * inv_scale
return dq_weight

Expand Down Expand Up @@ -58,7 +71,10 @@ def requantize_with_max_scale(
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()

if current_platform.is_hpu() and htexp._get_device_type(
) == htexp.synDeviceType.synDeviceGaudi2:
max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max /
torch.finfo(torch.float8_e4m3fnuz).max)
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
Expand Down Expand Up @@ -129,12 +145,20 @@ def apply_fp8_linear(

if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
if current_platform.is_hpu():
#hpu does not support torch._scaled_mm (SW-197036)
output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight,
False, None, input.dtype,
x_scale, weight_scale, None,
False)
else:
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)

# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
Expand Down
3 changes: 1 addition & 2 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,7 @@ def _set_gc_threshold(self) -> None:

def load_model(self) -> None:
import habana_frameworks.torch.core as htcore
if self.model_config.quantization == 'inc':
htcore.hpu_set_env()
htcore.hpu_set_env()
with HabanaMemoryProfiler() as m:
with HabanaMemoryProfiler() as m_getmodel:
self.model = get_model(model_config=self.model_config,
Expand Down
Loading