diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 0b3babd9c42..65d394b5ab9 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -331,6 +331,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, if any(key in full_module_name for key in modules_to_not_convert): continue + if is_linear and getattr(model_config, "model_type", None) == "chatglm" and name == "lm_head": + # Now we re-reference it to output_layer + model._modules[name] = model._modules["transformer"]._modules["output_layer"] + continue + if is_linear and not isinstance(module, LowBitLinear): in_features, out_features, mp_group = linear_args optimize_lm_head = False