From b899ac0010650215d55b7ebf8fab07bbb3b80cc9 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 25 Dec 2024 15:30:40 +0800 Subject: [PATCH] fix cpu --- python/llm/src/ipex_llm/transformers/models/llama.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 446dbe62974..63d1c90716b 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -139,6 +139,9 @@ def llama_attention_forward( self.num_key_value_heads, self.num_key_value_heads], dim=1) + kv_seq_len = key_states.shape[-2] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if query_states.device.type == "xpu": import xe_addons if position_embeddings is None: @@ -152,7 +155,12 @@ def llama_attention_forward( xe_addons.rotary_half_with_cache_inplaced(query_states, key_states, cos, sin) else: if position_embeddings is None: - cos, sin = self.rotary_emb(value_states, position_ids) + if isinstance(getattr(self.rotary_emb, "cos_cached", None), torch.Tensor): + # transformers < 4.38 + cos, sin = self.rotary_emb(value_states, kv_seq_len) + else: + # transformers >= 4.38 + cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)