Skip to content

Commit

Permalink
fix npu lm_head cpu condition (#11976)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix stype

* fix style

* fix style
  • Loading branch information
rnwang04 authored Aug 30, 2024
1 parent 60aa1a2 commit 573c20b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
intra_pp=args.intra_pp,
inter_pp=args.inter_pp,
transpose_value_cache=not args.disable_transpose_value_cache,
modules_to_not_convert=['vpm', 'resampler']
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down
15 changes: 14 additions & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,31 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):
from ipex_llm.transformers.models.baichuan import pre_compute_inv_freq
model.apply(pre_compute_inv_freq)

# MiniCPM-V 2.6 and minicpm-2b must put lm_head on CPU now
cpu_lm_head = (
(model.config.model_type == "minicpmv" and model.config.hidden_size == 3584 and
model.config.vocab_size == 151666)
or (
model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40
)
or os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0"
)

if model.config.model_type == "minicpmv" and hasattr(model, "llm"):
# MiniCPM-V
if model.config.hidden_size == 2304 and model.config.vocab_size == 122753:
# MiniCPM-V 2
model.llm.config.model_type = "minicpm"
elif model.config.hidden_size == 3584 and model.config.vocab_size == 151666:
# MiniCPM-V 2.6
model.llm.config.model_type = "qwen2"
elif model.config.hidden_size == 4096 and model.config.vocab_size == 128256:
# MiniCPM-V 2.5
model.llm.config.model_type = "llama"
model = model.llm

# lm_head to cpu optimization
if os.environ.get("IPEX_LLM_CPU_LM_HEAD", "0") != "0":
if cpu_lm_head:
# disable the optimization by default
from ipex_llm.transformers.low_bit_linear import SYM_INT4, SYM_INT8
if qtype == "sym_int4_rtn":
Expand Down

0 comments on commit 573c20b

Please sign in to comment.