Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vllm awq loading logic #11987

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 132 additions & 2 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

_IS_VLLM_AVAILABLE = None
_USE_VLLM = False
_USE_VLLM_AWQ = False
_VLLM_VERSION = None


Expand Down Expand Up @@ -143,7 +144,7 @@ def is_linear_module(module):
is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
if is_vllm_available():
# Only convert vllm modules
global _VLLM_VERSION
global _VLLM_VERSION, _USE_VLLM_AWQ
if _VLLM_VERSION is None:
_VLLM_VERSION = get_package_version('vllm')
from vllm.model_executor.layers.linear import (
Expand Down Expand Up @@ -180,6 +181,11 @@ def is_linear_module(module):
out_features = module.output_size
result = True
mp_group = None
# Check for attribute qweight
if (not _USE_VLLM_AWQ
and hasattr(module.quant_method, "quant_config")
and module.quant_method.quant_config.get_name() == "awq"):
_USE_VLLM_AWQ = True
invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not"
" support linear layers with skip_bias_add argument")
if isinstance(module, RowParallelLinear) and tp_size >= 2:
Expand Down Expand Up @@ -284,6 +290,65 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype,
return new_linear


def convert_vllm_awq(module):
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4")

scales = module.scales
wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
dtype=torch.int32) * 4).unsqueeze(0)
# vLLM only supports load 4-bits model, so this has been checked
bits = 4
group_size = module.quant_method.quant_config.group_size

zeros = torch.bitwise_right_shift(
torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
zeros = torch.bitwise_and(zeros, (2 ** bits) - 1)

g_id_map = None

zeros = zeros.reshape(scales.shape)

weight = torch.bitwise_right_shift(
torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2 ** bits) - 1)
weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])

# convert weight to ggml format
weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1])
weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2)
weight = weight.transpose(2, 3)
weight = torch.bitwise_left_shift(weight,
torch.tensor([0, 4], dtype=torch.int8).reshape(1, 1, 1, 2))
weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous()

# convert zeros to ggml format
zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\
.unsqueeze(2)\
.expand(-1, -1, group_size//Q4_1, -1)\
.reshape(zeros.shape[1], -1, 1)\
.contiguous().to(torch.float16)

# convert scales to ggml format
scales = scales.reshape(-1, 1, scales.shape[1]).permute(2, 0, 1)\
.unsqueeze(2)\
.expand(-1, -1, group_size//Q4_1, -1)\
.reshape(scales.shape[-1], -1, 1)\
.contiguous().to(torch.float16)

m = -(zeros * scales)
d = scales

ggml_weight = torch.cat([d.view(torch.uint8),
m.view(torch.uint8),
weight.view(torch.uint8)], dim=-1)
ggml_weight = ggml_weight.reshape([-1])

return ggml_weight, g_id_map


def convert_gptq(module, awq=False, llm_awq=False, act_order=False):
from ipex_llm.transformers.low_bit_linear import get_block_size
Q4_1 = get_block_size("asym_int4")
Expand Down Expand Up @@ -389,6 +454,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
FP16Linear, BF16Linear
from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding
has_been_replaced = False
global _USE_VLLM_AWQ

for name, module in model.named_children():
is_linear, linear_args = is_linear_module(module)
Expand Down Expand Up @@ -453,6 +519,70 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if has_bias:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device)
elif _USE_VLLM_AWQ:
# User load an AWQ quantized model from vLLM
from ipex_llm.transformers.low_bit_linear import vLLMLowBitLinear
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
has_bias = module.bias is not None and module.bias.abs().sum() != 0
if isinstance(module, ParallelLMHead):
new_linear = LowBitLinear(
in_features,
out_features,
qtype=qtype,
bias=has_bias,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=False,
act_order=act_order,
enable_scale_search=enable_scale_search,
)
device = module.weight.data.device
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
imatrix_data,
model_config)
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
convert_shape_only=convert_shape_only,
qtype=cur_qtype,
imatrix=cur_imatrix,
in_features=in_features,
enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search).to(device)
else:
new_linear = vLLMLowBitLinear(
in_features,
out_features,
qtype=qtype,
bias=has_bias,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=False,
act_order=act_order,
enable_scale_search=enable_scale_search,
)
device = module.qweight.data.device
invalidInputError(device.type != "meta",
"converting from meta device is not supported")
weight, g_idx_map = convert_vllm_awq(module)
if act_order:
new_linear.g_idx_map = g_idx_map
# Copy the weights
paramsLowBit = FP4Params(data=weight,
requires_grad=False,
quantized=True,
_shape=(out_features, in_features),
convert_shape_only=convert_shape_only,
qtype=qtype,
enable_xetla=enable_xetla,
enable_scale_search=enable_scale_search).to(device)
new_linear._parameters['weight'] = paramsLowBit
if has_bias:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device)
elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
if in_features % 64 != 0:
# now our kernel requires in_features is a multiple of 64
Expand Down Expand Up @@ -1871,4 +2001,4 @@ def safe_bmm_fwd(*args, **kwargs):
minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate)
model.generate = MethodType(minicpmv_generate, model)

return model
return model