Skip to content

Commit

Permalink
[Bugfix] Respect modules_to_not_convert within awq_marlin (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#9895)

Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin authored Nov 4, 2024
1 parent 2094062 commit 8f0a9ca
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils import replace_parameter
Expand All @@ -36,13 +38,18 @@ class AWQMarlinConfig(QuantizationConfig):
8: scalar_types.uint8,
}

def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
lm_head_quantized: bool) -> None:
def __init__(self,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]] = None) -> None:
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.has_zp = has_zp
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or []

if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
Expand All @@ -52,13 +59,14 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool,

verify_marlin_supported(self.quant_type,
group_size=self.group_size,
has_zp=self.has_zp)
has_zp=self.zero_point)

def __repr__(self) -> str:
return (f"AWQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "
f"has_zp={self.has_zp}, "
f"lm_head_quantized={self.lm_head_quantized})")
f"zero_point={self.zero_point}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"modules_to_not_convert={self.modules_to_not_convert})")

@classmethod
def get_name(cls) -> str:
Expand All @@ -80,10 +88,13 @@ def get_config_filenames(cls) -> List[str]:
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
has_zp = cls.get_from_keys(config, ["zero_point"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, has_zp, lm_head_quantized)
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
modules_to_not_convert)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
Expand All @@ -109,6 +120,8 @@ def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
Expand All @@ -123,7 +136,7 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
has_zp = quant_config.get("zero_point")
zero_point = quant_config.get("zero_point")

if not current_platform.is_cuda():
return False
Expand All @@ -132,15 +145,15 @@ def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
return False

# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or has_zp is None):
if (num_bits is None or group_size is None or zero_point is None):
return False

if num_bits not in cls.TYPE_MAP:
return False

return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
group_size=group_size,
has_zp=has_zp)
has_zp=zero_point)


class AWQMarlinLinearMethod(LinearMethodBase):
Expand Down

0 comments on commit 8f0a9ca

Please sign in to comment.