Skip to content

Commit

Permalink
[NPU] Modify IPEX_LLM_NPU_DISABLE_COMPILE_OPT setting for long input (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang authored Dec 13, 2024
1 parent 7cc01fd commit 6596c18
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
8 changes: 5 additions & 3 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def optimize_npu_model(cls, *args, **kwargs):
model.config.update({"group_size": quantization_group_size})
model.config.update({"asym": qtype == "asym_int4_rtn"})
optimize_llm_pre(model, qtype, mixed_precision,
quantization_group_size=quantization_group_size)
quantization_group_size=quantization_group_size,
max_prompt_len=max_prompt_len)
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
quantization_group_size, imatrix_data,
*args, **kwargs)
Expand Down Expand Up @@ -580,7 +581,7 @@ def load_low_bit(cls, pretrained_model_name_or_path: str, *model_args, **kwargs)
with torch.no_grad():
optimize_llm_pre(model, qtype, mixed_precision,
quantization_group_size=quantization_group_size,
load=bigdl_lcmu_enabled)
load=bigdl_lcmu_enabled, max_prompt_len=max_prompt_len)
cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
quantization_group_size, *model_args, **kwargs)
create_npu_kernels(llm)
Expand Down Expand Up @@ -804,7 +805,8 @@ def optimize_npu_model(cls, *args, **kwargs):

with torch.no_grad():
optimize_llm_pre(model, qtype, mixed_precision,
quantization_group_size=quantization_group_size)
quantization_group_size=quantization_group_size,
max_prompt_len=max_prompt_len)
cls.load_convert_fp16(qtype, model.encoder, "cpu", modules_to_not_convert,
quantization_group_size)
create_npu_kernels(model.encoder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def convert_forward(m, target_m, new_forward):


def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
quantization_group_size=0, load=False):
quantization_group_size=0, load=False, max_prompt_len=512):
if model.config.model_type == "baichuan":
# process NormHead module in Baichuan2 7B
if hasattr(model, 'lm_head') and model.lm_head is not None:
Expand All @@ -48,6 +48,13 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,

cpu_lm_head = os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0"

# workaround for long input performance of llama3.2-3b and glm-edge-4b CW
if os.environ.get("IPEX_LLM_NPU_DISABLE_COMPILE_OPT") is None:
disable_compile_opt = model.config.model_type == "llama" and \
model.config.hidden_size == 3072 and max_prompt_len >= 1920 and \
quantization_group_size == 0
os.environ["IPEX_LLM_NPU_DISABLE_COMPILE_OPT"] = "1" if disable_compile_opt else "0"

# workaround for MiniCPM-2B
if model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40:
# 73440 is vocab_size of MiniCPM-1B
Expand Down

0 comments on commit 6596c18

Please sign in to comment.