Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support minicpm for NPU C++ #12434

Merged
merged 4 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ In this directory, you will find a C++ example on how to run LLM models on Intel
| Qwen2.5 | [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) |
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) |

## 0. Requirements
To run this C++ example with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,36 @@ def convert_llm_for_deploy(model: torch.nn.Module,
# save blob of lmhead and bin of embedding
convert_lm_head_and_embedding(model, n_splits_linear,
save_directory, weight_dir, True)
elif model.config.model_type == "minicpm":
layernorm_const = True
fused_layers = 4
update_dict = {"kv_len": kv_len,
"num_head": model.model.layers[0].self_attn.num_heads,
"head_dim": model.model.layers[0].self_attn.head_dim,
"transpose_value_cache": transpose_value_cache,
"max_prompt_len": max_prompt_len,
"layernorm_const": layernorm_const,
"group_size": group_size,
"fused_layers": fused_layers,
"qkv_bias": False,
"use_prefill_sdp": False,
"weight_num": 7,
"weight_idx": 5,
"model_type": "minicpm",
"embedding_post": True}
model.config.update(update_dict)
model.config.save_pretrained(save_directory)

from .minicpm import convert_minicpm_layer, convert_fused_minicpm_layer
from .minicpm import convert_lm_head_and_embedding
# save fused_layers blobs of fused decoder layers
convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, kv_len,
group_size, layernorm_const, "decode")
# save blob of single prefill layer
convert_minicpm_layer(model, 0, n_splits_linear, n_splits_down_proj,
save_directory, weight_dir, transpose_value_cache, max_prompt_len,
group_size, layernorm_const, "prefill")
# save blob of lmhead and bin of embedding
convert_lm_head_and_embedding(model, n_splits_linear,
save_directory, weight_dir, True, max_prompt_len)
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
)
else:
# llama-3.2-3B & llama-3.2-1B
embedding_layer = model.model.embed_tokens
new_embedding = Llama32Embedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
embedding_weight=model.model.embed_tokens.weight.to(torch.float16).detach().numpy(),
embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16),
attention_scaling=model.model.rotary_emb.attention_scaling,
Expand Down
195 changes: 172 additions & 23 deletions python/llm/src/ipex_llm/transformers/npu_pipeline_model/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,29 @@ def __init__(
self.compile()


class MiniCPMPostEmbedding(NNFactory):
def __init__(
self,
input_size,
embedding_dim,
dtype, # fp16
scale_emb,
device: str = "NPU",
):
super().__init__(False, device)
self.embedding_dim = embedding_dim
self.dtype = dtype

input = self.parameter((1, input_size, embedding_dim), dtype=dtype)
res = input * scale_emb

# define outputs
res = self.convert_to_fp16(res)

print("start compiling")
self.compile()


class MiniCPMLMHead(LLMBaseNNFactory):
def __init__(
self,
Expand Down Expand Up @@ -134,7 +157,8 @@ def __init__(
self.compile()


def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
convert_model=False, max_prompt_len=1):
num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
head_dim = model.model.layers[0].self_attn.head_dim
Expand Down Expand Up @@ -180,7 +204,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
vocab_size=vocab_size,
n_splits=n_splits_linear
)
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir,
True, True)

# save weights bins files
if n_splits_linear == 1:
Expand Down Expand Up @@ -209,14 +234,31 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
dtype=np.float16,
scale_emb=model.config.scale_emb,
)
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir)
if convert_model:
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
first_blob_path = None
# save embedding post module
embedding_post = MiniCPMPostEmbedding(1, model.config.hidden_size,
dtype=np.float16,
scale_emb=model.config.scale_emb)
update_names_of_IR_and_export_blob(embedding_post, "embedding_post",
temp_dir, True, False)
embedding_post_prefill = MiniCPMPostEmbedding(max_prompt_len, model.config.hidden_size,
dtype=np.float16,
scale_emb=model.config.scale_emb)
update_names_of_IR_and_export_blob(embedding_post_prefill,
"embedding_post_prefill",
temp_dir, True, False)
else:
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir, True, False)
return first_blob_path, last_blob_path


def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
temp_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const):
layernorm_const, mode="decode"):
num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
head_dim = model.model.layers[0].self_attn.head_dim
Expand Down Expand Up @@ -252,8 +294,16 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
else: # FP16 Linear
np_dtype = np.float16

if mode == "decode":
input_len = 1
decoder_name = f"decoder_layer_{layer_idx}"
else:
input_len = kv_len
decoder_name = "decoder_layer_prefill"
layernorm_const = False

single_decoder = LowBitMinicpmMultiDecoderlayer(
[1, 1, num_heads * head_dim],
[1, input_len, num_heads * head_dim],
input_layernorm_weights=[layer_norm_0] if layernorm_const else None,
post_attn_layernorm_weights=[layer_norm_1] if layernorm_const else None,
cached_cos=cached_cos,
Expand All @@ -266,28 +316,127 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
intermediate_size=intermediate_size,
scale_depth=scale_depth,
num_hidden_layers=num_hidden_layers,
mode="decode",
mode=mode,
transpose_value=transpose_value_cache,
dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
)
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
f"decoder_layer_{layer_idx}",
temp_dir)
decoder_name,
temp_dir,
True, True)

if layernorm_const:
st_idx = 5
else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 7
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
scale.numpy().tofile(bin_file)
del single_decoder
if mode == "decode":
if layernorm_const:
st_idx = 5
else:
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 7
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
scale.numpy().tofile(bin_file)
del single_decoder


def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_down_proj,
save_dir, weight_dir, transpose_value_cache, kv_len, group_size,
layernorm_const, mode="decode"):
num_heads = model.model.layers[0].self_attn.num_heads
num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads
head_dim = model.model.layers[0].self_attn.head_dim
intermediate_size = model.config.intermediate_size
rms_norm_eps = model.config.rms_norm_eps
num_hidden_layers = model.config.num_hidden_layers
scale_depth = model.model.config.scale_depth
layer_num = len(model.model.layers)
fused_layer_num = layer_num // fused_layers

from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer
for i in range(fused_layers):
layer_start = i * fused_layer_num
layer_end = min((i + 1) * fused_layer_num, layer_num)
layer_weights = []
input_layer_norm_weights = []
post_attn_layernorm_weights = []
layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp

weights = []
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))

cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)

layer_weights.extend(weights)
input_layer_norm_weights.append(layer_norm_0)
post_attn_layernorm_weights.append(layer_norm_1)

# save weight
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin")
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
st_idx = 5
# 6, 7 are past k/v
for idx, (weight, scale) in enumerate(weights):
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
weight.numpy().tofile(bin_file)
bin_file = os.path.join(weight_dir,
f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
scale.numpy().tofile(bin_file)

if isinstance(weights[0], tuple):
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
else: # FP16 Linear
np_dtype = np.float16

fused_decoder = LowBitMinicpmMultiDecoderlayer(
[1, 1, num_heads * head_dim],
input_layernorm_weights=input_layer_norm_weights,
post_attn_layernorm_weights=post_attn_layernorm_weights,
cached_cos=cached_cos,
cached_sin=cached_sin,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
num_layers=fused_layer_num,
max_seq_len=kv_len,
rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size,
scale_depth=scale_depth,
num_hidden_layers=num_hidden_layers,
mode=mode,
transpose_value=transpose_value_cache,
dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
)
update_names_of_IR_and_export_blob(fused_decoder,
f"decoder_layer_{i}",
save_dir,
compile_blob=True,
keep_ir=False)
return 0
Loading