diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 97184b29d66..9877fed0dad 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1912,6 +1912,7 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.MistralModel, mistral_model_forward) convert_forward(model, module.MistralAttention, mistral_attention_forward) + convert_forward(model, module.MistralSdpaAttention, mistral_attention_forward) convert_forward(model, module.MistralRMSNorm, rms_norm_forward) convert_forward(model, module.MistralMLP, mlp_silu_forward) elif model.config.model_type == "gemma":