From 1c144ee0ee1dbb83ce2d3bf1d0951e5e7ed281ee Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Fri, 30 Aug 2024 05:32:01 -0700 Subject: [PATCH] Support Qwen2-7b mlp in int4 --- .../transformers/npu_models/convert_mp.py | 7 ++- .../transformers/npu_models/mp_models_base.py | 2 +- .../transformers/npu_models/qwen2_mp.py | 51 +++++++++++++++---- 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index a1b07a8cf5ae..02c5edd3dd9d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -42,6 +42,11 @@ def optimize_llm_pre(model: torch.nn.Module, qtype): from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq model.apply(pre_compute_inv_freq) + if model.config.model_type == "qwen2": + from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj + from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_forward + model.apply(split_mlp_down_proj) + # lm_head to cpu optimization if os.environ.get("IPEX_LLM_CPU_LM_HEAD", "1") != "0": from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8 @@ -110,8 +115,6 @@ def optimize_llm( intra_pp = 2 if inter_pp is None: inter_pp = 4 if model.config.intermediate_size == 18944 else 1 - if model.config.intermediate_size == 18944: - transpose_value_cache = False from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 611270f8b129..784f6d450052 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -396,7 +396,7 @@ def set_weights_async(self, op_id, weights): (f"weights size does not match graph, " f"with weights size: {len(weights)} and " f" graph linear size: {len(self.linear_ops)}")) - self.setWeights(offset, op_id, *weights) + self.setWeights(offset, op_id, *weights, verify_size=True) @staticmethod def run_decoders(inputs, decoders): diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 60f8e2baeec8..2022f0e38226 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -42,6 +42,30 @@ from ipex_llm.transformers.npu_models.common import reshape_lm_head_input from transformers.modeling_outputs import CausalLMOutputWithPast from torch.nn import CrossEntropyLoss +from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP + + +def split_mlp_down_proj(module: torch.nn.Module): + if isinstance(module, Qwen2MLP) and module.down_proj.in_features == 18944: + new_linear_0 = torch.nn.Linear(0, 0, bias=False) + new_weight_0 = torch.nn.Parameter(module.down_proj.weight[:, :9472], requires_grad=False) + new_linear_0.weight = new_weight_0 + new_linear_0.in_features = new_weight_0.size(1) + new_linear_0.out_features = new_weight_0.size(0) + module.down_proj_0 = new_linear_0 + new_linear_1 = torch.nn.Linear(0, 0, bias=False) + new_weight_1 = torch.nn.Parameter(module.down_proj.weight[:, 9472:], requires_grad=False) + new_linear_1.weight = new_weight_1 + new_linear_1.in_features = new_weight_1.size(1) + new_linear_1.out_features = new_weight_1.size(0) + module.down_proj_1 = new_linear_1 + + del module.down_proj + + +def split_mlp_forward(self, x): + h = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + return self.down_proj_0(h[:, :, :9472]) + self.down_proj_1(h[:, :, 9472:]) class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): @@ -199,7 +223,7 @@ def __init__( self.compile() - def mlp(self, hidden_states): + def mlp(self, hidden_states, seq_len): mm1 = self.linear( hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype ) @@ -209,9 +233,11 @@ def mlp(self, hidden_states): mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] if self.intermediate_size == 18944: # for qwen2-7b - hidden_states = self.linear( - mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=np.int8 - ) + mm1_0 = self.slice(mm1, begin=[0, 0, 0], end=[1, seq_len, 9472]) + mm1_1 = self.slice(mm1, begin=[0, 0, 9472], end=[1, seq_len, 18944]) + hidden_states_0 = self.linear(mm1_0, self.hidden_size, 9472, bias=False, wt_dtype=self.dtype) + hidden_states_1 = self.linear(mm1_1, self.hidden_size, 9472, bias=False, wt_dtype=self.dtype) + hidden_states = hidden_states_0 + hidden_states_1 else: hidden_states = self.linear( mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype @@ -255,7 +281,7 @@ def build_decoder( hidden_states = self.eltwise_add(residual, attn_output) residual = hidden_states hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, self.seq_len) hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) @@ -341,9 +367,13 @@ def __init__( ) self.backend_decoders.append(decoder) + offset = 0 for i in range(intra_stages): start, end = self.layer_ranges[i] - self.backend_decoders[i].set_weights(self.op_id, op_parameters[start * 7:end * 7]) + curr_linear_ops = len(self.backend_decoders[i].linear_ops) + curr_parameters = self.op_parameters[offset:offset + curr_linear_ops] + self.backend_decoders[i].set_weights(self.op_id, curr_parameters) + offset = offset + curr_linear_ops def forward( self, @@ -541,7 +571,8 @@ def run_decode( (attn_layer.o_proj.weight, attn_layer.o_proj.scale), (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), + (mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale), + (mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale) ] cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) @@ -812,6 +843,8 @@ def run_prefill( transpose_value=transpose_value_cache ) convert_forward(model, Qwen2Attention, qwen2_attention_forward) + from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP + convert_forward(model, Qwen2MLP, split_mlp_forward) deocderlayers = model.model.layers while True: @@ -834,7 +867,6 @@ def run_prefill( hidden_states = layer_outputs[0] next_decoder_cache = layer_outputs[1] - result_queue.put((hidden_states, next_decoder_cache)) @@ -1120,10 +1152,11 @@ def qwen2_attention_forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - cache_kwargs = {"max_seq_len": max_seq_len, "transpose": transpose_value, } if past_key_value is not None: + if transpose_value: + value_states = value_states.transpose(-1, -2) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)