Skip to content

Commit

Permalink
optimize minicpm (#12496)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Dec 4, 2024
1 parent ae9c215 commit a9e3f7f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,8 +1032,9 @@ def _optimize_pre(model, qtype=None):
from ipex_llm.transformers.models.mllama import merge_qkv
model.apply(merge_qkv)
elif model.config.model_type == "minicpm":
from ipex_llm.transformers.models.minicpm import merge_qkv
from ipex_llm.transformers.models.minicpm import merge_qkv, apply_residual_scale
model.apply(merge_qkv)
model.apply(apply_residual_scale)
elif model.config.model_type == "minicpm3":
from ipex_llm.transformers.models.minicpm3 import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)
Expand Down Expand Up @@ -2101,9 +2102,11 @@ def safe_bmm_fwd(*args, **kwargs):
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
from ipex_llm.transformers.models.minicpm import minicpm_model_forward_wrapper
from ipex_llm.transformers.models.minicpm import minicpm_decoder_layer_forward
convert_forward(model, module.MiniCPMAttention, minicpm_attention_forward)
convert_forward(model, module.MiniCPMMLP, llama_mlp_forward)
convert_forward(model, module.MiniCPMRMSNorm, llama_rms_norm_forward)
convert_forward(model, module.MiniCPMDecoderLayer, minicpm_decoder_layer_forward)
minicpm_model_forward = minicpm_model_forward_wrapper(module.MiniCPMModel.forward)
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
elif model.config.model_type == "minicpm3":
Expand Down
60 changes: 60 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, "MiniCPMAttention")


def apply_residual_scale(module: torch.nn.Module):
if module.__class__.__name__ == "MiniCPMDecoderLayer":
scale = module.scale_depth / math.sqrt(module.num_hidden_layers)
module.self_attn.o_proj.weight.data *= scale
if module.self_attn.o_proj.bias is not None:
module.self_attn.o_proj.bias.weight.data *= scale
module.mlp.down_proj.weight.data *= scale
if module.mlp.down_proj.bias is not None:
module.mlp.down_proj.bias.weight.data *= scale


def minicpm_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -214,3 +225,52 @@ def minicpm_model_forward(
)

return minicpm_model_forward


def minicpm_decoder_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)

# ipex-llm changes start
hidden_states = residual + hidden_states
# ipex-llm changes end

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

hidden_states = self.mlp(hidden_states)

# ipex-llm changes start
hidden_states = residual + hidden_states
# ipex-llm changes end

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)

if use_cache:
outputs += (present_key_value,)

return outputs

0 comments on commit a9e3f7f

Please sign in to comment.