Skip to content

Commit

Permalink
Support minicpm for NPU C++ (#12434)
Browse files Browse the repository at this point in the history
* support minicpm-1b

* update

* tune fused_layers

* update readme.md
  • Loading branch information
rnwang04 authored Nov 25, 2024
1 parent 0819fad commit f414053
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ In this directory, you will find a C++ example on how to run LLM models on Intel
| Qwen2.5 | [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) |
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |

## 0. Requirements
To run this C++ example with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,36 @@ def convert_llm_for_deploy(model: torch.nn.Module,
# save blob of lmhead and bin of embedding
convert_lm_head_and_embedding(model, n_splits_linear,
save_directory, weight_dir, True)
elif model.config.model_type == "minicpm":
layernorm_const = True
fused_layers = 4
update_dict = {"kv_len": kv_len,
"num_head": model.model.layers[0].self_attn.num_heads,
"head_dim": model.model.layers[0].self_attn.head_dim,
"transpose_value_cache": transpose_value_cache,
"max_prompt_len": max_prompt_len,
"layernorm_const": layernorm_const,
"group_size": group_size,
"fused_layers": fused_layers,
"qkv_bias": False,
"use_prefill_sdp": False,
"weight_num": 7,
"weight_idx": 5,
"model_type": "minicpm",
"embedding_post": True}
model.config.update(update_dict)
model.config.save_pretrained(save_directory)

from .minicpm import convert_minicpm_layer, convert_fused_minicpm_layer
from .minicpm import convert_lm_head_and_embedding
# save fused_layers blobs of fused decoder layers
convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, kv_len,
group_size, layernorm_const, "decode")
# save blob of single prefill layer
convert_minicpm_layer(model, 0, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, max_prompt_len,
group_size, layernorm_const, "prefill")
# save blob of lmhead and bin of embedding
convert_lm_head_and_embedding(model, n_splits_linear,
save_directory, weight_dir, True, max_prompt_len)
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
)
else:
# llama-3.2-3B & llama-3.2-1B
embedding_layer = model.model.embed_tokens
new_embedding = Llama32Embedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
embedding_weight=model.model.embed_tokens.weight.to(torch.float16).detach().numpy(),
embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16),
attention_scaling=model.model.rotary_emb.attention_scaling,
Expand Down
195 changes: 172 additions & 23 deletions python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,29 @@ def __init__(
self.compile()


class MiniCPMPostEmbedding(NNFactory):
def __init__(
self,
input_size,
embedding_dim,
dtype, # fp16
scale_emb,
device: str = "NPU",
):
super().__init__(False, device)
self.embedding_dim = embedding_dim
self.dtype = dtype

input = self.parameter((1, input_size, embedding_dim), dtype=dtype)
res = input * scale_emb

# define outputs
res = self.convert_to_fp16(res)

print("start compiling")
self.compile()


class MiniCPMLMHead(LLMBaseNNFactory):
def __init__(
self,
Expand Down Expand Up @@ -134,7 +157,8 @@ def __init__(
self.compile()


def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
convert_model=False, max_prompt_len=1):
num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
head_dim = model.model.layers[0].self_attn.head_dim
Expand Down Expand Up @@ -180,7 +204,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
vocab_size=vocab_size,
n_splits=n_splits_linear
)
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir,
True, True)

# save weights bins files
if n_splits_linear == 1:
Expand Down Expand Up @@ -209,14 +234,31 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
dtype=np.float16,
scale_emb=model.config.scale_emb,
)
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir)
if convert_model:
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
first_blob_path = None
# save embedding post module
embedding_post = MiniCPMPostEmbedding(1, model.config.hidden_size,
dtype=np.float16,
scale_emb=model.config.scale_emb)
update_names_of_IR_and_export_blob(embedding_post, "embedding_post",
temp_dir, True, False)
embedding_post_prefill = MiniCPMPostEmbedding(max_prompt_len, model.config.hidden_size,
dtype=np.float16,
scale_emb=model.config.scale_emb)
update_names_of_IR_and_export_blob(embedding_post_prefill,
"embedding_post_prefill",
temp_dir, True, False)
else:
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir, True, False)
return first_blob_path, last_blob_path


def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const):
layernorm_const, mode="decode"):
num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
head_dim = model.model.layers[0].self_attn.head_dim
Expand Down Expand Up @@ -252,8 +294,16 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
else: # FP16 Linear
np_dtype = np.float16

if mode == "decode":
input_len = 1
decoder_name = f"decoder_layer_{layer_idx}"
else:
input_len = kv_len
decoder_name = "decoder_layer_prefill"
layernorm_const = False

single_decoder = LowBitMinicpmMultiDecoderlayer(
[1, 1, num_heads * head_dim],
[1, input_len, num_heads * head_dim],
input_layernorm_weights=[layer_norm_0] if layernorm_const else None,
post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None,
cached_cos=cached_cos,
Expand All @@ -266,28 +316,127 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
intermediate_size=intermediate_size,
scale_depth=scale_depth,
num_hidden_layers=num_hidden_layers,
mode="decode",
mode=mode,
transpose_value=transpose_value_cache,
dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
)
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
f"decoder_layer_{layer_idx}",
temp_dir)
decoder_name,
temp_dir,
True, True)

if layernorm_const:
st_idx = 5
else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 7
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
scale.numpy().tofile(bin_file)
del single_decoder
if mode == "decode":
if layernorm_const:
st_idx = 5
else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 7
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
scale.numpy().tofile(bin_file)
del single_decoder


def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const, mode="decode"):
num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
head_dim = model.model.layers[0].self_attn.head_dim
intermediate_size = model.config.intermediate_size
rms_norm_eps = model.config.rms_norm_eps
num_hidden_layers = model.config.num_hidden_layers
scale_depth = model.model.config.scale_depth
layer_num = len(model.model.layers)
fused_layer_num = layer_num // fused_layers

from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer
for i in range(fused_layers):
layer_start = i * fused_layer_num
layer_end = min((i + 1) * fused_layer_num, layer_num)
layer_weights = []
input_layer_norm_weights = []
post_attn_layernorm_weights = []
layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp

weights = []
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)

layer_weights.extend(weights)
input_layer_norm_weights.append(layer_norm_0)
post_attn_layernorm_weights.append(layer_norm_1)

# save weight
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 5
# 6, 7 are past k/v
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
scale.numpy().tofile(bin_file)

if isinstance(weights[0], tuple):
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
else: # FP16 Linear
np_dtype = np.float16

fused_decoder = LowBitMinicpmMultiDecoderlayer(
[1, 1, num_heads * head_dim],
input_layernorm_weights=input_layer_norm_weights,
post_attn_layernorm_weights=post_attn_layernorm_weights,
cached_cos=cached_cos,
cached_sin=cached_sin,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
num_layers=fused_layer_num,
max_seq_len=kv_len,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
scale_depth=scale_depth,
num_hidden_layers=num_hidden_layers,
mode=mode,
transpose_value=transpose_value_cache,
dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
)
update_names_of_IR_and_export_blob(fused_decoder,
f"decoder_layer_{i}",
save_dir,
compile_blob=True,
keep_ir=False)
return 0

0 comments on commit f414053

Please sign in to comment.