diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 6eb7ff72fb11d..d5472f97a1c50 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -11,14 +11,18 @@ CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationType) @pytest.mark.parametrize("model_args", [ - ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor"), - ("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel"), + ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor", + QuantizationType.INT, 2560), + ("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel", + QuantizationType.INT, 2560), ]) def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): - model_path, strategy = model_args + model_path, strategy, quant_type, shape_0 = model_args with vllm_runner(model_path, enforce_eager=True) as llm: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 layer = model.model.layers[0] @@ -34,17 +38,23 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): CompressedTensorsLinearMethod) assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) assert qkv_proj.scheme.strategy == strategy - assert qkv_proj.weight.dtype is torch.int8 - assert o_proj.weight.dtype is torch.int8 - assert gate_up_proj.weight.dtype is torch.int8 + expected_type = (torch.int8 if quant_type == QuantizationType.INT else + torch.float8_e4m3fn) + + assert qkv_proj.weight.dtype is expected_type + assert o_proj.weight.dtype is expected_type + assert gate_up_proj.weight.dtype is expected_type if qkv_proj.scheme.strategy == "tensor": - assert qkv_proj.weight_scale.shard_splitter is not None - assert qkv_proj.weight_scale.logical_widths is not None + # Make sure it is a channelwise buffer + # After running process_weights_after_loading + assert len(qkv_proj.weight_scale.shape) == 2 + assert qkv_proj.weight_scale.shape[0] == shape_0 + assert qkv_proj.weight_scale.shape[1] == 1 + assert qkv_proj.weight_scale.dtype is torch.float32 assert qkv_proj.input_scale.dtype is torch.float32 diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 74d21ead042ed..4d76ae7072f3d 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -9,6 +9,23 @@ from vllm._custom_ops import scaled_fp8_quant from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +MODELS = [ + "neuralmagic/Meta-Llama-3-8B-Instruct-FP8", + "nm-testing/Phi-3-mini-128k-instruct-FP8", +] + + +@pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.") +@pytest.mark.parametrize("model", MODELS) +def test_model_load_and_run(vllm_runner, model: str): + with vllm_runner(model) as llm: + # note: this does not test accuracy, just that we can run through + # see lm-eval tests for accuracy + outputs = llm.generate_greedy(prompts=["Hello my name is"], + max_tokens=10) + print(outputs[0][1]) + @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="FP8 is not supported on this GPU type.") diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d221fecd66ff1..3cc257834033a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -41,6 +41,29 @@ def adjust_bitsandbytes_shard(param: Parameter, return quantized_size, quantized_offset +def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): + """For fused modules (QKV and MLP) we have an array of length + N that holds 1 scale for each "logical" matrix. So the param + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on + the shard_id for loading. + """ + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, str): + shard_id = qkv_idxs[shard_id] + elif not isinstance(shard_id, int): + raise ValueError(f"Unknown Shard Id {shard_id}") + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + return param[shard_id], loaded_weight + + class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -358,37 +381,15 @@ def weight_loader(self, output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) - - param_shard_splitter = getattr(param, "shard_splitter", None) - - if output_dim is not None and param_shard_splitter is not None: - raise NotImplementedError( - "We do not currently support output_dim != None and " - "shard_splitter != None for a parameter. Please open an issue." - ) - # If a parameter has defined a shard_splitter to be used for - # the weight, it should be applied before the weight is - # loaded/copied to the parameter. The shard_splitter applies - # logic by using the loaded_shard_id to ensure that the loaded - # param is loaded to the correct location - # within the parameter defined by the linear method. - if loaded_shard_id is None and param_shard_splitter is not None: - raise NotImplementedError( - "We do not currently support loaded_shard_id == None and " - "shard_splitter != None for a parameter. Please open an issue." - ) - - # Special case for Fp8 scales. - fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", - None) + # Special case for per-tensor scale to load scalar into fused array. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) if loaded_shard_id is None: # Loaded weight is already fused on disk (qkv/mlp). if output_dim is None: - # If fp8 + scale, need to send to each shard. - if fp8_scales_shard_indexer is not None: - param_data, loaded_weight = fp8_scales_shard_indexer( - param_data, loaded_weight, loaded_shard_id) + if needs_scalar_to_array is not None: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -450,15 +451,9 @@ def weight_loader(self, shard_offset = loaded_shard_id * shard_size param_data = param_data.narrow(0, shard_offset, shard_size) - # If a param_shard_splitter is defined by the LinearMethod, use it. - elif param_shard_splitter is not None: - logical_widths = getattr(param, "logical_widths", None) - param_data, loaded_weight = param_shard_splitter( - param_data, loaded_weight, loaded_shard_id, logical_widths) - - # Special case for Fp8 scales. - elif fp8_scales_shard_indexer is not None: - param_data, loaded_weight = fp8_scales_shard_indexer( + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight, loaded_shard_id) else: @@ -548,36 +543,15 @@ def weight_loader(self, # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) - param_shard_splitter = getattr(param, "shard_splitter", None) - - if output_dim is not None and param_shard_splitter is not None: - raise NotImplementedError( - "We do not currently support output_dim != None and " - "shard_splitter != None for a parameter. Please open an issue." - ) - # If a parameter has defined a shard_splitter to be used for - # the weight, it should be applied before the weight is - # loaded/copied to the parameter. The shard_splitter applies - # logic by using the loaded_shard_id to ensure that the loaded - # param is loaded to the correct location - # within the parameter defined by the linear method. - if loaded_shard_id is None and param_shard_splitter is not None: - raise NotImplementedError( - "We do not currently support loaded_shard_id == None and " - "shard_splitter != None for a parameter. Please open an issue." - ) - - # Special case for Fp8 scales. - fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", - None) + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) if loaded_shard_id is None: # Loaded weight is already fused on disk (qkv/mlp). if output_dim is None: - # If fp8 + scale, need to send to each shard. - if fp8_scales_shard_indexer is not None: - param_data, loaded_weight = fp8_scales_shard_indexer( - param_data, loaded_weight, loaded_shard_id) + if needs_scalar_to_array is not None: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -667,15 +641,9 @@ def weight_loader(self, shard_index = ["q", "k", "v"].index(loaded_shard_id) param_data = param_data.narrow(0, shard_index * shard_size, shard_size) - # If a param_shard_splitter is defined by the LinearMethod, use it. - elif param_shard_splitter is not None: - logical_widths = getattr(param, "logical_widths", None) - param_data, loaded_weight = param_shard_splitter( - param_data, loaded_weight, loaded_shard_id, logical_widths) - - # Special case for Fp8 scales. - elif fp8_scales_shard_indexer is not None: - param_data, loaded_weight = fp8_scales_shard_indexer( + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight, loaded_shard_id) else: ignore_warning = getattr(param, "ignore_warning", False) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index a451427ec93f2..664eac3f9e97b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -186,6 +186,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase): def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + return layer.scheme.process_weights_after_loading(layer) + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index 3a5904208656e..119f6cd91bb0c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -31,3 +31,11 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): """ raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py index 0cfac13d1ca25..f5911bc3dabb5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -18,6 +18,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): in a linear transformation. """ + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size_per_partition: int, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 607029c819ddb..3c07d6b6fe5c1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -29,6 +29,9 @@ def __init__(self, raise ValueError( "group_size must be given when using strategy group") + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + def create_weights(self, layer: torch.nn.Module, input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py index efed79ec7a11c..49779057659f0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py @@ -15,70 +15,63 @@ class CompressedTensorsW8A8(CompressedTensorsScheme): def __init__(self, strategy: str): self.strategy = strategy - def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: - if isinstance(shard_id, int): - return shard_id - - assert isinstance(shard_id, str) - qkv_idxs = {"q": 0, "k": 1, "v": 2} - assert shard_id in qkv_idxs - return qkv_idxs[shard_id] - - def scales_shard_splitter( - self, param: torch.Tensor, loaded_weight: torch.Tensor, - shard_id: Union[str, int], - logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - shard_id = self._shard_id_as_int(shard_id) - offset = sum(logical_widths[:shard_id]) - size = logical_widths[shard_id] - # update loaded weight with copies for broadcast. - loaded_weight = loaded_weight.repeat(size) - return param[offset:offset + size], loaded_weight + # Cutlass kernels support only per-tensor and per-channel cases. + # So if we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), we convert to the per-channel case. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if (self.strategy == QuantizationStrategy.TENSOR + and len(self.logical_widths) > 1): + + # Load the N per-tensor scales into the channelwise buffer. + weight_scale_channel = torch.empty( + (sum(self.logical_widths), 1), + dtype=torch.float32, + device=layer.weight_scale.device) + start = 0 + for idx, logical_width in enumerate(self.logical_widths): + end = start + logical_width + weight_scale_channel[start:end, :] = layer.weight_scale[idx] + start = end + + layer.weight_scale = Parameter(weight_scale_channel, + requires_grad=False) def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + self.logical_widths = output_partition_sizes - is_tensor_partitioned = len(output_partition_sizes) != 1 - weight_scale_dim = sum(output_partition_sizes) if ( - is_tensor_partitioned - or self.strategy == QuantizationStrategy.CHANNEL) else 1 - - shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, ) + # WEIGHT SCALE + shape: Union[Tuple[int], Tuple[int, int]] if self.strategy == QuantizationStrategy.CHANNEL: - shape = (weight_scale_dim, 1) + shape = (sum(self.logical_widths), 1) + else: + shape = (len(self.logical_widths), ) weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32), requires_grad=False) - layer.register_parameter("weight_scale", weight_scale) - set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + if self.strategy == QuantizationStrategy.CHANNEL: + set_weight_attrs(weight_scale, { + "weight_loader": weight_loader, + "output_dim": 0, + }) + else: + set_weight_attrs(weight_scale, { + "weight_loader": weight_loader, + "needs_scalar_to_array": True, + }) + # WEIGHT weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8), requires_grad=False) - layer.register_parameter("weight", weight) - set_weight_attrs( - weight, { - "input_dim": 1, - "output_dim": 0, - "weight_loader": weight_loader, - "logical_widths": output_partition_sizes - }) - - # Don't need a shard_splitter for channel-wise quantization - # Use the default loading method - if self.strategy == QuantizationStrategy.CHANNEL: - set_weight_attrs(weight_scale, { - "output_dim": 0, - }) - else: - set_weight_attrs( - weight_scale, { - "logical_widths": output_partition_sizes, - "shard_splitter": self.scales_shard_splitter, - }) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + "weight_loader": weight_loader, + }) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 7707ea6ee94bc..2243260053ef5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -29,6 +29,9 @@ def __init__(self, raise ValueError( "group_size must be given when using strategy group") + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + def create_weights(self, layer: torch.nn.Module, input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1c760566c28d7..df6fe4c3d572e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torch from torch.nn import Module @@ -98,7 +98,6 @@ class Fp8LinearMethod(LinearMethodBase): """ def __init__(self, quant_config: Fp8Config): - self.fused_module_in_checkpoint = False self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -114,12 +113,10 @@ def _create_scale_param( requires_grad=False) scale[:] = torch.finfo(torch.float8_e4m3fn).min layer.register_parameter(scale_name, scale) - set_weight_attrs( - scale, { - **extra_weight_attrs, - "fp8_scales_shard_indexer": - self.scales_shard_indexer, - }) + set_weight_attrs(scale, { + **extra_weight_attrs, + "needs_scalar_to_array": True, + }) def create_weights( self, @@ -170,26 +167,6 @@ def create_weights( output_partition_sizes=output_partition_sizes, **extra_weight_attrs) - def scales_shard_indexer( - self, param: torch.Tensor, loaded_weight: torch.Tensor, - shard_id: Optional[Union[str, - int]]) -> Tuple[torch.Tensor, torch.Tensor]: - qkv_idxs = {"q": 0, "k": 1, "v": 2} - - if shard_id is None: - shard_id = 0 - self.fused_module_in_checkpoint = True - elif isinstance(shard_id, int): - pass - elif isinstance(shard_id, str): - if shard_id not in qkv_idxs: - raise ValueError(f"Unknown shard_id: {shard_id}") - shard_id = qkv_idxs[shard_id] - else: - ValueError(f"Shard id must be int or str but got {type(shard_id)}") - - return param[shard_id], loaded_weight - def process_weights_after_loading(self, layer: Module) -> None: if (not hasattr(layer, "process_after_load") or not layer.process_after_load): @@ -212,7 +189,17 @@ def process_weights_after_loading(self, layer: Module) -> None: # Loop over logical weights, requantizing with single scale. max_w_scale = layer.weight_scale.max() - if not self.fused_module_in_checkpoint: + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. As a result, we skip dequant -> requant + # since we already have quantized QKV together. + # Sample Model with fused checkpoint: + # * nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = ( + layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min) + + if unfused_module_in_checkpoint: start = 0 for idx, logical_width in enumerate(layer.logical_widths): end = start + logical_width