Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 1383 Llama model on transformers=4.41[WIP] #11280

Merged
merged 10 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,19 +947,32 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
transformers.models.llama.modeling_llama.LlamaDecoderLayer,
llama_decoder_forward)

if version.parse(trans_version) >= version.parse("4.36.0"):
# transformers version >= 4.36.0
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
if version.parse(trans_version) >= version.parse("4.38.0"):
from ipex_llm.transformers.models.llama import llama_model_forward_4_38
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_38)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38)
if version.parse(trans_version) >= version.parse("4.41.0"):
from ipex_llm.transformers.models.llama import llama_model_forward_4_41
from ipex_llm.transformers.models.llama import llama_attention_forward_4_41
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_41)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_41)
else:
from ipex_llm.transformers.models.llama import llama_model_forward_4_38
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_38)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38)
else:
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
convert_forward(
Expand Down
Loading
Loading