Skip to content

Commit

Permalink
Add chatglm2&3 fuse mlp (#12328)
Browse files Browse the repository at this point in the history
* add chatglm fuse mlp
  • Loading branch information
leonardozcm authored Nov 4, 2024
1 parent 94c4ce3 commit 1b637e4
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
8 changes: 8 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,12 @@ def _optimize_pre(model, qtype=None):
model.llm.config.model_type = "llama"
_optimize_pre(model.llm, qtype=qtype)
model.llm.config.model_type = "minicpmv"
if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
from ipex_llm.transformers.models.chatglm2 import split_mlp
if hasattr(model.config, 'padded_vocab_size') and \
model.config.padded_vocab_size == 65024:
model.apply(split_mlp)

return model

Expand Down Expand Up @@ -1372,6 +1378,7 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
from ipex_llm.transformers.models.chatglm2 import chatglm2_encoder_forward
from ipex_llm.transformers.models.chatglm2 import chatglm2_model_forward
from ipex_llm.transformers.models.chatglm2 import mlp_forward
convert_forward(model,
module.SelfAttention,
chatglm2_attention_forward)
Expand All @@ -1384,6 +1391,7 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
convert_forward(model, module.MLP, mlp_forward)
elif hasattr(model.config, 'padded_vocab_size') and \
model.config.padded_vocab_size == 64896:
# codegeex-nano
Expand Down
46 changes: 45 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
Expand Down Expand Up @@ -91,7 +92,7 @@ def chatglm2_model_forward(

if use_cache:
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
input_ids)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
Expand Down Expand Up @@ -570,3 +571,46 @@ def codegeex_attention_forward(
output = self.dense(context_layer)

return output, past_key_value

import torch.nn.functional as F


def split_mlp(module: torch.nn.Module):
if module.__class__.__name__ == "MLP":
gate_weight, up_weight = module.dense_h_to_4h.weight.data.chunk(2, dim=0)

gate_proj = torch.nn.Linear(0, 0, bias=False)
gate_proj.weight = torch.nn.Parameter(gate_weight, requires_grad=False)
gate_proj.in_features = gate_weight.size(1)
gate_proj.out_features = gate_weight.size(0)

up_proj = torch.nn.Linear(0, 0, bias=False)
up_proj.weight = torch.nn.Parameter(up_weight, requires_grad=False)
up_proj.in_features = up_weight.size(1)
up_proj.out_features = up_weight.size(0)

module.gate_proj = gate_proj
module.up_proj = up_proj

module.activation_fn = F.silu

del module.dense_h_to_4h


def mlp_forward(
self,
hidden_states: torch.FloatTensor
) -> torch.FloatTensor:
x_2d = hidden_states.view(-1, hidden_states.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
x_2d = x_2d.contiguous()
import xe_linear
return self.dense_4h_to_h(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_features,
SILU, qtype
))
return self.dense_4h_to_h(
self.activation_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
)

0 comments on commit 1b637e4

Please sign in to comment.