From 0e23bd779f043145710f46b400555a3beff07a04 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Mon, 25 Nov 2024 17:26:55 -0800 Subject: [PATCH] Add support of llama3.2 for NPU C++ (#12442) * initial support of llama3.2 * update * update * fix style * fix style * fix * small fix --- .../LLM/CPP_Examples/README.md | 7 ++ .../LLM/CPP_Examples/convert.py | 13 ++- .../LLM/CPP_Examples/llm-npu-cli.cpp | 2 +- .../LLM/Pipeline-Models/qwen.py | 2 +- .../transformers/npu_models/llama_mp.py | 7 +- .../npu_pipeline_model/convert_pipeline.py | 38 ++++++-- .../transformers/npu_pipeline_model/llama.py | 91 ++++++++++++++++--- 7 files changed, 133 insertions(+), 27 deletions(-) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md index 79e8bb94021..c847efebc50 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md @@ -10,6 +10,7 @@ In this directory, you will find a C++ example on how to run LLM models on Intel | Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | | Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | | MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16), [openbmb/MiniCPM-2B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) | +| Llama3.2 | [meta-llama/Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct), [meta-llama/Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) | ## 0. Requirements To run this C++ example with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU. @@ -55,6 +56,12 @@ python convert.py --repo-id-or-model-path openbmb/MiniCPM-1B-sft-bf16 --save-di :: to convert MiniCPM-2B-sft-bf16 python convert.py --repo-id-or-model-path openbmb/MiniCPM-2B-sft-bf16 --save-directory + +:: to convert Llama-3.2-1B-Instruct +python convert.py --repo-id-or-model-path meta-llama/Llama-3.2-1B-Instruct --save-directory + +:: to convert Llama-3.2-3B-Instruct +python convert.py --repo-id-or-model-path meta-llama/Llama-3.2-3B-Instruct --save-directory ``` Arguments info: diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/convert.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/convert.py index c4781f0a419..867527be888 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/convert.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/convert.py @@ -18,8 +18,12 @@ import torch import argparse from ipex_llm.transformers.npu_model import AutoModelForCausalLM +import transformers from transformers import AutoTokenizer from transformers.utils import logging +from packaging import version +import os +import shutil logger = logging.get_logger(__name__) @@ -67,7 +71,14 @@ save_directory=save_dir) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - tokenizer.save_pretrained(save_dir) + + trans_version = transformers.__version__ + if version.parse(trans_version) >= version.parse("4.45.0"): + tokenizer_json = os.path.join(model_path, "tokenizer.json") + dst_path = os.path.join(save_dir, "tokenizer.json") + shutil.copy(tokenizer_json, dst_path) + else: + tokenizer.save_pretrained(save_dir) print("-" * 80) print(f"finish save model to {save_dir}") diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/llm-npu-cli.cpp b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/llm-npu-cli.cpp index 2c45ba2e55d..06e19af4378 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/llm-npu-cli.cpp +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/llm-npu-cli.cpp @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { for (int i = 1; i < params.n_predict; i++){ auto logits = run_decode(model, embd[i-1]); int32_t token = llm_sample_token(logits, true, model_params); - if (token != tok_params.eos_token_id) { + if (std::find(tok_params.eos_token_id.begin(), tok_params.eos_token_id.end(), token) == tok_params.eos_token_id.end()){ embd.push_back(token); token_nums ++; } else { diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py index d46ee771c0e..d04961ece87 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py @@ -46,7 +46,7 @@ help='Prompt to infer') parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--max-context-len", type=int, default=1024) - parser.add_argument("--max-prompt-len", type=int, default=960) + parser.add_argument("--max-prompt-len", type=int, default=512) parser.add_argument("--quantization_group_size", type=int, default=0) parser.add_argument('--low_bit', type=str, default="sym_int4", help='Low bit precision to quantize the model') diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 1f3ac302e83..69f618c888a 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -71,6 +71,7 @@ def __init__( n_splits_down_proj: int = 1, group_size: int = 0, cos_len: int = 1, + keep_position_ids=True, ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, @@ -122,7 +123,7 @@ def __init__( self.seq_len), dtype=np.float16) if self.cached_cos is None: - if mode == "prefill": + if mode == "prefill" and keep_position_ids: position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) cos = self.create_input_op((self.batch_size, self.cos_len, self.head_dim), dtype=np.float32) @@ -185,12 +186,12 @@ def __init__( hidden_states = input curr_key_values = [] + cos_condition = cached_cos is not None or (mode == "prefill" and keep_position_ids) for i in range(num_layers): hidden_states, new_key_states, new_value_states = self.build_decoder( hidden_states=hidden_states, attention_mask=attention_mask, - position_ids=position_ids if (cached_cos is not None - or mode == "prefill") else None, + position_ids=position_ids if cos_condition else None, input_layernorm_weight=input_layernorm_weights[i], post_attention_layernorm_weight=post_attn_layernorm_weights[i], past_key=past_keys[i], diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 4537a756acc..84b02363452 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -456,15 +456,27 @@ def convert_llm_for_deploy(model: torch.nn.Module, group_size, layernorm_const, "prefill") # save blob of lmhead and bin of embedding convert_lm_head_and_embedding(model, n_splits_linear, - save_directory, weight_dir, True) + save_directory, weight_dir, + convert_model=True) elif model.config.model_type == "llama": layernorm_const = True + embedding_post = False + cos_sin_input = False + use_prefill_sdp = False if model.config.vocab_size == 32000: # for Llama2-7B fused_layers = 4 + use_prefill_sdp = True else: - # for Llama3-8B - fused_layers = 2 + if model.config.intermediate_size == 8192: + # llama3.2 1B & # llama3.2 3B + embedding_post = True + cos_sin_input = True + fused_layers = 2 + else: + # for Llama3-8B + fused_layers = 2 + use_prefill_sdp = True update_dict = {"kv_len": kv_len, "num_head": model.model.layers[0].self_attn.num_heads, "head_dim": model.model.layers[0].self_attn.head_dim, @@ -474,14 +486,21 @@ def convert_llm_for_deploy(model: torch.nn.Module, "group_size": group_size, "fused_layers": fused_layers, "qkv_bias": False, - "use_prefill_sdp": True, + "use_prefill_sdp": use_prefill_sdp, "weight_num": 7, - "weight_idx": 5} + "weight_idx": 5, + "embedding_post": embedding_post, + "cos_sin_input": cos_sin_input} model.config.update(update_dict) model.config.save_pretrained(save_directory) from .llama import convert_llama_layer, convert_fused_llama_layer from .llama import convert_lm_head_and_embedding + # save blob of lmhead and bin of embedding & (optional) embedding_post + convert_lm_head_and_embedding(model, n_splits_linear, + save_directory, weight_dir, + convert_model=True, + max_prompt_len=max_prompt_len) # save fused_layers blobs of fused decoder layers convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_down_proj, save_directory, weight_dir, transpose_value_cache, kv_len, @@ -490,9 +509,6 @@ def convert_llm_for_deploy(model: torch.nn.Module, convert_llama_layer(model, 0, n_splits_linear, n_splits_down_proj, save_directory, weight_dir, transpose_value_cache, max_prompt_len, group_size, layernorm_const, "prefill") - # save blob of lmhead and bin of embedding - convert_lm_head_and_embedding(model, n_splits_linear, - save_directory, weight_dir, True) elif model.config.model_type == "minicpm": layernorm_const = True fused_layers = 4 @@ -523,6 +539,8 @@ def convert_llm_for_deploy(model: torch.nn.Module, convert_minicpm_layer(model, 0, n_splits_linear, n_splits_down_proj, save_directory, weight_dir, transpose_value_cache, max_prompt_len, group_size, layernorm_const, "prefill") - # save blob of lmhead and bin of embedding + # save blob of lmhead and bin of embedding and embedding_post convert_lm_head_and_embedding(model, n_splits_linear, - save_directory, weight_dir, True, max_prompt_len) + save_directory, weight_dir, + convert_model=True, + max_prompt_len=max_prompt_len) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index ecf1083a52b..0899ef4ab17 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -83,8 +83,46 @@ def __init__( self.compile() +class Llama32PostEmbedding(NNFactory): + def __init__( + self, + inv_freq, + attention_scaling, + input_len: int = 1, + device: str = "NPU", + ): + super().__init__(False, device) + self.attention_scaling = attention_scaling + + # define input + position_ids = self.parameter((1, input_len), dtype=np.int64) + inv_freq = self.constant(inv_freq) + + # rotary_emb module + inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1)) + position_ids = self.reshape(position_ids, (1, 1, input_len)) + freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq), + self.convert_to_fp32(position_ids)) + freqs = self.transpose(freqs, [0, 2, 1]) + emb = self.concat(freqs, freqs, axis=2) + cos = self.cos(emb) + sin = self.sin(emb) + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + if input_len > 1: + cos = self.unsqueeze(cos, [1]) + sin = self.unsqueeze(sin, [1]) + + # define outputs + cos = self.convert_to_fp32(cos) + sin = self.convert_to_fp32(sin) + + print("start compiling") + self.compile() + + def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, - convert_model=False): + convert_model=False, max_prompt_len=1): num_heads = model.model.layers[0].self_attn.num_heads num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads head_dim = model.model.layers[0].self_attn.head_dim @@ -145,6 +183,13 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, padding_idx=model.config.pad_token_id, dtype=np.float16, ) + if convert_model: + bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") + embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) + first_blob_path = None + else: + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir, True, False) else: # llama-3.2-3B & llama-3.2-1B embedding_layer = model.model.embed_tokens @@ -157,13 +202,27 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, attention_scaling=model.model.rotary_emb.attention_scaling, dtype=np.float16, ) - if convert_model: - bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") - embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) - first_blob_path = None - else: - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", - temp_dir) + if convert_model: + bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") + embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) + first_blob_path = None + # save embedding post module + inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16) + attention_scaling = model.model.rotary_emb.attention_scaling + embedding_post = Llama32PostEmbedding(inv_freq=inv_freq, + attention_scaling=attention_scaling, + input_len=1) + update_names_of_IR_and_export_blob(embedding_post, "embedding_post", + temp_dir, True, False) + embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq, + attention_scaling=attention_scaling, + input_len=max_prompt_len) + update_names_of_IR_and_export_blob(embedding_post_prefill, + "embedding_post_prefill", + temp_dir, True, False) + else: + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir) return first_blob_path, last_blob_path @@ -212,10 +271,12 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, if mode == "decode": input_len = 1 decoder_name = f"decoder_layer_{layer_idx}" + keep_position_ids = True else: input_len = kv_len decoder_name = "decoder_layer_prefill" layernorm_const = False + keep_position_ids = False single_decoder = LowBitLlamaMultiDecoderlayer( [1, input_len, num_heads * head_dim], @@ -234,7 +295,9 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + cos_len=input_len, + keep_position_ids=keep_position_ids ) rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, @@ -309,8 +372,14 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow scales.append(l.scale) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) - cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): + # llama-2-7B & llama-3-8B + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + else: + # llama-3.2-3B & llama-3.2-1B + cached_cos = None + cached_sin = None layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16)