diff --git a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py index 94494da2e15..7979bfbc62a 100644 --- a/python/llm/src/ipex_llm/vllm/xpu/model_convert.py +++ b/python/llm/src/ipex_llm/vllm/xpu/model_convert.py @@ -248,7 +248,15 @@ def _ipex_llm_load_model(self) -> None: parallel_config=self.parallel_config, scheduler_config=self.scheduler_config) from ipex_llm import optimize_model - optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype) + import os + not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None) + if not_convert_last_mlp is not None: + # only use to avoid nan value in last mlp forward running glm4-9b-chat + modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"] + else: + modules = None + optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype, + modules_to_not_convert=modules) self.model = self.model.to(device=self.device_config.device, dtype=self.model_config.dtype)