Skip to content

Commit

Permalink
use mlp silu mul fusion in qwen2
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Jul 12, 2024
1 parent a945500 commit cb10902
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
3 changes: 2 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
convert_forward(model,
module.Qwen2Model,
qwen2_model_forward)
Expand All @@ -1334,7 +1335,7 @@ def _optimize_post(model, lightweight_bmm=False):
llama_rms_norm_forward)
convert_forward(model,
module.Qwen2MLP,
llama_mlp_forward)
qwen2_mlp_forward)
convert_forward(model,
module.Qwen2Attention,
qwen2_attention_forward)
Expand Down
39 changes: 32 additions & 7 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from torch.nn import CrossEntropyLoss
from torch.nn.functional import scaled_dot_product_attention as sdpa

from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
Expand Down Expand Up @@ -183,6 +184,14 @@ def qwen2_model_forward(

hidden_states = inputs_embeds

# ipex-llm changes
curr_device = decoder_layer.input_layernorm.weight.device
if attention_mask is not None:
attention_mask = attention_mask.to(curr_device)
if position_ids is not None:
position_ids = position_ids.to(curr_device)
# ipex-llm changes end

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand All @@ -203,13 +212,6 @@ def qwen2_model_forward(
use_cache,
)
else:
# ipex-llm changes
curr_device = decoder_layer.input_layernorm.weight.device
if attention_mask is not None:
attention_mask = attention_mask.to(curr_device)
if position_ids is not None:
position_ids = position_ids.to(curr_device)
# ipex-llm changes end
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
Expand Down Expand Up @@ -491,3 +493,26 @@ def qwen2_attention_forward(
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value


def qwen2_mlp_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
import xe_linear
return self.down_proj(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
SILU, qtype
))
elif not self.training:
import xe_addons
gate = self.gate_proj(x)
up = self.up_proj(x)
xe_addons.mlp_silu_mul_inplaced(gate, up)
return self.down_proj(gate)
else:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

0 comments on commit cb10902

Please sign in to comment.