Skip to content

Commit

Permalink
fix cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Dec 25, 2024
1 parent 711957a commit b899ac0
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def llama_attention_forward(
self.num_key_value_heads,
self.num_key_value_heads], dim=1)

kv_seq_len = key_states.shape[-2]
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

if query_states.device.type == "xpu":
import xe_addons
if position_embeddings is None:
Expand All @@ -152,7 +155,12 @@ def llama_attention_forward(
xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)
else:
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
if isinstance(getattr(self.rotary_emb, "cos_cached", None), torch.Tensor):
# transformers < 4.38
cos, sin = self.rotary_emb(value_states, kv_seq_len)
else:
# transformers >= 4.38
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
Expand Down

0 comments on commit b899ac0

Please sign in to comment.