From e20cdc1671830e0bdd295cc24de12868867e1c1e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 11 Apr 2024 11:04:56 -0700 Subject: [PATCH] Review comment --- vllm/model_executor/layers/quantization/gptq.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 4b0ebb19a576f..37163606f9e48 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -80,12 +80,15 @@ class ExllamaState(Enum): class GPTQLinearMethod(LinearMethodBase): """Linear method for GPTQ. + Note this linear method holds its own state. + Args: quant_config: The GPTQ quantization config. """ def __init__(self, quant_config: GPTQConfig): self.quant_config = quant_config + self.exllama_state = ExllamaState.UNINITIALIZED def create_weights( self, @@ -191,7 +194,7 @@ def create_weights( layer.register_parameter("scales", scales) set_weight_attrs(scales, extra_weight_attrs) - layer.exllama_state = exllama_state + self.exllama_state = exllama_state def apply_weights(self, layer: torch.nn.Module, @@ -202,18 +205,18 @@ def apply_weights(self, reshaped_x = x.reshape(-1, x.shape[-1]) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if layer.exllama_state == ExllamaState.UNINITIALIZED: + if self.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: layer.g_idx.data = torch.empty((0, ), device=layer.g_idx.device) - layer.exllama_state = ExllamaState.READY + self.exllama_state = ExllamaState.READY ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, layer.scales, layer.g_idx, - layer.exllama_state == ExllamaState.READY, + self.exllama_state == ExllamaState.READY, self.quant_config.weight_bits) if bias is not None: output.add_(bias)