Skip to content

Commit

Permalink
Review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 committed Apr 11, 2024
1 parent a740d2b commit e20cdc1
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ class ExllamaState(Enum):
class GPTQLinearMethod(LinearMethodBase):
"""Linear method for GPTQ.
Note this linear method holds its own state.
Args:
quant_config: The GPTQ quantization config.
"""

def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config
self.exllama_state = ExllamaState.UNINITIALIZED

def create_weights(
self,
Expand Down Expand Up @@ -191,7 +194,7 @@ def create_weights(
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)

layer.exllama_state = exllama_state
self.exllama_state = exllama_state

def apply_weights(self,
layer: torch.nn.Module,
Expand All @@ -202,18 +205,18 @@ def apply_weights(self,
reshaped_x = x.reshape(-1, x.shape[-1])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
self.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None:
output.add_(bias)
Expand Down

0 comments on commit e20cdc1

Please sign in to comment.