diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml new file mode 100644 index 0000000000000..42936fbfbe7d4 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1 +model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.578 + - name: "exact_match,flexible-extract" + value: 0.585 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 1d1b0ed38671d..109692395acf6 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -5,3 +5,4 @@ Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml +Qwen2-1.5B-Instruct-FP8W8.yaml diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index c4d0c9cb981da..39d00bd5733ff 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -10,7 +10,8 @@ W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensorsScheme, CompressedTensorsUnquantized, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsWNA16) + CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, QuantizationType, find_matched_target, is_activation_quantization_format, @@ -100,14 +101,18 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def get_config_filenames(cls) -> List[str]: return [] - def _check_scheme_supported(self, min_capability: int): + def _check_scheme_supported(self, + min_capability: int, + error: bool = True) -> bool: capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] - if capability < min_capability: + supported = capability >= min_capability + if error and not supported: raise RuntimeError( "Quantization scheme is not supported for ", f"the current GPU. Min capability: {min_capability}. ", f"Current capability: {capability}.") + return supported def _is_static_tensor_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: @@ -170,6 +175,29 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, # All conditions satisfied. return True + def _is_fp8_w8a16(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + # Confirm weights quantized. + if weight_quant is None: + return False + + # Confirm we have floating points. + if weight_quant.type != QuantizationType.FLOAT: + return False + + # Confirm weight scheme is supported. + is_symmetric_weight = weight_quant.symmetric + is_static_weight = not weight_quant.dynamic + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) + if not (is_symmetric_weight and is_static_weight + and is_per_tensor_or_channel_weight): + return False + + # All conditions satisfied. + return True + def _is_wNa16_group_channel(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: input_quant_none = input_quant is None @@ -204,9 +232,23 @@ def _get_scheme_from_parts( # Detect If Activation Quantization. if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8Fp8( + is_fp8_w8a8_supported = self._check_scheme_supported( + CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + if is_fp8_w8a8_supported: + return CompressedTensorsW8A8Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=(not input_quant.dynamic)) + else: + return CompressedTensorsW8A16Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=(input_quant + and not input_quant.dynamic)) + + if self._is_fp8_w8a16(weight_quant, input_quant): + return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=(not input_quant.dynamic)) + is_static_input_scheme=(input_quant + and not input_quant.dynamic)) if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( @@ -257,11 +299,10 @@ def get_scheme( targets=self.target_scheme_map.keys()) # Find the quant_scheme - scheme = self.target_scheme_map[matched_target] - - return self._get_scheme_from_parts( - weight_quant=scheme["weights"], - input_quant=scheme["input_activations"]) + scheme_dict = self.target_scheme_map[matched_target] + scheme = self._get_scheme_from_parts( + weight_quant=scheme_dict["weights"], + input_quant=scheme_dict["input_activations"]) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index dd94c49827f62..ca9e286ce5b2d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -4,6 +4,7 @@ CompressedTensorsW4A16Sparse24) from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 +from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, CompressedTensorsWNA16) @@ -11,6 +12,7 @@ "CompressedTensorsScheme", "CompressedTensorsUnquantized", "CompressedTensorsWNA16", + "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index d5f37b47bb87e..b4bab33e1fb1d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -12,8 +12,9 @@ class CompressedTensorsScheme(ABC): of different quantization schemes supported by CompressedTensors. """ + @classmethod @abstractmethod - def get_min_capability(self) -> int: + def get_min_capability(cls) -> int: """ Get minimum device capability. """ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py index 6203f02d25e90..b7ba29ddc9840 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -18,7 +18,8 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): in a linear transformation. """ - def get_min_capability(self) -> int: + @classmethod + def get_min_capability(cls) -> int: # volta and up return 70 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index eec523d00372c..b8ffb22d7a89d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -29,7 +29,8 @@ def __init__(self, raise ValueError( "group_size must be given when using strategy group") - def get_min_capability(self) -> int: + @classmethod + def get_min_capability(cls) -> int: # ampere + up return 80 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py new file mode 100644 index 0000000000000..eeb7c042e1d1f --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -0,0 +1,105 @@ +from typing import Callable, List, Optional + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationStrategy) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise, create_per_channel_scale_param, + create_per_tensor_scale_param) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsW8A16Fp8"] + +SUPPORTED_STRATEGIES = [ + QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR +] + + +class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): + + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + # W8A8-Fp8 kernels support only per-tensor and per-channel cases. + # So if we have a fused module (QKV, MLP) with per tensor scales, + # we expand each scale to its shard's channels. + def process_weights_after_loading(self, layer) -> None: + if self.strategy == QuantizationStrategy.TENSOR: + ws_channelwise = convert_to_channelwise(layer.weight_scale, + layer.logical_widths) + layer.weight_scale = torch.nn.Parameter(ws_channelwise, + requires_grad=False) + + # Weights must be transposed for marlin + layer.weight = torch.nn.Parameter(layer.weight.t(), + requires_grad=False) + + prepare_fp8_layer_for_marlin(layer, strategy="channel") + + def create_weights(self, layer: torch.nn.Module, input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight = torch.nn.Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + requires_grad=False) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + "weight_loader": weight_loader, + }) + + # WEIGHT SCALE + layer_kwargs = {"weight_loader": weight_loader} + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = create_per_channel_scale_param( + output_partition_sizes, **layer_kwargs) + elif self.strategy == QuantizationStrategy.TENSOR: + weight_scale = create_per_tensor_scale_param( + output_partition_sizes, **layer_kwargs) + else: + raise ValueError( + f"Unsupported weight strategy={self.strategy}, " + f"supported strategies are {SUPPORTED_STRATEGIES}") + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE (to deal with converted checkpoints) + if self.is_static_input_scheme: + input_scale = create_per_tensor_scale_param( + output_partition_sizes, **layer_kwargs) + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_marlin_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 51156a3bc07af..cc9d71db140c2 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -23,7 +23,8 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() - def get_min_capability(self) -> int: + @classmethod + def get_min_capability(cls) -> int: # lovelace and up return 89 @@ -77,19 +78,20 @@ def create_weights(self, layer: torch.nn.Module, }) # WEIGHT SCALE + layer_kwargs = {"weight_loader": weight_loader} if self.strategy == QuantizationStrategy.CHANNEL: weight_scale = create_per_channel_scale_param( - output_partition_sizes, weight_loader=weight_loader) + output_partition_sizes, **layer_kwargs) else: assert self.strategy == QuantizationStrategy.TENSOR weight_scale = create_per_tensor_scale_param( - output_partition_sizes, weight_loader=weight_loader) + output_partition_sizes, **layer_kwargs) layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: input_scale = create_per_tensor_scale_param( - output_partition_sizes, weight_loader=weight_loader) + output_partition_sizes, **layer_kwargs) layer.register_parameter("input_scale", input_scale) def apply_weights(self, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index e81496c89ac7f..3a80863d3abbe 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -19,7 +19,8 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - def get_min_capability(self) -> int: + @classmethod + def get_min_capability(cls) -> int: # turing and up return 75 @@ -68,19 +69,19 @@ def create_weights(self, layer: torch.nn.Module, # WEIGHT SCALE layer_kwargs = {"weight_loader": weight_loader} if self.strategy == QuantizationStrategy.CHANNEL: - scale = create_per_channel_scale_param(output_partition_sizes, - **layer_kwargs) + weight_scale = create_per_channel_scale_param( + output_partition_sizes, **layer_kwargs) else: assert self.strategy == QuantizationStrategy.TENSOR - scale = create_per_tensor_scale_param(output_partition_sizes, - **layer_kwargs) - layer.register_parameter("weight_scale", scale) + weight_scale = create_per_tensor_scale_param( + output_partition_sizes, **layer_kwargs) + layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: - scale = create_per_tensor_scale_param(output_partition_sizes, - **layer_kwargs) - layer.register_parameter("input_scale", scale) + input_scale = create_per_tensor_scale_param( + output_partition_sizes, **layer_kwargs) + layer.register_parameter("input_scale", input_scale) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index e4cf0c0b5d95b..996cba315c556 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -42,7 +42,8 @@ def __init__(self, group_size=self.group_size, is_sym=True) - def get_min_capability(self) -> int: + @classmethod + def get_min_capability(cls) -> int: # ampere and up return 80 diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3a4f2a49a3497..6649b317ca838 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -18,8 +18,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, apply_fp8_linear, create_per_tensor_scale_param, - cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale) + all_close_1d, apply_fp8_linear, convert_to_channelwise, + create_per_tensor_scale_param, cutlass_fp8_supported, + per_tensor_dequantize, requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import print_warning_once @@ -179,19 +180,29 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.input_scale = None - # If checkpoint is fp8, requantize the separately quantized logical - # weights into a single fp8 weight with a single weight scale. + # If checkpoint is fp8, handle that there are N scales for N + # shards in a fused module else: - # Dequant -> Quant with max scale. - max_w_scale, weight = requantize_with_max_scale( - weight=layer.weight, - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths, - ) + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + weight = layer.weight + weight_scale = convert_to_channelwise(layer.weight_scale, + layer.logical_widths) + + # If using w8a8, torch._scaled_mm needs per tensor, so + # requantize the logical shards as a single weight. + else: + # Dequant -> Quant with max scale so we can run per tensor. + weight_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) - layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) if self.quant_config.activation_scheme == "static": layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index c878939580f10..5f9d8658a342f 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -46,7 +46,8 @@ def apply_fp8_marlin_linear( return output.reshape(out_shape) -def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: +def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, + strategy: str = "tensor") -> None: print_warning_once( "Your GPU does not have native support for FP8 computation but " "FP8 quantization is being used. Weight-only FP8 compression will " @@ -74,16 +75,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) # WEIGHT SCALES - # Currently Marlin doesn't support per-tensor scales, so we - # expand it to channelwise - is_channelwise = (len(layer.weight_scale.shape) > 0 - and layer.weight_scale.shape[0] == part_size_n) - if is_channelwise: - scales = layer.weight_scale - else: - scales = layer.weight_scale.repeat(1, part_size_n) - scales = scales.to(layer.orig_dtype).to(device) - + scales = layer.weight_scale.to(layer.orig_dtype) # Permute scales marlin_scales = marlin_permute_scales(s=scales, size_k=part_size_k,