From aabf2242e959cc9be836e178034a070a31ad26ab Mon Sep 17 00:00:00 2001 From: plusbang Date: Tue, 24 Dec 2024 10:57:13 +0800 Subject: [PATCH] fix --- .../llm/src/ipex_llm/transformers/npu_models/mp_models_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 8f2d25070d7..d406b3ef920 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -492,7 +492,7 @@ def rotate_half(self, x, *, num_heads, seq_len, head_dim): def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids, num_heads, seq_len, head_dim): if position_ids is not None: - position_ids = self.squeeze(position_ids) + position_ids = self.reshape(position_ids, [-1]) cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) cos = self.unsqueeze(cos, [1])