Skip to content

Commit

Permalink
[Kernel][FP8] Initial support with dynamic per-tensor scaling (vllm-p…
Browse files Browse the repository at this point in the history
…roject#4118)

Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
  • Loading branch information
comaniac authored and alexeykondrat committed May 1, 2024
1 parent 3783064 commit 96f4a02
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 5 deletions.
24 changes: 24 additions & 0 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Tests whether FP8 computation is enabled correctly.
Run `pytest tests/quantization/test_fp8.py --forked`.
"""
import pytest
import torch

from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]


@pytest.mark.skipif(
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
reason="FP8 is not supported on this GPU type.")
def test_load_fp16_model(vllm_runner) -> None:
llm = vllm_runner("facebook/opt-125m", quantization="fp8")

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.linear_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn
9 changes: 5 additions & 4 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq" and "squeezellm". If None, we first check
the `quantization_config` attribute in the model config file. If
that is None, we assume the model weights are not quantized and use
`dtype` to determine the data type of the weights.
we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
If None, we first check the `quantization_config` attribute in the
model config file. If that is None, we assume the model weights are
not quantized and use `dtype` to determine the data type of
the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter

from vllm.distributed import (divide, get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -48,6 +49,13 @@ def apply_weights(self,
Expects create_weights to have been called before on the layer."""
raise NotImplementedError

def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return


class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization.
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import FP8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig

QUANTIZATION_METHODS = {
"awq": AWQConfig,
"fp8": FP8Config,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
Expand Down
138 changes: 138 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn import Module
from torch.nn.parameter import Parameter

from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)


class FP8Config(QuantizationConfig):
"""Config class for FP8."""

@classmethod
def get_name(cls) -> str:
return "fp8"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]

@classmethod
def get_min_capability(cls) -> int:
# TODO: PyTorch 2.3.0+ is required to run FP8 on
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
# be included: https://github.com/pytorch/pytorch/pull/118881
return 90

@classmethod
def get_config_filenames(cls) -> List[str]:
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
return cls()

def get_linear_method(self) -> "Fp8LinearMethod":
return Fp8LinearMethod(self)

def get_scaled_act_names(self) -> List[str]:
return []


class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
We now support common FP16/BF16 model checkpoints ONLY. The weight
scaling factor will be initialized after the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""

def __init__(self, quant_config: FP8Config):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, extra_weight_attrs)

w_scale = Parameter(
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("weight_scaling_factor", w_scale)

def process_weights_after_loading(self, layer: Module) -> None:
# Although the linear_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization.
if not hasattr(layer, "weight_scaling_factor"):
return

qweight, weight_scale = per_tensor_quantize(layer.weight)
# torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scaling_factor.data.copy_(weight_scale)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qinput, x_scale = per_tensor_quantize(x)
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scaling_factor,
bias=bias,
)
return output


def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]:
"""Quantize a tensor using per-tensor static scaling factor.
Args:
tensor: The input tensor.
"""
finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
min_val, max_val = tensor.aminmax()
amax = min_val.abs().max(max_val.abs())
scale = finfo.max / amax.clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
scale = scale.float().reciprocal()
return qweight, scale
4 changes: 4 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ def load_model(self, *, model_config: ModelConfig,
model,
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules():
linear_method = getattr(module, "linear_method", None)
if linear_method is not None:
linear_method.process_weights_after_loading(module)
return model.eval()


Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,18 @@ def get_quant_config(model_config: ModelConfig,
tqdm_class=DisabledTqdm)
else:
hf_folder = model_name_or_path

possible_config_filenames = quant_cls.get_config_filenames()

# If the quantization config is not found, use the default config.
if not possible_config_filenames:
return quant_cls()

config_files = glob.glob(os.path.join(hf_folder, "*.json"))

quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames())
f.endswith(x) for x in possible_config_filenames)
]
if len(quant_config_files) == 0:
raise ValueError(
Expand Down

0 comments on commit 96f4a02

Please sign in to comment.