diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 39d5888e230..27e4469fd99 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -472,7 +472,11 @@ def layer_norm(self, hidden_states, layernorm_weight): ) eps = self.constant(self.rms_norm_eps) hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) - hidden_states = self.convert_to_fp16(hidden_states) + if os.environ.get("IPEX_LLM_NPU_DRIVER_VERSION", None) in ["5716", "5733"]: + # to support special drivers + hidden_states = self.convert_to_fp16(hidden_states) + else: + layernorm_weight = self.convert_to_fp32(layernorm_weight) hidden_states = self.eltwise_mul(layernorm_weight, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) return hidden_states