Skip to content

Commit

Permalink
[Misc] GPTQ Activation Ordering (vllm-project#8135)
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Garg <[email protected]>
  • Loading branch information
kylesayrs authored and garg-amit committed Oct 28, 2024
1 parent 42b962a commit e37b38e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 15 deletions.
1 change: 1 addition & 0 deletions tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def _get_scheme_from_parts(
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)

# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
ActivationOrdering)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported,
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types

__all__ = ["CompressedTensorsWNA16"]
Expand All @@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
group_size: Optional[int] = None,
actorder: Optional[ActivationOrdering] = None):

self.pack_factor = 32 // num_bits
self.strategy = strategy
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP

if self.group_size == -1 and self.strategy != "channel":
raise ValueError("Marlin kernels require group quantization or "
Expand Down Expand Up @@ -64,12 +70,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
output_size_per_partition = sum(output_partition_sizes)

# If group_size is -1, we are in channelwise case.
channelwise = (self.group_size == -1)
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
partition_scales = (row_parallel and not channelwise)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)

verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
Expand Down Expand Up @@ -123,6 +127,16 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)

# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)

layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
Expand All @@ -137,9 +151,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)

# Act-order not supported in compressed-tensors yet, so set to empty.
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# Handle sorting for activation reordering if needed.
if self.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "weight_g_idx", g_idx)
else:
layer.weight_g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)

# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
Expand All @@ -159,9 +178,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
replace_tensor(layer, "weight_packed", marlin_qweight)

# Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales = marlin_permute_scales(
layer.weight_scale,
size_k=layer.input_size_per_partition,
size_k=(layer.input_size
if self.has_g_idx else layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
Expand All @@ -174,7 +195,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
weight=layer.weight_packed,
weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp,
g_idx=layer.g_idx,
g_idx=layer.weight_g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_type,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import re
from enum import Enum
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, Optional, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from torch.nn import Module

from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
TOKEN = "token"


class ActivationOrdering(str, Enum):
"""
Enum storing strategies for activation ordering
Group: reorder groups and weight\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder\n
"""

GROUP = "group"
WEIGHT = "weight"


class QuantizationArgs(BaseModel):
"""
User facing arguments used to define a quantization config
Expand All @@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to None for arbitrary ordering
"""

num_bits: int = 8
Expand All @@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
actorder: Union[ActivationOrdering, bool, None] = None
observer: str = Field(
default="minmax",
description=("The class to use to compute the quantization param - "
Expand All @@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel):
"Observers constructor excluding quantization range or symmetry"),
)

@field_validator("actorder", mode="before")
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
if isinstance(value, bool):
return ActivationOrdering.GROUP if value else None

if isinstance(value, str):
return ActivationOrdering(value.lower())

return value


def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
Expand Down

0 comments on commit e37b38e

Please sign in to comment.