Skip to content

Commit

Permalink
add save & load support for NPU optimized model (#11999)
Browse files Browse the repository at this point in the history
* add save &  load support

* fix style
  • Loading branch information
rnwang04 authored Sep 3, 2024
1 parent 6eb5565 commit 9eaff5e
Showing 1 changed file with 46 additions and 6 deletions.
52 changes: 46 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def from_pretrained(cls, *args, **kwargs):
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
)
model.save_low_bit = types.MethodType(save_low_bit, model)
else:
from ipex_llm.transformers.npu_models.convert import optimize_llm
optimize_llm(model)
Expand Down Expand Up @@ -209,10 +210,16 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs)
ignore_argument(kwargs, "lightweight_bmm")
ignore_argument(kwargs, "cpu_embedding")
ignore_argument(kwargs, "embedding_qtype")
ignore_argument(kwargs, "optimize_model")
ignore_argument(kwargs, "modules_to_not_convert")
ignore_argument(kwargs, "speculative")
ignore_argument(kwargs, "pipeline_parallel_stages")
optimize_model = kwargs.pop("optimize_model", False)
max_output_len = kwargs.pop("max_output_len", 1024)
max_prompt_len = kwargs.pop("max_prompt_len", 512)
inter_pp = kwargs.pop("inter_pp", None)
intra_pp = kwargs.pop("intra_pp", None)
transpose_value_cache = kwargs.pop("transpose_value_cache", True)
modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])

from transformers.models.auto.configuration_auto import AutoConfig
from transformers.modeling_utils import no_init_weights, get_state_dict_dtype
Expand Down Expand Up @@ -351,12 +358,34 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs)
logger.info(f"Converting model, it may takes up to several minutes ...")
from intel_npu_acceleration_library.compiler import create_npu_kernels

with torch.no_grad():
optimize_llm(model)
cls.load_convert(qtype, model, quant_device, *model_args, **kwargs)
create_npu_kernels(model)
if optimize_model:
invalidInputError(
max_prompt_len < max_output_len,
(
f"max_prompt_len ({max_prompt_len}) should be less"
" than max_output_len ({max_output_len})"
),
)
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm_pre

if hasattr(model, "llm"):
llm = model.llm
else:
llm = model

with torch.no_grad():
optimize_llm_pre(model, qtype)
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
*model_args, **kwargs)
create_npu_kernels(llm)

model = model.eval()
else:
from ipex_llm.transformers.npu_models.convert import optimize_llm
optimize_llm(model)
with torch.no_grad():
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
*model_args, **kwargs)
create_npu_kernels(model)

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
Expand Down Expand Up @@ -415,6 +444,17 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs)
for param in model.parameters():
param.requires_grad_(False)

if optimize_model:
from ipex_llm.transformers.npu_models.convert_mp import optimize_llm
optimize_llm(
llm,
max_output_len=max_output_len,
max_prompt_len=max_prompt_len,
inter_pp=inter_pp,
intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache,
)

return model


Expand Down

0 comments on commit 9eaff5e

Please sign in to comment.