Skip to content

Commit

Permalink
[Misc] Update gptq_marlin_24 to use vLLMParameters (vllm-project#7762)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
2 people authored and triple-Mu committed Sep 4, 2024
1 parent b613f6e commit d5cd5b8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 54 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod"
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod"
]


Expand Down
102 changes: 49 additions & 53 deletions vllm/model_executor/layers/quantization/gptq_marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types

logger = init_logger(__name__)
Expand Down Expand Up @@ -149,7 +152,7 @@ def create_weights(
**extra_weight_attrs,
):
del output_size # Unused.

weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
Expand Down Expand Up @@ -187,87 +190,80 @@ def create_weights(
"Each permutation group must reside on the same gpu")

# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size // 2,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader)

# Meta
meta = Parameter(
torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
requires_grad=False,
)
set_weight_attrs(
meta,
{
"input_dim": 0,
"packed_dim": 1,
"pack_factor": 1,
"output_dim": 1,
"marlin_tile_size": 2,
},
)
meta = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)

# Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)

scales = Parameter(
weight_scale_args = {
"data":
torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": None if input_groups == 1 else 0,
"output_dim": 1,
},
)
"weight_loader":
weight_loader
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)

# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)

workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
weight_loader=weight_loader)

layer.register_parameter("B_24", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("B_meta", meta)
set_weight_attrs(meta, extra_weight_attrs)
layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
layer.s = Parameter(layer.s.data, requires_grad=False)
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)

def apply(
self,
Expand Down

0 comments on commit d5cd5b8

Please sign in to comment.