diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 9ce54536bca..21cfa7cf22b 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -181,6 +181,11 @@ def is_linear_module(module): out_features = module.output_size result = True mp_group = None + # Check for attribute qweight + if (not _USE_VLLM_AWQ + and hasattr(module.quant_method, "quant_config") + and module.quant_method.quant_config.get_name() == "awq"): + _USE_VLLM_AWQ = True invalidInputError(module.skip_bias_add is not True, "Currently, ipex-vllm does not" " support linear layers with skip_bias_add argument") if isinstance(module, RowParallelLinear) and tp_size >= 2: @@ -285,6 +290,65 @@ def convert_vllm(module, qtype, in_features, out_features, mp_group, cur_qtype, return new_linear +def convert_vllm_awq(module): + from ipex_llm.transformers.low_bit_linear import get_block_size + Q4_1 = get_block_size("asym_int4") + + scales = module.scales + wf = (torch.tensor([0, 4, 1, 5, 2, 6, 3, 7], + dtype=torch.int32) * 4).unsqueeze(0) + # vLLM only supports load 4-bits model, so this has been checked + bits = 4 + group_size = module.quant_method.quant_config.group_size + + zeros = torch.bitwise_right_shift( + torch.unsqueeze(module.qzeros, 2).expand(-1, -1, 32 // bits), + wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) + zeros = torch.bitwise_and(zeros, (2 ** bits) - 1) + + g_id_map = None + + zeros = zeros.reshape(scales.shape) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(module.qweight, 2).expand(-1, -1, 32 // bits), + wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2 ** bits) - 1) + weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2]) + + # convert weight to ggml format + weight = weight.reshape(weight.shape[0]//group_size, group_size, weight.shape[1]) + weight = weight.permute(2, 0, 1).reshape(weight.shape[2], -1, 2, Q4_1//2) + weight = weight.transpose(2, 3) + weight = torch.bitwise_left_shift(weight, + torch.tensor([0, 4], dtype=torch.int8).reshape(1, 1, 1, 2)) + weight = torch.bitwise_or(weight[:, :, :, 0], weight[:, :, :, 1]).contiguous() + + # convert zeros to ggml format + zeros = zeros.reshape(-1, 1, zeros.shape[1]).permute(2, 0, 1)\ + .unsqueeze(2)\ + .expand(-1, -1, group_size//Q4_1, -1)\ + .reshape(zeros.shape[1], -1, 1)\ + .contiguous().to(torch.float16) + + # convert scales to ggml format + scales = scales.reshape(-1, 1, scales.shape[1]).permute(2, 0, 1)\ + .unsqueeze(2)\ + .expand(-1, -1, group_size//Q4_1, -1)\ + .reshape(scales.shape[-1], -1, 1)\ + .contiguous().to(torch.float16) + + m = -(zeros * scales) + d = scales + + ggml_weight = torch.cat([d.view(torch.uint8), + m.view(torch.uint8), + weight.view(torch.uint8)], dim=-1) + ggml_weight = ggml_weight.reshape([-1]) + + return ggml_weight, g_id_map + + def convert_gptq(module, awq=False, llm_awq=False, act_order=False): from ipex_llm.transformers.low_bit_linear import get_block_size Q4_1 = get_block_size("asym_int4") @@ -1917,6 +1981,7 @@ def safe_bmm_fwd(*args, **kwargs): minicpmv_generate = minicpmv_generate_wrapper(module.MiniCPMV.generate) model.generate = MethodType(minicpmv_generate, model) +<<<<<<< HEAD if model.config.hidden_size == 2304 and model.config.vocab_size == 122753: # MiniCPM-V 2 model.llm.config.model_type = "minicpm" @@ -1960,3 +2025,6 @@ def safe_bmm_fwd(*args, **kwargs): model.chat = MethodType(minicpmv_chat, model) return model +======= + return model +>>>>>>> 56b851402b (Add vllm awq loading logic (#11987))