From 5fa825008edb39d8bf2bde89511b97584c95b854 Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 3 Sep 2024 15:56:50 +0000 Subject: [PATCH 1/2] fix loading for unfused pathway --- vllm/model_executor/layers/linear.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1163cc727762d..97d5498a52d7d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -14,8 +14,10 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.parameter import (BasevLLMParameter, + PackedColumnParameter, PackedvLLMParameter, - PerTensorScaleParameter) + PerTensorScaleParameter, + RowvLLMParameter) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -574,8 +576,8 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, PackedvLLMParameter - ) and param.packed_dim == param.output_dim: + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: shard_size, shard_offset = \ param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset) @@ -594,9 +596,10 @@ def weight_loader_v2(self, param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) return - elif type(param) is BasevLLMParameter: + elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return @@ -724,8 +727,8 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, PackedvLLMParameter - ) and param.packed_dim == param.output_dim: + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: shard_size, shard_offset = \ param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset) @@ -741,12 +744,12 @@ def weight_loader_v2(self, loaded_shard_id: Optional[str] = None): if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): - param.load_merged_column_weight(loaded_weight=loaded_weight, - shard_id=0) + param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) return - elif type(param) is BasevLLMParameter: - param.load_merged_column_weight(loaded_weight=loaded_weight) + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_qkv_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return From 18fca34159c0feda3d1332f23c94d4100fa01805 Mon Sep 17 00:00:00 2001 From: Dipika Date: Wed, 28 Aug 2024 21:47:13 +0000 Subject: [PATCH 2/2] update fbgemm fp8 --- vllm/model_executor/layers/linear.py | 2 +- .../layers/quantization/fbgemm_fp8.py | 34 ++++++++++++------- .../layers/quantization/utils/w8a8_utils.py | 27 --------------- 3 files changed, 22 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 97d5498a52d7d..8d7141584df8e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -26,7 +26,7 @@ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", - "TPUInt8LinearMethod" + "TPUInt8LinearMethod", "FBGEMMFp8LinearMethod" ] diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index e7c3859967c71..3ccf1af9eb898 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -15,8 +15,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, create_per_channel_scale_param) -from vllm.model_executor.utils import set_weight_attrs + apply_fp8_linear) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -85,6 +86,7 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + weight_loader = extra_weight_attrs.get("weight_loader") del input_size, output_size output_size_per_partition = sum(output_partition_sizes) @@ -95,20 +97,21 @@ def create_weights( layer.orig_dtype = params_dtype # WEIGHT - weight = Parameter(torch.empty(output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - requires_grad=False) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - **extra_weight_attrs, - }) # WEIGHT SCALE - weight_scale = create_per_channel_scale_param(output_partition_sizes, - **extra_weight_attrs) + weight_scale = ChannelQuantScaleParameter(data=torch.empty( + (sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE UPPER BOUND @@ -118,6 +121,11 @@ def create_weights( layer.input_scale_ub = input_scale_ub def process_weights_after_loading(self, layer: Module) -> None: + # required by torch.compile + layer.weight_scale = Parameter(layer.weight_scale.data, + requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) + weight = layer.weight layer.weight = Parameter(weight.t(), requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 6cc1c65ddfa82..a54e3cae73b14 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,10 +1,8 @@ from typing import List, Optional, Tuple, Union import torch -from torch.nn import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import is_hip @@ -38,31 +36,6 @@ def all_close_1d(x: torch.Tensor) -> bool: return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) -def create_per_tensor_scale_param( - output_partition_sizes: List[int], - **extra_weight_attrs, -) -> Parameter: - scale = Parameter(torch.empty(len(output_partition_sizes), - dtype=torch.float32), - requires_grad=False) - scale[:] = torch.finfo(torch.float32).min - set_weight_attrs(scale, { - "needs_scalar_to_array": True, - **extra_weight_attrs - }) - return scale - - -def create_per_channel_scale_param(output_partition_sizes: List[int], - **extra_weight_attrs) -> Parameter: - scale = Parameter(torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), - requires_grad=False) - scale[:] = torch.finfo(torch.float32).min - set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs}) - return scale - - def convert_to_channelwise( weight_scale: torch.Tensor, logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: