diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 73666487333..0b008729880 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -247,7 +247,8 @@ def attention(self, attn_weight = self.matmul(query_states, key_states, False, True) / ( math.sqrt(head_dim) ) - attention_mask = self.convert_to_fp16(attention_mask) + if mode != "prefill": + attention_mask = self.convert_to_fp16(attention_mask) attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.softmax(attn_weight, -1)