diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py new file mode 100644 index 00000000000..0e3da6e62ad --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py @@ -0,0 +1,54 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from openvino.runtime import Core, serialize +import os + + +def update_names_of_IR_and_export_blob(model, model_name, dir): + xml_path = os.path.join(dir, model_name + ".xml") + model.save(xml_path) + new_ir_path = os.path.join(dir, model_name + "_new.xml") + blob_path = os.path.join(dir, model_name + ".blob") + + core = Core() + core.set_property("NPU", {"NPU_COMPILATION_MODE_PARAMS": + "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add"}) + core.set_property("NPU", {"PERFORMANCE_HINT": "LATENCY"}) + model = core.read_model(xml_path) + inputs = model.inputs + for idx, input in enumerate(inputs): + if len(input.names) == 0: + model.inputs[idx].set_names({f"input_{idx}"}) + outputs = model.outputs + for idx, input in enumerate(outputs): + if len(input.names) == 0: + model.outputs[idx].set_names({f"output_{idx}"}) + # rewrite this model to a new IR path + if new_ir_path is not None: + serialize(model, new_ir_path) + + if blob_path is not None: + compiledModel = core.compile_model(model, device_name="NPU") + model_stream = compiledModel.export_model() + with open(blob_path, 'wb') as f: + f.write(model_stream) + + os.remove(xml_path) + os.remove(new_ir_path) + + return blob_path diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 31ff054c434..e5db1bb2ee5 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -15,7 +15,6 @@ # -from openvino.runtime import Core, serialize import os import torch from ipex_llm.utils.common import invalidInputError @@ -31,6 +30,7 @@ import tempfile import numpy as np from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead +from multiprocessing import Pool def generate( @@ -188,41 +188,6 @@ def generate( return output -def update_names_of_IR_and_export_blob(model, model_name, dir): - xml_path = os.path.join(dir, model_name + ".xml") - model.save(xml_path) - new_ir_path = os.path.join(dir, model_name + "_new.xml") - blob_path = os.path.join(dir, model_name + ".blob") - - core = Core() - core.set_property("NPU", {"NPU_COMPILATION_MODE_PARAMS": - "compute-layers-with-higher-precision=Sqrt,Power,ReduceMean,Add"}) - core.set_property("NPU", {"PERFORMANCE_HINT": "LATENCY"}) - model = core.read_model(xml_path) - inputs = model.inputs - for idx, input in enumerate(inputs): - if len(input.names) == 0: - model.inputs[idx].set_names({f"input_{idx}"}) - outputs = model.outputs - for idx, input in enumerate(outputs): - if len(input.names) == 0: - model.outputs[idx].set_names({f"output_{idx}"}) - # rewrite this model to a new IR path - if new_ir_path is not None: - serialize(model, new_ir_path) - - if blob_path is not None: - compiledModel = core.compile_model(model, device_name="NPU") - model_stream = compiledModel.export_model() - with open(blob_path, 'wb') as f: - f.write(model_stream) - - os.remove(xml_path) - os.remove(new_ir_path) - - return blob_path - - def convert_llm(model: torch.nn.Module, kv_len: int, max_prompt_len: int, @@ -235,180 +200,41 @@ def convert_llm(model: torch.nn.Module, n_splits_linear = model.config.hidden_size // group_size n_splits_down_proj = model.config.intermediate_size // group_size if model.config.model_type == "llama": - from ipex_llm.transformers.npu_models.convert_mp import convert_llama - convert_llama(model, - max_output_len=kv_len, - max_prompt_len=max_prompt_len, - decoder=False, - transpose_value_cache=transpose_value_cache) - from .llama import LowBitLlamaLMHead, LlamaEmbedding with tempfile.TemporaryDirectory() as temp_dir: - # generate lm_head blob weight_dir = os.path.join(temp_dir, "model_weights") os.mkdir(weight_dir) - num_heads = model.model.layers[0].self_attn.num_heads - num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads - head_dim = model.model.layers[0].self_attn.head_dim - intermediate_size = model.config.intermediate_size layer_num = len(model.model.layers) - rms_norm_eps = model.config.rms_norm_eps - vocab_size = model.config.vocab_size - model_norm = model.model.norm - lm_head = model.lm_head - if n_splits_linear == 1: - weights = [(lm_head.weight, lm_head.scale)] - else: - lm_heads = lm_head.lm_heads - lm_head_weights = [] - scales = [] - for i in range(n_splits_linear): - lm_head_weights.append(lm_heads[i].weight) - scales.append(lm_heads[i].scale) - weights = [(torch.stack(lm_head_weights, axis=0), - torch.stack(scales, axis=0))] - if isinstance(weights[0], tuple): - np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 - else: # FP16 Linear - np_dtype = np.float16 - - new_lm_head = LowBitLlamaLMHead( - [1, 1, num_heads * head_dim], - num_heads=num_heads, - num_key_value_heads=num_key_value_heads, - max_seq_len=kv_len, - rms_norm_eps=rms_norm_eps, - mode="decode", - transpose_value=False, - dtype=np_dtype, - model_norm_weight=model_norm.weight.to(torch.float16), - vocab_size=vocab_size, - n_splits=n_splits_linear - ) - last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) - - # save weights bins files - if n_splits_linear == 1: - weight_numpy = [ - lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), - ] - else: - weight_numpy = [v.numpy() for v in weights[0]] + from .llama import convert_llama_layer, convert_lm_head_and_embedding + first_blob_path, last_blob_path = convert_lm_head_and_embedding(model, n_splits_linear, + temp_dir, weight_dir) - for idx, weight in enumerate(weight_numpy): - bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") - weight.tofile(bin_file) - - embedding_layer = model.model.embed_tokens - new_embedding = LlamaEmbedding( - vocab_size=model.config.vocab_size, - embedding_dim=model.config.hidden_size, - embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), - padding_idx=model.config.pad_token_id, - dtype=np.float16, - ) - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", - temp_dir) - - # generate decoder layer blob - from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer + param_list = [] for layer_idx in range(0, layer_num): - curr_layer = model.model.layers[layer_idx] - attn_layer = curr_layer.self_attn - mlp_layer = curr_layer.mlp - - weights = [] - if n_splits_linear == 1: - for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, - attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, - attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, - mlp_layer.up_proj_dq_list): - weights.append((q.weight, q.scale)) - weights.append((k.weight, k.scale)) - weights.append((v.weight, v.scale)) - weights.append((o.weight, o.scale)) - weights.append((g.weight, g.scale)) - weights.append((u.weight, u.scale)) - else: - for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, - attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, - mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: - l_weights = [] - scales = [] - for l in layer_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), - torch.stack(scales, axis=0))) - - if n_splits_down_proj == 1: - for l in mlp_layer.down_proj_dq_list: - weights.append((l.weight, l.scale)) - else: - l_weights = [] - scales = [] - for l in mlp_layer.down_proj_dq_list: - l_weights.append(l.weight) - scales.append(l.scale) - weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) - - cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) - cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) - layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) - layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) - - if isinstance(weights[0], tuple): - np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 - else: # FP16 Linear - np_dtype = np.float16 - - if layer_idx == 0: - single_decoder = LowBitLlamaMultiDecoderlayer( - [1, 1, num_heads * head_dim], - input_layernorm_weights=None, - post_attn_layernorm_weights=None, - cached_cos=cached_cos, - cached_sin=cached_sin, - num_heads=num_heads, - num_key_value_heads=num_key_value_heads, - num_layers=1, - max_seq_len=kv_len, - rms_norm_eps=rms_norm_eps, - intermediate_size=intermediate_size, - mode="decode", - transpose_value=transpose_value_cache, - dtype=np_dtype, - n_splits_linear=n_splits_linear, - n_splits_down_proj=n_splits_down_proj, - group_size=group_size - ) - rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, - "decoder_layer", - temp_dir) - - input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin") - post_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_4.bin") - layer_norm_0.data.numpy().tofile(input_lm_bin_file) - layer_norm_1.data.numpy().tofile(post_lm_bin_file) - - for idx, (weight, scale) in enumerate(weights): - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{7+idx*2}.bin") - weight.numpy().tofile(bin_file) - bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{7+idx*2+1}.bin") - scale.numpy().tofile(bin_file) + param_list.append((model, layer_idx, n_splits_linear, n_splits_down_proj, + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size)) + with Pool() as pool: + result = pool.starmap(convert_llama_layer, param_list) + + # Prefill Runner + from ipex_llm.transformers.npu_models.convert_mp import convert_llama + convert_llama(model, + max_output_len=kv_len, + max_prompt_len=max_prompt_len, + decoder=False, + transpose_value_cache=transpose_value_cache) # patch attrs for generate model.kv_len = kv_len - model.num_head = num_heads - model.head_dim = head_dim + model.num_head = model.model.layers[0].self_attn.num_heads + model.head_dim = model.model.layers[0].self_attn.head_dim model.num_layers = layer_num model.transpose_value_cache = transpose_value_cache try: - res = InitLLMPipeline(kv_len, num_heads, head_dim, layer_num, + res = InitLLMPipeline(kv_len, model.num_head, model.head_dim, layer_num, model.vocab_size, weight_dir, "model", - first_blob_path, last_blob_path, rest_blob_path) + first_blob_path, last_blob_path, + os.path.join(temp_dir, "decoder_layer")) except: invalidInputError(False, "False to InitLLMPipeline.") diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index ba88ffef12a..9392c8470fd 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -15,10 +15,13 @@ # +import torch import numpy as np from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory from typing import Sequence from intel_npu_acceleration_library.backend.factory import NNFactory +import os +from .common import update_names_of_IR_and_export_blob class LowBitLlamaLMHead(LLMBaseNNFactory): @@ -120,3 +123,158 @@ def __init__( print("start compiling") self.compile() + + +def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir): + num_heads = model.model.layers[0].self_attn.num_heads + num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads + head_dim = model.model.layers[0].self_attn.head_dim + rms_norm_eps = model.config.rms_norm_eps + vocab_size = model.config.vocab_size + model_norm = model.model.norm + lm_head = model.lm_head + if n_splits_linear == 1: + weights = [(lm_head.weight, lm_head.scale)] + else: + lm_heads = lm_head.lm_heads + lm_head_weights = [] + scales = [] + for i in range(n_splits_linear): + lm_head_weights.append(lm_heads[i].weight) + scales.append(lm_heads[i].scale) + weights = [(torch.stack(lm_head_weights, axis=0), + torch.stack(scales, axis=0))] + if isinstance(weights[0], tuple): + np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + new_lm_head = LowBitLlamaLMHead( + [1, 1, num_heads * head_dim], + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + max_seq_len=1, + rms_norm_eps=rms_norm_eps, + mode="decode", + transpose_value=False, + dtype=np_dtype, + model_norm_weight=model_norm.weight.to(torch.float16), + vocab_size=vocab_size, + n_splits=n_splits_linear + ) + last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) + + # save weights bins files + if n_splits_linear == 1: + weight_numpy = [ + lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), + ] + else: + weight_numpy = [v.numpy() for v in weights[0]] + + for idx, weight in enumerate(weight_numpy): + bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") + weight.tofile(bin_file) + + embedding_layer = model.model.embed_tokens + new_embedding = LlamaEmbedding( + vocab_size=model.config.vocab_size, + embedding_dim=model.config.hidden_size, + embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), + padding_idx=model.config.pad_token_id, + dtype=np.float16, + ) + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir) + return first_blob_path, last_blob_path + + +def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj, + temp_dir, weight_dir, transpose_value_cache, kv_len, group_size): + num_heads = model.model.layers[0].self_attn.num_heads + num_key_value_heads = model.model.layers[0].self_attn.num_key_value_heads + head_dim = model.model.layers[0].self_attn.head_dim + intermediate_size = model.config.intermediate_size + rms_norm_eps = model.config.rms_norm_eps + + from ipex_llm.transformers.npu_models.llama_mp import LowBitLlamaMultiDecoderlayer + curr_layer = model.model.layers[layer_idx] + attn_layer = curr_layer.self_attn + mlp_layer = curr_layer.mlp + + weights = [] + if n_splits_linear == 1: + for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list, + attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, + attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, + mlp_layer.up_proj_dq_list): + weights.append((q.weight, q.scale)) + weights.append((k.weight, k.scale)) + weights.append((v.weight, v.scale)) + weights.append((o.weight, o.scale)) + weights.append((g.weight, g.scale)) + weights.append((u.weight, u.scale)) + else: + for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, + attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, + mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: + l_weights = [] + scales = [] + for l in layer_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), + torch.stack(scales, axis=0))) + + if n_splits_down_proj == 1: + for l in mlp_layer.down_proj_dq_list: + weights.append((l.weight, l.scale)) + else: + l_weights = [] + scales = [] + for l in mlp_layer.down_proj_dq_list: + l_weights.append(l.weight) + scales.append(l.scale) + weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) + + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) + layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + + if isinstance(weights[0], tuple): + np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + single_decoder = LowBitLlamaMultiDecoderlayer( + [1, 1, num_heads * head_dim], + input_layernorm_weights=[layer_norm_0], + post_attn_layernorm_weights=[layer_norm_1], + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + num_layers=1, + max_seq_len=kv_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + mode="decode", + transpose_value=transpose_value_cache, + dtype=np_dtype, + n_splits_linear=n_splits_linear, + n_splits_down_proj=n_splits_down_proj, + group_size=group_size + ) + rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, + f"decoder_layer_{layer_idx}", + temp_dir) + + for idx, (weight, scale) in enumerate(weights): + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2}.bin") + weight.numpy().tofile(bin_file) + bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{5+idx*2+1}.bin") + scale.numpy().tofile(bin_file) + del single_decoder