Skip to content

Commit

Permalink
[Misc] Update fbgemmfp8 to use vLLMParameters (vllm-project#7972)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
dsikka and mgoin authored Sep 4, 2024
1 parent 61f4a93 commit e16fa99
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 41 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod"
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod"
]


Expand Down
34 changes: 21 additions & 13 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs
apply_fp8_linear)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.platforms import current_platform

logger = init_logger(__name__)
Expand Down Expand Up @@ -85,6 +86,7 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)

Expand All @@ -95,20 +97,21 @@ def create_weights(
layer.orig_dtype = params_dtype

# WEIGHT
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
**extra_weight_attrs,
})

# WEIGHT SCALE
weight_scale = create_per_channel_scale_param(output_partition_sizes,
**extra_weight_attrs)
weight_scale = ChannelQuantScaleParameter(data=torch.empty(
(sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)

# INPUT SCALE UPPER BOUND
Expand All @@ -118,6 +121,11 @@ def create_weights(
layer.input_scale_ub = input_scale_ub

def process_weights_after_loading(self, layer: Module) -> None:
# required by torch.compile
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)

weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)

Expand Down
27 changes: 0 additions & 27 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import List, Optional, Tuple, Union

import torch
from torch.nn import Parameter

from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hip

Expand Down Expand Up @@ -38,31 +36,6 @@ def all_close_1d(x: torch.Tensor) -> bool:
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))


def create_per_tensor_scale_param(
output_partition_sizes: List[int],
**extra_weight_attrs,
) -> Parameter:
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**extra_weight_attrs
})
return scale


def create_per_channel_scale_param(output_partition_sizes: List[int],
**extra_weight_attrs) -> Parameter:
scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
return scale


def convert_to_channelwise(
weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down

0 comments on commit e16fa99

Please sign in to comment.