From a9e3f7f14c6e176d8721ea0b51193c947dfe9906 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 4 Dec 2024 17:14:16 +0800 Subject: [PATCH] optimize minicpm (#12496) --- .../llm/src/ipex_llm/transformers/convert.py | 5 +- .../ipex_llm/transformers/models/minicpm.py | 60 +++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index f6b159c32d1..e3674c6c09f 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1032,8 +1032,9 @@ def _optimize_pre(model, qtype=None): from ipex_llm.transformers.models.mllama import merge_qkv model.apply(merge_qkv) elif model.config.model_type == "minicpm": - from ipex_llm.transformers.models.minicpm import merge_qkv + from ipex_llm.transformers.models.minicpm import merge_qkv, apply_residual_scale model.apply(merge_qkv) + model.apply(apply_residual_scale) elif model.config.model_type == "minicpm3": from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq model.apply(pre_compute_inv_freq) @@ -2101,9 +2102,11 @@ def safe_bmm_fwd(*args, **kwargs): module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.minicpm import minicpm_attention_forward from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper + from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward) convert_forward(model, module.MiniCPMMLP, llama_mlp_forward) convert_forward(model, module.MiniCPMRMSNorm, llama_rms_norm_forward) + convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward) minicpm_model_forward = minicpm_model_forward_wrapper(module.MiniCPMModel.forward) convert_forward(model, module.MiniCPMModel, minicpm_model_forward) elif model.config.model_type == "minicpm3": diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index d248c507773..6e2cab0f741 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -56,6 +56,17 @@ def merge_qkv(module: torch.nn.Module): return merge_qkv_base(module, "MiniCPMAttention") +def apply_residual_scale(module: torch.nn.Module): + if module.__class__.__name__ == "MiniCPMDecoderLayer": + scale = module.scale_depth / math.sqrt(module.num_hidden_layers) + module.self_attn.o_proj.weight.data *= scale + if module.self_attn.o_proj.bias is not None: + module.self_attn.o_proj.bias.weight.data *= scale + module.mlp.down_proj.weight.data *= scale + if module.mlp.down_proj.bias is not None: + module.mlp.down_proj.bias.weight.data *= scale + + def minicpm_attention_forward( self, hidden_states: torch.Tensor, @@ -214,3 +225,52 @@ def minicpm_model_forward( ) return minicpm_model_forward + + +def minicpm_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + + # ipex-llm changes start + hidden_states = residual + hidden_states + # ipex-llm changes end + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + + # ipex-llm changes start + hidden_states = residual + hidden_states + # ipex-llm changes end + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs