diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index c0d94b6e030..8f0048ed4b9 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1323,6 +1323,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.qwen2 import qwen2_model_forward from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward + from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward convert_forward(model, module.Qwen2Model, qwen2_model_forward) @@ -1334,7 +1335,7 @@ def _optimize_post(model, lightweight_bmm=False): llama_rms_norm_forward) convert_forward(model, module.Qwen2MLP, - llama_mlp_forward) + qwen2_mlp_forward) convert_forward(model, module.Qwen2Attention, qwen2_attention_forward) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 2bd9626e8ab..90de62ab6ca 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -45,6 +45,7 @@ from torch.nn import CrossEntropyLoss from torch.nn.functional import scaled_dot_product_attention as sdpa +from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal @@ -491,3 +492,26 @@ def qwen2_attention_forward( if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value + + +def qwen2_mlp_forward( + self, + x: torch.Tensor, +) -> torch.Tensor: + x_2d = x.view(-1, x.shape[-1]) + qtype = getattr(self.gate_proj, "qtype", None) + if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla: + import xe_linear + return self.down_proj(xe_linear.mlp_forward_xpu( + x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, + x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, + SILU, qtype + )) + elif not self.training: + import xe_addons + gate = self.gate_proj(x) + up = self.up_proj(x) + xe_addons.mlp_silu_mul_inplaced(gate, up) + return self.down_proj(gate) + else: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))