Skip to content

Commit

Permalink
Add vllm awq loading logic (#11987)
Browse files Browse the repository at this point in the history
* [ADD] Add vllm awq loading logic

* [FIX] fix the module.linear_method path

* [FIX] fix quant_config path error
  • Loading branch information
ACupofAir authored and gc-fu committed Sep 10, 2024
1 parent 058b83c commit f5c55cd
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ def is_linear_module(module):
out_features = module.output_size
result = True
mp_group = None
# Check for attribute qweight
if (not _USE_VLLM_AWQ
and hasattr(module.quant_method, "quant_config")
and module.quant_method.quant_config.get_name() == "awq"):
_USE_VLLM_AWQ = True
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
" support linear layers with skip_bias_add argument")
if isinstance(module, RowParallelLinear) and tp_size >= 2:
Expand Down Expand Up @@ -285,6 +290,65 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
return new_linear


def convert_vllm_awq(module):
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4")

scales = module.scales
wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
dtype=torch.int32) * 4).unsqueeze(0)
# vLLM only supports load 4-bits model, so this has been checked
bits = 4
group_size = module.quant_method.quant_config.group_size

zeros = torch.bitwise_right_shift(
torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
zeros = torch.bitwise_and(zeros, (2 ** bits) - 1)

g_id_map = None

zeros = zeros.reshape(scales.shape)

weight = torch.bitwise_right_shift(
torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2 ** bits) - 1)
weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])

# convert weight to ggml format
weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1])
weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2)
weight = weight.transpose(2, 3)
weight = torch.bitwise_left_shift(weight,
torch.tensor([0, 4], dtype=torch.int8).reshape(1, 1, 1, 2))
weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous()

# convert zeros to ggml format
zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\
.unsqueeze(2)\
.expand(-1, -1, group_size//Q4_1, -1)\
.reshape(zeros.shape[1], -1, 1)\
.contiguous().to(torch.float16)

# convert scales to ggml format
scales = scales.reshape(-1, 1, scales.shape[1]).permute(2, 0, 1)\
.unsqueeze(2)\
.expand(-1, -1, group_size//Q4_1, -1)\
.reshape(scales.shape[-1], -1, 1)\
.contiguous().to(torch.float16)

m = -(zeros * scales)
d = scales

ggml_weight = torch.cat([d.view(torch.uint8),
m.view(torch.uint8),
weight.view(torch.uint8)], dim=-1)
ggml_weight = ggml_weight.reshape([-1])

return ggml_weight, g_id_map


def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4")
Expand Down Expand Up @@ -1917,6 +1981,7 @@ def safe_bmm_fwd(*args, **kwargs):
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
model.generate = MethodType(minicpmv_generate, model)

<<<<<<< HEAD
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
# MiniCPM-V 2
model.llm.config.model_type = "minicpm"
Expand Down Expand Up @@ -1960,3 +2025,6 @@ def safe_bmm_fwd(*args, **kwargs):
model.chat = MethodType(minicpmv_chat, model)

return model
=======
return model
>>>>>>> 56b851402b (Add vllm awq loading logic (#11987))

0 comments on commit f5c55cd

Please sign in to comment.