Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][FP8] Initial support with dynamic per-tensor scaling #4118

Merged
merged 16 commits into from
Apr 20, 2024
24 changes: 24 additions & 0 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Tests whether FP8 computation is enabled correctly.

Run `pytest tests/quantization/test_fp8.py --forked`.
"""
import pytest
import torch

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]


@pytest.mark.skipif(
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
reason="FP8 is not supported on this GPU type.")
def test_load_fp16_model(vllm_runner) -> None:
llm = vllm_runner("facebook/opt-125m", quantization="fp8")

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.linear_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn
9 changes: 5 additions & 4 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq" and "squeezellm". If None, we first check
the `quantization_config` attribute in the model config file. If
that is None, we assume the model weights are not quantized and use
`dtype` to determine the data type of the weights.
we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
If None, we first check the `quantization_config` attribute in the
model config file. If that is None, we assume the model weights are
not quantized and use `dtype` to determine the data type of
the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter

from vllm.distributed import (divide, get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -48,6 +49,13 @@ def apply_weights(self,
Expects create_weights to have been called before on the layer."""
raise NotImplementedError

def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.

This can be used for example, to transpose weights for computation.
"""
return


class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization.
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import FP8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig

QUANTIZATION_METHODS = {
"awq": AWQConfig,
"fp8": FP8Config,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
Expand Down
138 changes: 138 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)


class FP8Config(QuantizationConfig):
"""Config class for FP8."""

@classmethod
def get_name(cls) -> str:
return "fp8"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]

@classmethod
def get_min_capability(cls) -> int:
# TODO: PyTorch 2.3.0+ is required to run FP8 on
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
# be included: https://github.com/pytorch/pytorch/pull/118881
return 90
comaniac marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def get_config_filenames(cls) -> List[str]:
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
return cls()

def get_linear_method(self) -> "Fp8LinearMethod":
return Fp8LinearMethod(self)

def get_scaled_act_names(self) -> List[str]:
return []


class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
We now support common FP16/BF16 model checkpoints ONLY. The weight
scaling factor will be initialized after the model weights are loaded.

Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)

Args:
quant_config: The quantization config.
"""

def __init__(self, quant_config: FP8Config):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, extra_weight_attrs)

w_scale = Parameter(
comaniac marked this conversation as resolved.
Show resolved Hide resolved
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("weight_scaling_factor", w_scale)

def process_weights_after_loading(self, layer: Module) -> None:
# Although the linear_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization.
if not hasattr(layer, "weight_scaling_factor"):
return
comaniac marked this conversation as resolved.
Show resolved Hide resolved

qweight, weight_scale = per_tensor_quantize(layer.weight)
# torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scaling_factor.data.copy_(weight_scale)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qinput, x_scale = per_tensor_quantize(x)
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scaling_factor,
bias=bias,
)
return output


def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]:
"""Quantize a tensor using per-tensor static scaling factor.

Args:
tensor: The input tensor.
"""
finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
min_val, max_val = tensor.aminmax()
amax = min_val.abs().max(max_val.abs())
scale = finfo.max / amax.clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
scale = scale.float().reciprocal()
return qweight, scale
4 changes: 4 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ def load_model(self, *, model_config: ModelConfig,
model,
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules():
linear_method = getattr(module, "linear_method", None)
if linear_method is not None:
linear_method.process_weights_after_loading(module)
return model.eval()


Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,18 @@ def get_quant_config(model_config: ModelConfig,
tqdm_class=DisabledTqdm)
else:
hf_folder = model_name_or_path

possible_config_filenames = quant_cls.get_config_filenames()

# If the quantization config is not found, use the default config.
if not possible_config_filenames:
return quant_cls()

config_files = glob.glob(os.path.join(hf_folder, "*.json"))

quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames())
f.endswith(x) for x in possible_config_filenames)
]
if len(quant_config_files) == 0:
raise ValueError(
Expand Down
Loading