Skip to content

Commit

Permalink
use mlp silu_mul fusion in qwen2 to optimize memory usage (#11574)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jul 13, 2024
1 parent 13a72dc commit 019da6c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.qwen2 import qwen2_model_forward
from ipex_llm.transformers.models.qwen2 import qwen2_attention_forward
from ipex_llm.transformers.models.qwen2 import qwen2_causal_lm_forward
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
convert_forward(model,
module.Qwen2Model,
qwen2_model_forward)
Expand All @@ -1334,7 +1335,7 @@ def _optimize_post(model, lightweight_bmm=False):
llama_rms_norm_forward)
convert_forward(model,
module.Qwen2MLP,
llama_mlp_forward)
qwen2_mlp_forward)
convert_forward(model,
module.Qwen2Attention,
qwen2_attention_forward)
Expand Down
24 changes: 24 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from torch.nn import CrossEntropyLoss
from torch.nn.functional import scaled_dot_product_attention as sdpa

from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
Expand Down Expand Up @@ -491,3 +492,26 @@ def qwen2_attention_forward(
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value


def qwen2_mlp_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
import xe_linear
return self.down_proj(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
SILU, qtype
))
elif not self.training:
import xe_addons
gate = self.gate_proj(x)
up = self.up_proj(x)
xe_addons.mlp_silu_mul_inplaced(gate, up)
return self.down_proj(gate)
else:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

0 comments on commit 019da6c

Please sign in to comment.