Skip to content

Commit

Permalink
refactor mistral and phi3 (#12605)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Dec 24, 2024
1 parent 45f8f72 commit 073f936
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 1,364 deletions.
55 changes: 16 additions & 39 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,9 @@ def _optimize_pre(model, qtype=None):
elif model.config.model_type == "mllama":
from ipex_llm.transformers.models.mllama import merge_qkv
model.apply(merge_qkv)
elif model.config.model_type == "mistral":
from ipex_llm.transformers.models.mistral import merge_qkv
model.apply(merge_qkv)
elif model.config.model_type == "minicpm":
from ipex_llm.transformers.models.minicpm import merge_qkv, apply_residual_scale
model.apply(merge_qkv)
Expand Down Expand Up @@ -1901,43 +1904,17 @@ def _optimize_post(model, lightweight_bmm=False):
else:
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
if version.parse(trans_version) >= version.parse("4.36.0"):
from ipex_llm.transformers.models.mistral import mistral_model_forward_4_36
if version.parse(trans_version) >= version.parse("4.39.0"):
from ipex_llm.transformers.models.mistral import \
mistral_attention_forward_4_39
convert_forward(model,
module.MistralAttention,
mistral_attention_forward_4_39
)
else:
from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_36
convert_forward(model,
module.MistralAttention,
mistral_attention_forward_4_36
)
convert_forward(model,
module.MistralModel,
mistral_model_forward_4_36
)
convert_forward(model,
module.MistralRMSNorm,
llama_rms_norm_forward)
convert_forward(model,
module.MistralMLP,
llama_mlp_forward)
else:
from ipex_llm.transformers.models.mistral import mistral_attention_forward
convert_forward(model,
module.MistralAttention,
mistral_attention_forward
)
convert_forward(model,
module.MistralRMSNorm,
llama_rms_norm_forward)
convert_forward(model,
module.MistralMLP,
llama_mlp_forward)

from ipex_llm.transformers.models.mistral import mistral_model_forward
from ipex_llm.transformers.models.mistral import mistral_attention_forward
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.common import mlp_silu_forward

convert_forward(model, module.MistralModel, mistral_model_forward)
convert_forward(model, module.MistralAttention, mistral_attention_forward)
convert_forward(model, module.MistralSdpaAttention, mistral_attention_forward)
convert_forward(model, module.MistralRMSNorm, rms_norm_forward)
convert_forward(model, module.MistralMLP, mlp_silu_forward)
elif model.config.model_type == "gemma":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
Expand Down Expand Up @@ -2078,8 +2055,8 @@ def safe_bmm_fwd(*args, **kwargs):
convert_forward(model, module.Phi3Attention, attention_forward)
from ipex_llm.transformers.models.phi3 import mlp_forward
convert_forward(model, module.Phi3MLP, mlp_forward)
from ipex_llm.transformers.models.phi3 import phi3_rms_norm_forward
convert_forward(model, module.Phi3RMSNorm, phi3_rms_norm_forward)
from ipex_llm.transformers.models.common import rms_norm_forward
convert_forward(model, module.Phi3RMSNorm, rms_norm_forward)
if model.config.model_type == "phi3":
from ipex_llm.transformers.models.phi3 import phi3_model_forward_wrapper
model_forward = phi3_model_forward_wrapper(module.Phi3Model.forward)
Expand Down
11 changes: 8 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,13 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
key = repeat_kv(key, n_heads // n_kv_heads)
value = repeat_kv(value, n_heads // n_kv_heads)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, is_causal=is_causal, scale=scale
)
if is_causal and mask is None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, is_causal=is_causal, scale=scale
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, scale=scale
)
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
return attn_output
Loading

0 comments on commit 073f936

Please sign in to comment.