From 56b851402bc361ebc821c307e0fd4aa8c5c2d7d8 Mon Sep 17 00:00:00 2001 From: Jun Wang Date: Fri, 6 Sep 2024 15:06:56 +0800 Subject: [PATCH] Add vllm awq loading logic (#11987) * [ADD] Add vllm awq loading logic * [FIX] fix the module.linear_method path * [FIX] fix quant_config path error --- .../llm/src/ipex_llm/transformers/convert.py | 134 +++++++++++++++++- 1 file changed, 132 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 1ab44f91b54..f94eb3f6c36 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -55,6 +55,7 @@ _IS_VLLM_AVAILABLE = None _USE_VLLM = False +_USE_VLLM_AWQ = False _VLLM_VERSION = None @@ -143,7 +144,7 @@ def is_linear_module(module): is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM) if is_vllm_available(): # Only convert vllm modules - global _VLLM_VERSION + global _VLLM_VERSION, _USE_VLLM_AWQ if _VLLM_VERSION is None: _VLLM_VERSION = get_package_version('vllm') from vllm.model_executor.layers.linear import ( @@ -180,6 +181,11 @@ def is_linear_module(module): out_features = module.output_size result = True mp_group = None + # Check for attribute qweight + if (not _USE_VLLM_AWQ + and hasattr(module.quant_method, "quant_config") + and module.quant_method.quant_config.get_name() == "awq"): + _USE_VLLM_AWQ = True invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not" " support linear layers with skip_bias_add argument") if isinstance(module, RowParallelLinear) and tp_size >= 2: @@ -284,6 +290,65 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype, return new_linear +def convert_vllm_awq(module): + from ipex_llm.transformers.low_bit_linear import get_block_size + Q4_1 = get_block_size("asym_int4") + + scales = module.scales + wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7], + dtype=torch.int32) * 4).unsqueeze(0) + # vLLM only supports load 4-bits model, so this has been checked + bits = 4 + group_size = module.quant_method.quant_config.group_size + + zeros = torch.bitwise_right_shift( + torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits), + wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) + zeros = torch.bitwise_and(zeros, (2 ** bits) - 1) + + g_id_map = None + + zeros = zeros.reshape(scales.shape) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits), + wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2 ** bits) - 1) + weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2]) + + # convert weight to ggml format + weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1]) + weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2) + weight = weight.transpose(2, 3) + weight = torch.bitwise_left_shift(weight, + torch.tensor([0, 4], dtype=torch.int8).reshape(1, 1, 1, 2)) + weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous() + + # convert zeros to ggml format + zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\ + .unsqueeze(2)\ + .expand(-1, -1, group_size//Q4_1, -1)\ + .reshape(zeros.shape[1], -1, 1)\ + .contiguous().to(torch.float16) + + # convert scales to ggml format + scales = scales.reshape(-1, 1, scales.shape[1]).permute(2, 0, 1)\ + .unsqueeze(2)\ + .expand(-1, -1, group_size//Q4_1, -1)\ + .reshape(scales.shape[-1], -1, 1)\ + .contiguous().to(torch.float16) + + m = -(zeros * scales) + d = scales + + ggml_weight = torch.cat([d.view(torch.uint8), + m.view(torch.uint8), + weight.view(torch.uint8)], dim=-1) + ggml_weight = ggml_weight.reshape([-1]) + + return ggml_weight, g_id_map + + def convert_gptq(module, awq=False, llm_awq=False, act_order=False): from ipex_llm.transformers.low_bit_linear import get_block_size Q4_1 = get_block_size("asym_int4") @@ -389,6 +454,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, FP16Linear, BF16Linear from ipex_llm.transformers.embedding import CPUEmbedding, DiskEmbedding, LowBitEmbedding has_been_replaced = False + global _USE_VLLM_AWQ for name, module in model.named_children(): is_linear, linear_args = is_linear_module(module) @@ -453,6 +519,70 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if has_bias: new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ .to(device) + elif _USE_VLLM_AWQ: + # User load an AWQ quantized model from vLLM + from ipex_llm.transformers.low_bit_linear import vLLMLowBitLinear + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + has_bias = module.bias is not None and module.bias.abs().sum() != 0 + if isinstance(module, ParallelLMHead): + new_linear = LowBitLinear( + in_features, + out_features, + qtype=qtype, + bias=has_bias, + mp_group=mp_group, + enable_xetla=enable_xetla, + optimize_lm_head=False, + act_order=act_order, + enable_scale_search=enable_scale_search, + ) + device = module.weight.data.device + cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype, + full_module_name, + imatrix_data, + model_config) + # Copy the weights + paramsLowBit = FP4Params(data=module.weight.data, + requires_grad=False, + quantized=False, + _shape=None, + convert_shape_only=convert_shape_only, + qtype=cur_qtype, + imatrix=cur_imatrix, + in_features=in_features, + enable_xetla=enable_xetla, + enable_scale_search=enable_scale_search).to(device) + else: + new_linear = vLLMLowBitLinear( + in_features, + out_features, + qtype=qtype, + bias=has_bias, + mp_group=mp_group, + enable_xetla=enable_xetla, + optimize_lm_head=False, + act_order=act_order, + enable_scale_search=enable_scale_search, + ) + device = module.qweight.data.device + invalidInputError(device.type != "meta", + "converting from meta device is not supported") + weight, g_idx_map = convert_vllm_awq(module) + if act_order: + new_linear.g_idx_map = g_idx_map + # Copy the weights + paramsLowBit = FP4Params(data=weight, + requires_grad=False, + quantized=True, + _shape=(out_features, in_features), + convert_shape_only=convert_shape_only, + qtype=qtype, + enable_xetla=enable_xetla, + enable_scale_search=enable_scale_search).to(device) + new_linear._parameters['weight'] = paramsLowBit + if has_bias: + new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\ + .to(device) elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]: if in_features % 64 != 0: # now our kernel requires in_features is a multiple of 64 @@ -1871,4 +2001,4 @@ def safe_bmm_fwd(*args, **kwargs): minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate) model.generate = MethodType(minicpmv_generate, model) - return model + return model \ No newline at end of file