diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 5dac6c5a871..0c70bf635b0 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -65,6 +65,11 @@ def optimize_llm_pre(model: torch.nn.Module, qtype): model.llm.config.model_type = "llama" model = model.llm + if model.config.model_type == "qwen2": + from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj + from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_forward + model.apply(split_mlp_down_proj) + # lm_head to cpu optimization if cpu_lm_head: # disable the optimization by default @@ -134,8 +139,6 @@ def optimize_llm( intra_pp = 2 if inter_pp is None: inter_pp = 4 if model.config.intermediate_size == 18944 else 1 - if model.config.intermediate_size == 18944: - transpose_value_cache = False from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 61bff6e76a4..30e9054d8e4 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -42,6 +42,30 @@ from ipex_llm.transformers.npu_models.common import reshape_lm_head_input from transformers.modeling_outputs import CausalLMOutputWithPast from torch.nn import CrossEntropyLoss +from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP + + +def split_mlp_down_proj(module: torch.nn.Module): + if isinstance(module, Qwen2MLP) and module.down_proj.in_features == 18944: + new_linear_0 = torch.nn.Linear(0, 0, bias=False) + new_weight_0 = torch.nn.Parameter(module.down_proj.weight[:, :9472], requires_grad=False) + new_linear_0.weight = new_weight_0 + new_linear_0.in_features = new_weight_0.size(1) + new_linear_0.out_features = new_weight_0.size(0) + module.down_proj_0 = new_linear_0 + new_linear_1 = torch.nn.Linear(0, 0, bias=False) + new_weight_1 = torch.nn.Parameter(module.down_proj.weight[:, 9472:], requires_grad=False) + new_linear_1.weight = new_weight_1 + new_linear_1.in_features = new_weight_1.size(1) + new_linear_1.out_features = new_weight_1.size(0) + module.down_proj_1 = new_linear_1 + + del module.down_proj + + +def split_mlp_forward(self, x): + h = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + return self.down_proj_0(h[:, :, :9472]) + self.down_proj_1(h[:, :, 9472:]) class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): @@ -201,7 +225,7 @@ def __init__( self.compile() print("end compiling") - def mlp(self, hidden_states): + def mlp(self, hidden_states, seq_len): mm1 = self.linear( hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype ) @@ -211,9 +235,13 @@ def mlp(self, hidden_states): mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] if self.intermediate_size == 18944: # for qwen2-7b - hidden_states = self.linear( - mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=np.int8 - ) + mm1_0 = self.slice(mm1, begin=[0, 0, 0], end=[1, seq_len, 9472]) + mm1_1 = self.slice(mm1, begin=[0, 0, 9472], end=[1, seq_len, 18944]) + hidden_states_0 = self.linear(mm1_0, self.hidden_size, 9472, + bias=False, wt_dtype=self.dtype) + hidden_states_1 = self.linear(mm1_1, self.hidden_size, 9472, + bias=False, wt_dtype=self.dtype) + hidden_states = hidden_states_0 + hidden_states_1 else: hidden_states = self.linear( mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype @@ -257,7 +285,7 @@ def build_decoder( hidden_states = self.eltwise_add(residual, attn_output) residual = hidden_states hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, self.seq_len) hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) @@ -343,9 +371,13 @@ def __init__( ) self.backend_decoders.append(decoder) + offset = 0 for i in range(intra_stages): start, end = self.layer_ranges[i] - self.backend_decoders[i].set_weights(self.op_id, op_parameters[start * 7:end * 7]) + curr_linear_ops = len(self.backend_decoders[i].linear_ops) + curr_parameters = self.op_parameters[offset:offset + curr_linear_ops] + self.backend_decoders[i].set_weights(self.op_id, curr_parameters) + offset = offset + curr_linear_ops def forward( self, @@ -543,7 +575,8 @@ def run_decode( (attn_layer.o_proj.weight, attn_layer.o_proj.scale), (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), - (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), + (mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale), + (mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale) ] cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) @@ -814,6 +847,8 @@ def run_prefill( transpose_value=transpose_value_cache ) convert_forward(model, Qwen2Attention, qwen2_attention_forward) + from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP + convert_forward(model, Qwen2MLP, split_mlp_forward) deocderlayers = model.model.layers while True: @@ -836,7 +871,6 @@ def run_prefill( hidden_states = layer_outputs[0] next_decoder_cache = layer_outputs[1] - result_queue.put((hidden_states, next_decoder_cache)) @@ -1124,10 +1158,11 @@ def qwen2_attention_forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - cache_kwargs = {"max_seq_len": max_seq_len, "transpose": transpose_value, } if past_key_value is not None: + if transpose_value: + value_states = value_states.transpose(-1, -2) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)