diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 8f2d25070d7..39d5888e230 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -29,6 +29,7 @@ import numpy as np from typing import Optional, Any, List import numpy.typing as npt +import os logger = logging.get_logger(__name__) @@ -492,7 +493,11 @@ def rotate_half(self, x, *, num_heads, seq_len, head_dim): def apply_rotary_pos_emb(self, *, q, k, cos, sin, position_ids, num_heads, seq_len, head_dim): if position_ids is not None: - position_ids = self.squeeze(position_ids) + if os.environ.get("IPEX_LLM_NPU_MTL", "0") == "1" or\ + os.environ.get("IPEX_LLM_NPU_ARL", "0") == "1": + position_ids = self.reshape(position_ids, [-1]) + else: + position_ids = self.squeeze(position_ids) cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) cos = self.unsqueeze(cos, [1])