From cb109021c42534b543a3ec189e69014f56828deb Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 12 Jul 2024 16:52:40 +0800 Subject: [PATCH 1/2] use mlp silu mul fusion in qwen2 --- .../llm/src/ipex_llm/transformers/convert.py | 3 +- .../src/ipex_llm/transformers/models/qwen2.py | 39 +++++++++++++++---- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index c0d94b6e030..8f0048ed4b9 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1323,6 +1323,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.qwen2 import qwen2_model_forward from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward + from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward convert_forward(model, module.Qwen2Model, qwen2_model_forward) @@ -1334,7 +1335,7 @@ def _optimize_post(model, lightweight_bmm=False): llama_rms_norm_forward) convert_forward(model, module.Qwen2MLP, - llama_mlp_forward) + qwen2_mlp_forward) convert_forward(model, module.Qwen2Attention, qwen2_attention_forward) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 2bd9626e8ab..db48debecb4 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -45,6 +45,7 @@ from torch.nn import CrossEntropyLoss from torch.nn.functional import scaled_dot_product_attention as sdpa +from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal @@ -183,6 +184,14 @@ def qwen2_model_forward( hidden_states = inputs_embeds + # ipex-llm changes + curr_device = decoder_layer.input_layernorm.weight.device + if attention_mask is not None: + attention_mask = attention_mask.to(curr_device) + if position_ids is not None: + position_ids = position_ids.to(curr_device) + # ipex-llm changes end + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -203,13 +212,6 @@ def qwen2_model_forward( use_cache, ) else: - # ipex-llm changes - curr_device = decoder_layer.input_layernorm.weight.device - if attention_mask is not None: - attention_mask = attention_mask.to(curr_device) - if position_ids is not None: - position_ids = position_ids.to(curr_device) - # ipex-llm changes end layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -491,3 +493,26 @@ def qwen2_attention_forward( if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value + + +def qwen2_mlp_forward( + self, + x: torch.Tensor, +) -> torch.Tensor: + x_2d = x.view(-1, x.shape[-1]) + qtype = getattr(self.gate_proj, "qtype", None) + if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla: + import xe_linear + return self.down_proj(xe_linear.mlp_forward_xpu( + x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, + x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, + SILU, qtype + )) + elif not self.training: + import xe_addons + gate = self.gate_proj(x) + up = self.up_proj(x) + xe_addons.mlp_silu_mul_inplaced(gate, up) + return self.down_proj(gate) + else: + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) From def438fc5961e942df6899be56f135e51acf44a0 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 12 Jul 2024 16:54:03 +0800 Subject: [PATCH 2/2] fix --- .../llm/src/ipex_llm/transformers/models/qwen2.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index db48debecb4..90de62ab6ca 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -184,14 +184,6 @@ def qwen2_model_forward( hidden_states = inputs_embeds - # ipex-llm changes - curr_device = decoder_layer.input_layernorm.weight.device - if attention_mask is not None: - attention_mask = attention_mask.to(curr_device) - if position_ids is not None: - position_ids = position_ids.to(curr_device) - # ipex-llm changes end - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -212,6 +204,13 @@ def qwen2_model_forward( use_cache, ) else: + # ipex-llm changes + curr_device = decoder_layer.input_layernorm.weight.device + if attention_mask is not None: + attention_mask = attention_mask.to(curr_device) + if position_ids is not None: + position_ids = position_ids.to(curr_device) + # ipex-llm changes end layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask,