Skip to content

Commit

Permalink
Add vllm awq loading logic (#11987)
Browse files Browse the repository at this point in the history
* [ADD] Add vllm awq loading logic

* [FIX] fix the module.linear_method path

* [FIX] fix quant_config path error
  • Loading branch information
ACupofAir authored Sep 6, 2024
1 parent f008ea0 commit 56b8514
Showing 1 changed file with 132 additions and 2 deletions.
134 changes: 132 additions & 2 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

_IS_VLLM_AVAILABLE = None
_USE_VLLM = False
_USE_VLLM_AWQ = False
_VLLM_VERSION = None


Expand Down Expand Up @@ -143,7 +144,7 @@ def is_linear_module(module):
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
if is_vllm_available():
# Only convert vllm modules
global _VLLM_VERSION
global _VLLM_VERSION, _USE_VLLM_AWQ
if _VLLM_VERSION is None:
_VLLM_VERSION = get_package_version('vllm')
from vllm.model_executor.layers.linear import (
Expand Down Expand Up @@ -180,6 +181,11 @@ def is_linear_module(module):
out_features = module.output_size
result = True
mp_group = None
# Check for attribute qweight
if (not _USE_VLLM_AWQ
and hasattr(module.quant_method, "quant_config")
and module.quant_method.quant_config.get_name() == "awq"):
_USE_VLLM_AWQ = True
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
" support linear layers with skip_bias_add argument")
if isinstance(module, RowParallelLinear) and tp_size >= 2:
Expand Down Expand Up @@ -284,6 +290,65 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
return new_linear


def convert_vllm_awq(module):
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4")

scales = module.scales
wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
dtype=torch.int32) * 4).unsqueeze(0)
# vLLM only supports load 4-bits model, so this has been checked
bits = 4
group_size = module.quant_method.quant_config.group_size

zeros = torch.bitwise_right_shift(
torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
zeros = torch.bitwise_and(zeros, (2 ** bits) - 1)

g_id_map = None

zeros = zeros.reshape(scales.shape)

weight = torch.bitwise_right_shift(
torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2 ** bits) - 1)
weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])

# convert weight to ggml format
weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1])
weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2)
weight = weight.transpose(2, 3)
weight = torch.bitwise_left_shift(weight,
torch.tensor([0, 4], dtype=torch.int8).reshape(1, 1, 1, 2))
weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous()

# convert zeros to ggml format
zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\
.unsqueeze(2)\
.expand(-1, -1, group_size//Q4_1, -1)\
.reshape(zeros.shape[1], -1, 1)\
.contiguous().to(torch.float16)

# convert scales to ggml format
scales = scales.reshape(-1, 1, scales.shape[1]).permute(2, 0, 1)\
.unsqueeze(2)\
.expand(-1, -1, group_size//Q4_1, -1)\
.reshape(scales.shape[-1], -1, 1)\
.contiguous().to(torch.float16)

m = -(zeros * scales)
d = scales

ggml_weight = torch.cat([d.view(torch.uint8),
m.view(torch.uint8),
weight.view(torch.uint8)], dim=-1)
ggml_weight = ggml_weight.reshape([-1])

return ggml_weight, g_id_map


def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4")
Expand Down Expand Up @@ -389,6 +454,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
FP16Linear, BF16Linear
from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding
has_been_replaced = False
global _USE_VLLM_AWQ

for name, module in model.named_children():
is_linear, linear_args = is_linear_module(module)
Expand Down Expand Up @@ -453,6 +519,70 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if has_bias:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device)
elif _USE_VLLM_AWQ:
# User load an AWQ quantized model from vLLM
from ipex_llm.transformers.low_bit_linear import vLLMLowBitLinear
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
has_bias = module.bias is not None and module.bias.abs().sum() != 0
if isinstance(module, ParallelLMHead):
new_linear = LowBitLinear(
in_features,
out_features,
qtype=qtype,
bias=has_bias,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=False,
act_order=act_order,
enable_scale_search=enable_scale_search,
)
device = module.weight.data.device
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
imatrix_data,
model_config)
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
convert_shape_only=convert_shape_only,
qtype=cur_qtype,
imatrix=cur_imatrix,
in_features=in_features,
enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search).to(device)
else:
new_linear = vLLMLowBitLinear(
in_features,
out_features,
qtype=qtype,
bias=has_bias,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=False,
act_order=act_order,
enable_scale_search=enable_scale_search,
)
device = module.qweight.data.device
invalidInputError(device.type != "meta",
"converting from meta device is not supported")
weight, g_idx_map = convert_vllm_awq(module)
if act_order:
new_linear.g_idx_map = g_idx_map
# Copy the weights
paramsLowBit = FP4Params(data=weight,
requires_grad=False,
quantized=True,
_shape=(out_features, in_features),
convert_shape_only=convert_shape_only,
qtype=qtype,
enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search).to(device)
new_linear._parameters['weight'] = paramsLowBit
if has_bias:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device)
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
if in_features % 64 != 0:
# now our kernel requires in_features is a multiple of 64
Expand Down Expand Up @@ -1871,4 +2001,4 @@ def safe_bmm_fwd(*args, **kwargs):
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
model.generate = MethodType(minicpmv_generate, model)

return model
return model

0 comments on commit 56b8514

Please sign in to comment.