From 8f0a9ca890a125f2b0fef49ba042ecf5b37830a8 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 4 Nov 2024 18:57:44 -0500 Subject: [PATCH] [Bugfix] Respect modules_to_not_convert within awq_marlin (#9895) Signed-off-by: mgoin --- .../layers/quantization/awq_marlin.py | 35 +++++++++++++------ 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 95ec12daeeeb5..ea69bee45f8d9 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -9,7 +9,9 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod, set_weight_attrs) +from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils import replace_parameter @@ -36,13 +38,18 @@ class AWQMarlinConfig(QuantizationConfig): 8: scalar_types.uint8, } - def __init__(self, weight_bits: int, group_size: int, has_zp: bool, - lm_head_quantized: bool) -> None: + def __init__(self, + weight_bits: int, + group_size: int, + zero_point: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[List[str]] = None) -> None: self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size - self.has_zp = has_zp + self.zero_point = zero_point self.lm_head_quantized = lm_head_quantized self.weight_bits = weight_bits + self.modules_to_not_convert = modules_to_not_convert or [] if self.weight_bits not in self.TYPE_MAP: raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " @@ -52,13 +59,14 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool, verify_marlin_supported(self.quant_type, group_size=self.group_size, - has_zp=self.has_zp) + has_zp=self.zero_point) def __repr__(self) -> str: return (f"AWQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " - f"has_zp={self.has_zp}, " - f"lm_head_quantized={self.lm_head_quantized})") + f"zero_point={self.zero_point}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"modules_to_not_convert={self.modules_to_not_convert})") @classmethod def get_name(cls) -> str: @@ -80,10 +88,13 @@ def get_config_filenames(cls) -> List[str]: def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) - has_zp = cls.get_from_keys(config, ["zero_point"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) - return cls(weight_bits, group_size, has_zp, lm_head_quantized) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None) + return cls(weight_bits, group_size, zero_point, lm_head_quantized, + modules_to_not_convert) @classmethod def override_quantization_method(cls, hf_quant_cfg, @@ -109,6 +120,8 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): return AWQMoEMethod(self) @@ -123,7 +136,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") group_size = quant_config.get("group_size") - has_zp = quant_config.get("zero_point") + zero_point = quant_config.get("zero_point") if not current_platform.is_cuda(): return False @@ -132,7 +145,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): return False # If we cannot find the info needed in the config, cannot convert. - if (num_bits is None or group_size is None or has_zp is None): + if (num_bits is None or group_size is None or zero_point is None): return False if num_bits not in cls.TYPE_MAP: @@ -140,7 +153,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, - has_zp=has_zp) + has_zp=zero_point) class AWQMarlinLinearMethod(LinearMethodBase):