diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index df18d597394..63487dfaf92 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -174,6 +174,7 @@ def from_pretrained(cls, *args, **kwargs): intra_pp=intra_pp, transpose_value_cache=transpose_value_cache, ) + model.save_low_bit = types.MethodType(save_low_bit, model) else: from ipex_llm.transformers.npu_models.convert import optimize_llm optimize_llm(model) @@ -209,10 +210,16 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs) ignore_argument(kwargs, "lightweight_bmm") ignore_argument(kwargs, "cpu_embedding") ignore_argument(kwargs, "embedding_qtype") - ignore_argument(kwargs, "optimize_model") ignore_argument(kwargs, "modules_to_not_convert") ignore_argument(kwargs, "speculative") ignore_argument(kwargs, "pipeline_parallel_stages") + optimize_model = kwargs.pop("optimize_model", False) + max_output_len = kwargs.pop("max_output_len", 1024) + max_prompt_len = kwargs.pop("max_prompt_len", 512) + inter_pp = kwargs.pop("inter_pp", None) + intra_pp = kwargs.pop("intra_pp", None) + transpose_value_cache = kwargs.pop("transpose_value_cache", True) + modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) from transformers.models.auto.configuration_auto import AutoConfig from transformers.modeling_utils import no_init_weights, get_state_dict_dtype @@ -351,12 +358,34 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs) logger.info(f"Converting model, it may takes up to several minutes ...") from intel_npu_acceleration_library.compiler import create_npu_kernels - with torch.no_grad(): - optimize_llm(model) - cls.load_convert(qtype, model, quant_device, *model_args, **kwargs) - create_npu_kernels(model) + if optimize_model: + invalidInputError( + max_prompt_len < max_output_len, + ( + f"max_prompt_len ({max_prompt_len}) should be less" + " than max_output_len ({max_output_len})" + ), + ) + from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre + + if hasattr(model, "llm"): + llm = model.llm + else: + llm = model + + with torch.no_grad(): + optimize_llm_pre(model, qtype) + cls.load_convert(qtype, model, quant_device, modules_to_not_convert, + *model_args, **kwargs) + create_npu_kernels(llm) - model = model.eval() + else: + from ipex_llm.transformers.npu_models.convert import optimize_llm + optimize_llm(model) + with torch.no_grad(): + cls.load_convert(qtype, model, quant_device, modules_to_not_convert, + *model_args, **kwargs) + create_npu_kernels(model) if is_sharded: loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] @@ -415,6 +444,17 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs) for param in model.parameters(): param.requires_grad_(False) + if optimize_model: + from ipex_llm.transformers.npu_models.convert_mp import optimize_llm + optimize_llm( + llm, + max_output_len=max_output_len, + max_prompt_len=max_prompt_len, + inter_pp=inter_pp, + intra_pp=intra_pp, + transpose_value_cache=transpose_value_cache, + ) + return model