From d6b4c0b4c24e835c467aaebf38911b15f9cd6098 Mon Sep 17 00:00:00 2001 From: jenniew Date: Thu, 22 Aug 2024 15:54:30 -0700 Subject: [PATCH 1/8] update --- .../transformers/npu_models/convert_mp.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 8d4755d5010..ffaf9156541 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -14,6 +14,7 @@ # limitations under the License. import torch +import importlib def convert_forward(m, target_m, new_forward): @@ -54,3 +55,26 @@ def optimize_llm( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, LlamaModel, llama_model_forward) + if model.config.model_type == "baichuan": + from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward + from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner + + decode_runner = DecodeRunner( + model, + max_seq_len=max_output_len, + inter_pp=inter_pp, + intra_pp=intra_pp, + transpose_value_cache=transpose_value_cache, + ) + prefill_runner = PrefillRunner( + model, + max_output_len=max_output_len, + max_prompt_len=max_prompt_len, + transpose_value_cache=transpose_value_cache, + ) + baichuan_model_forward = gen_baichuan_fused_model_forward( + prefill_runner=prefill_runner, decode_runner=decode_runner + ) + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + convert_forward(model, module.BaichuanModel, baichuan_model_forward) From 5c0b9a2a3d39ae0fb05eade31fcbdcf3c290437e Mon Sep 17 00:00:00 2001 From: jenniew Date: Mon, 26 Aug 2024 12:04:54 -0700 Subject: [PATCH 2/8] add baichuan mp --- .../LLM/baichuan2.py | 107 ++ .../transformers/npu_models/baichuan_mp.py | 1305 +++++++++++++++++ 2 files changed, 1412 insertions(+) create mode 100644 python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py new file mode 100644 index 00000000000..a53de2ab06e --- /dev/null +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py @@ -0,0 +1,107 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import torch +import time +import argparse + +from ipex_llm.transformers.npu_model import AutoModelForCausalLM +from transformers import AutoTokenizer + +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +def get_prompt(message: str, chat_history: list[tuple[str, str]], + system_prompt: str) -> str: + texts = [f'[INST] <>\n{system_prompt}\n<>\n\n'] + # The first user input is _not_ stripped + do_strip = False + for user_input, response in chat_history: + user_input = user_input.strip() if do_strip else user_input + do_strip = True + texts.append(f'{user_input} [/INST] {response.strip()} [INST] ') + message = message.strip() if do_strip else message + texts.append(f'{message} [/INST]') + return ''.join(texts) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Predict Tokens using `generate()` API for npu model" + ) + parser.add_argument( + "--repo-id-or-model-path", + type=str, + default="meta-llama/Llama-2-7b-chat-hf", + help="The huggingface repo id for the Llama2 model to be downloaded" + ", or the path to the huggingface checkpoint folder", + ) + parser.add_argument('--prompt', type=str, default="What is AI?", + help='Prompt to infer') + parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") + parser.add_argument("--max-output-len", type=int, default=1024) + parser.add_argument("--max-prompt-len", type=int, default=768) + parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) + parser.add_argument("--intra-pp", type=int, default=2) + parser.add_argument("--inter-pp", type=int, default=2) + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + attn_implementation="eager", + load_in_low_bit="sym_int4", + enable_mp=True, + max_output_len=args.max_output_len, + max_prompt_len=args.max_prompt_len, + intra_pp=args.intra_pp, + inter_pp=args.inter_pp, + transpose_value_cache=not args.disable_transpose_value_cache, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + DEFAULT_SYSTEM_PROMPT = """\ + """ + + print("-" * 80) + print("done") + with torch.inference_mode(): + print("finish to load") + for i in range(5): + prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT) + _input_ids = tokenizer.encode(prompt, return_tensors="pt") + print("input length:", len(_input_ids[0])) + st = time.time() + output = model.generate( + _input_ids, num_beams=1, do_sample=False, max_new_tokens=args.n_predict + ) + end = time.time() + print(f"Inference time: {end-st} s") + input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False) + print("-" * 20, "Input", "-" * 20) + print(input_str) + output_str = tokenizer.decode(output[0], skip_special_tokens=False) + print("-" * 20, "Output", "-" * 20) + print(output_str) + + print("-" * 80) + print("done") + print("success shut down") diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py new file mode 100644 index 00000000000..f94f8715e7a --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -0,0 +1,1305 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import torch +import time +import argparse + +from ipex_llm.transformers.npu_model import AutoModelForCausalLM +from transformers import AutoTokenizer +from intel_npu_acceleration_library.backend.factory import NNFactory +from typing import Optional, Sequence, List, Union, Any, Tuple +import numpy as np +import math +from intel_npu_acceleration_library.backend.runtime import set_contiguous, record_function +from intel_npu_acceleration_library.backend.runtime import adapt_output_tensor, _model_cache +from collections import deque +from transformers.cache_utils import Cache +from intel_npu_acceleration_library.backend.bindings import lib as backend_lib +import ctypes +from ipex_llm.utils.common import invalidInputError +from typing import Optional, List, Generator +import uuid +from functools import partial +import torch.nn.functional as F +import torch.nn.parallel +import torch.distributed as dist +from filelock import FileLock + +from transformers.utils import logging + +logger = logging.get_logger(__name__) +import gc +from colorama import Fore, Back, Style +import torch.multiprocessing as mp +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast + + +@torch.no_grad() +def run_model( + x: Union[torch.Tensor, List[torch.Tensor]], + weights: List[torch.Tensor], + backend_cls: Any, + op_id: str, + replica: int = 1, +) -> torch.Tensor: + global _model_cache + import time + + t0 = time.perf_counter() + + # Use or not op_id depending on the class used + op_kwargs = {"op_id": op_id} if op_id else {} + + if not isinstance(x, (list, tuple)): + x = [x] + + # Reshape input + input_dtype = x[0].dtype + x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] + op_args = [] + op_args_flatten = [] + for w in weights: + if isinstance(w, tuple): # from QuantizedLinear + op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) + op_args_flatten.append(op_args[-1][0]) + op_args_flatten.append(op_args[-1][1]) + else: + op_args.append(set_contiguous(w).to(torch.float16).numpy()) + op_args_flatten.append(op_args[-1]) + + shape_dtype_signature = "_".join( + ["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten] + ) + key = f"{backend_cls.func.__name__}_{shape_dtype_signature}" + models = _model_cache.get(key, None) + + input_shapes = [elem.shape for elem in x_np] + if models is None: + _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(replica)]) + elif len(models) < 1: + _model_cache[key].append(backend_cls(*input_shapes)) + else: + _model_cache[key].rotate(1) + + # Get the model + model = _model_cache[key][0] + + with record_function(f"npu_factory_mul_{key}"): + ret = model.run(x_np, *op_args, **op_kwargs) + + if isinstance(ret, list): + results = [adapt_output_tensor(r, r.shape, input_dtype) for r in ret] + else: + results = adapt_output_tensor(ret, ret.shape, input_dtype) + + return results + + +class LowBitLlamaMultiDecoderlayer(NNFactory): + def __init__( + self, + # batch_size: int, + # seq_len: int, + # hidden_size: int, + hidden_shape: Sequence[int], + *shapes, + num_heads: int, + # num_key_value_heads: int, + num_layers: int, + cached_cos, + cached_sin, + input_layernorm_weights=None, + post_attn_layernorm_weights=None, + mode: str = "prefill", + dtype: np.dtype = np.int8, + max_seq_len: int = 1024, + transpose_value: bool = False, + profile: bool = False, + device: str = "NPU", + rms_norm_eps, + intermediate_size, + ): + super().__init__(profile, device) + self.max_seq_len = max_seq_len + self.intermediate_size = intermediate_size + self.dtype = dtype + self.cached_cos = cached_cos + self.cached_sin = cached_sin + self.batch_size, self.seq_len, self.hidden_size = hidden_shape + self.mode = mode + self.rms_norm_eps = rms_norm_eps + self.transpose_value = transpose_value + + cos = self.constant(self.cached_cos) + self.cos = self.unsqueeze(cos, axis=0) + + sin = self.constant(self.cached_sin) + self.sin = self.unsqueeze(sin, axis=0) + + if mode == "decode": + self.kv_seq_len = self.max_seq_len + 1 + else: + self.kv_seq_len = self.seq_len + + self.num_heads = num_heads + # self.num_key_value_heads = num_key_value_heads + + self.head_dim = self.hidden_size // self.num_heads + # self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + # define input, the order self.parameter matters + input = self.parameter((self.batch_size, self.seq_len, self.hidden_size)) + + # Self Attention + if mode == "decode": + attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1)) + else: + attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len)) + + position_ids = self.parameter((self.batch_size, self.seq_len)) + past_keys = [] + past_values = [] + if mode == "decode": + for i in range(num_layers): + past_key = self.parameter( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + if transpose_value: + past_value = self.parameter( + (self.batch_size, self.num_heads, self.head_dim, self.max_seq_len) + ) + else: + past_value = self.parameter( + (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) + ) + past_keys.append(past_key) + past_values.append(past_value) + else: + past_keys = [None] * num_layers + past_values = [None] * num_layers + + if input_layernorm_weights is None: + input_layernorm_weights = [] + post_attn_layernorm_weights = [] + for i in range(num_layers): + input_layernorm_weights.append( + self.parameter( + ( + 1, + self.hidden_size, + ) + ) + ) + post_attn_layernorm_weights.append( + self.parameter( + ( + 1, + self.hidden_size, + ) + ) + ) + else: + input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights] + post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights] + + hidden_states = input + + curr_key_values = [] + for i in range(num_layers): + hidden_states, new_key_states, new_value_states = self.build_decoder( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + input_layernorm_weight=input_layernorm_weights[i], + post_attention_layernorm_weight=post_attn_layernorm_weights[i], + past_key=past_keys[i], + past_value=past_values[i], + ) + curr_key_values.append((new_key_states, new_value_states)) + + # define outputs + hidden_states = self.convert_to_fp16(hidden_states) + + for i in range(num_layers): + new_key_states = self.convert_to_fp16(curr_key_values[i][0]) + new_value_states = self.convert_to_fp16(curr_key_values[i][1]) + + print("start compiling") + self.compile() + + # def repeat_kv(self, hidden_states, n_rep, transpose=False): + # if n_rep == 1: + # return hidden_states + # if not transpose: + # hidden_states = self.reshape( + # hidden_states, + # [self.batch_size, self.num_key_value_heads, 1, self.kv_seq_len, self.head_dim], + # ) + # hidden_states = self.broadcast( + # hidden_states, + # [self.batch_size, self.num_key_value_heads, n_rep, self.kv_seq_len, self.head_dim], + # ) + # hidden_states = self.reshape( + # hidden_states, + # [self.batch_size, n_rep * self.num_key_value_heads, self.kv_seq_len, self.head_dim], + # ) + # else: + # hidden_states = self.reshape( + # hidden_states, + # [self.batch_size, self.num_key_value_heads, 1, self.head_dim, self.kv_seq_len], + # ) + # hidden_states = self.broadcast( + # hidden_states, + # [self.batch_size, self.num_key_value_heads, n_rep, self.head_dim, self.kv_seq_len], + # ) + # hidden_states = self.reshape( + # hidden_states, + # [self.batch_size, n_rep * self.num_key_value_heads, self.head_dim, self.kv_seq_len], + # ) + # return hidden_states + + def build_decoder( + self, + hidden_states, + attention_mask, + position_ids, + input_layernorm_weight, + post_attention_layernorm_weight, + past_key=None, + past_value=None, + ): + + residual = hidden_states + + input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size)) + + # input layernorm + input_2d = self.convert_to_fp32(input_2d) + variance = self.reduce_mean( + self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), + -1, + keep_dims=True, + ) + eps = self.constant(self.rms_norm_eps) + input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps))) + input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight) + input_2d = self.eltwise_mul(input_layernorm_weight, input_2d) + input_2d = self.convert_to_fp16(input_2d) + + # attention + proj = self.linear(input_2d, 3 * self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype) + # proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + proj = self.reshape(proj, [-1, 3, self.hidden_size]) # b*s, 3, h + proj = self.unsqueeze(proj, [0]) # b, s, 3, h + proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h + proj = self.squeeze(proj) # 3, b*s, h + print("proj shape: ", proj.shape) + proj = self.unsqueeze(proj, [1]) + print("proj shape after unsqueeze", proj.shape) + # query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.reshape(proj[0, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + query_states = self.transpose(query_states, [0, 2, 1, 3]) + # key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.reshape(proj[1, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + key_states = self.transpose(key_states, [0, 2, 1, 3]) + # value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.reshape(proj[2, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + if self.transpose_value: + value_states = self.transpose(value_states, [0, 2, 3, 1]) + else: + value_states = self.transpose(value_states, [0, 2, 1, 3]) + + # query_states = self.linear( + # input_2d, + # self.num_heads * self.head_dim, + # self.hidden_size, + # bias=False, + # wt_dtype=self.dtype, + # ) + # key_states = self.linear( + # input_2d, + # self.num_key_value_heads * self.head_dim, + # self.hidden_size, + # bias=False, + # wt_dtype=self.dtype, + # ) + # value_states = self.linear( + # input_2d, + # self.num_key_value_heads * self.head_dim, + # self.hidden_size, + # bias=False, + # wt_dtype=self.dtype, + # ) + + # query_states = self.reshape( + # query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim] + # ) + # key_states = self.reshape( + # key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim] + # ) + # value_states = self.reshape( + # value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim] + # ) + # + # query_states = self.transpose(query_states, [0, 2, 1, 3]) + # key_states = self.transpose(key_states, [0, 2, 1, 3]) + # if self.transpose_value: + # value_states = self.transpose(value_states, [0, 2, 3, 1]) + # else: + # value_states = self.transpose(value_states, [0, 2, 1, 3]) + cos = self.unsqueeze(self.squeeze(self.cos), [0]) + sin = self.unsqueeze(self.squeeze(self.sin), [0]) + query_states, key_states = self.apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + new_key_states = key_states + new_value_states = value_states + + if self.mode == "decode": + key_states = self.concat(past_key, key_states, axis=-2) + if self.transpose_value: + value_states = self.concat(past_value, value_states, axis=-1) + else: + value_states = self.concat(past_value, value_states, axis=-2) + + # key_states = self.repeat_kv(key_states, self.num_key_value_groups) + # value_states = self.repeat_kv(value_states, self.num_key_value_groups, self.transpose_value) + + # if query_states.size(2) == key_states.size(2): + # # first token + # from intel_npu_acceleration_library.functional import scaled_dot_product_attention + # attn_output = scaled_dot_product_attention( + # query_states, + # key_states, + # value_states, + # attn_mask=attention_mask + # ) + # attn_weights = None + # else: + attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim)) + attn_weight = self.eltwise_add(attn_weight, attention_mask) + attn_weight = self.convert_to_fp32(attn_weight) + attn_weight = self.softmax(attn_weight, -1) + attn_weight = self.convert_to_fp16(attn_weight) + attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) + + # attn_weight = self.matmul(query_states, key_states, False, True) / ( + # math.sqrt(self.head_dim) + # ) + # attn_weight = self.eltwise_add(attn_weight, attention_mask) + # attn_weight = self.convert_to_fp32(attn_weight) + # attn_weight = self.softmax(attn_weight, -1) + # attn_weight = self.convert_to_fp16(attn_weight) + # attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) + + attn_output = self.transpose(attn_output, [0, 2, 1, 3]) + attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size]) + + attn_output = self.linear( + attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype + ) + + hidden_states = self.eltwise_add(residual, attn_output) + + # Fully Connected + residual = hidden_states + # post_attention_layernorm forward + + hidden_states = self.convert_to_fp32(hidden_states) + variance = self.reduce_mean( + self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), + -1, + keep_dims=True, + ) + hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) + post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight) + hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states) + hidden_states = self.convert_to_fp16(hidden_states) + + # mlp + # gate proj + mm1 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, + bias=False, wt_dtype=self.dtype) + # up proj + mm2 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, + bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined] + mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] + # down proj + hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype) + + hidden_states = self.eltwise_add(residual, hidden_states) + hidden_states = self.convert_to_fp16(hidden_states) + + return hidden_states, new_key_states, new_value_states + + def rotate_half(self, x): + x1 = self.slice( + x, + [0, 0, 0, 0], + [self.batch_size, self.num_heads, self.seq_len, self.head_dim // 2], + ) + x2 = self.slice( + x, + [0, 0, 0, self.head_dim // 2], + [self.batch_size, self.num_heads, self.seq_len, self.head_dim], + ) + return self.concat(self.negative(x2), x1, axis=-1) + + def apply_rotary_pos_emb2(self, q, k, cos, sin, position_ids): + + cos = self.squeeze(cos) # [seq_len, dim] + sin = self.squeeze(sin) # [seq_len, dim] + # cos = cos[position_ids] + cos = self.unsqueeze(cos, [0, 1]) # [bs, 1, seq_len, dim] + # sin = sin[position_ids] + sin = self.unsqueeze(sin, [0, 1]) # [bs, 1, seq_len, dim] + + q_embed = self.eltwise_add( + self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin) + ) + k_embed = self.eltwise_add( + self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin) + ) + + return q_embed, k_embed + + def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids): + position_ids = self.squeeze(position_ids) + cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) + sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) + cos = self.unsqueeze(cos, [1]) + sin = self.unsqueeze(sin, [1]) + + q_embed = self.eltwise_add( + self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin) + ) + k_embed = self.eltwise_add( + self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin) + ) + + return q_embed, k_embed + + +class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): + + def __init__( + self, + parameters: List[Tuple[torch.Tensor]], + input_laynorm_weights: List[torch.Tensor], + post_attn_layernorm_weights: List[torch.Tensor], + layer_indexes: List[int], + intra_stages: int, + cached_cos: torch.Tensor, + cached_sin: torch.Tensor, + num_heads: int, + head_dim: int, + # num_key_value_heads: int, + rms_norm_eps, + intermediate_size, + max_seq_len: int = 1024, + transpose_value: bool = False, + do_print: bool = False, + ): + super().__init__() + + self.do_print = do_print + + op_parameters = [] + for w in parameters: + if isinstance(w, tuple): # from QuantizedLinear + op_parameters.append((w[0].numpy(), w[1].numpy())) + else: + op_parameters.append(w.to(torch.float16).numpy()) + self.op_parameters = op_parameters + self.op_id = str(uuid.uuid4()) + self.max_seq_len = max_seq_len + self.transpose_value = transpose_value + if isinstance(parameters[0], tuple): + np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + self.intra_stages = intra_stages + self.layer_indexes = layer_indexes + self.num_layers_1 = len(self.layer_indexes) // 2 + self.num_layers_0 = len(self.layer_indexes) - self.num_layers_1 + num_layers = len(self.layer_indexes) // intra_stages + self.layer_ranges = [] + for i in range(intra_stages): + if i == intra_stages - 1: + self.layer_ranges.append((i * num_layers, len(self.layer_indexes))) + else: + self.layer_ranges.append((i * num_layers, (i + 1) * num_layers)) + + self.backend_decoders = [] + + for i in range(intra_stages): + start, end = self.layer_ranges[i] + lm_0 = input_laynorm_weights[start:end] + lm_1 = post_attn_layernorm_weights[start:end] + decoder = LowBitLlamaMultiDecoderlayer( + [1, 1, num_heads * head_dim], + input_layernorm_weights=lm_0, + post_attn_layernorm_weights=lm_1, + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + # num_key_value_heads=num_key_value_heads, + num_layers=end - start, + max_seq_len=max_seq_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + mode="decode", + transpose_value=self.transpose_value, + dtype=np_dtype, + ) + self.backend_decoders.append(decoder) + + for i in range(intra_stages): + start, end = self.layer_ranges[i] + num_intra_layers = end - start + self.backend_decoders[i].setWeights( + 3 + (num_intra_layers) * 2, self.op_id, *op_parameters[start * 5:end * 5] + ) + with FileLock(f"decoder_run.lock"): + backend_lib.run(self.backend_decoders[i]._mm) + + self.kv_cache_c_parameter_handel = [] + self.kv_cache_parameters = [] + self.kv_cache_prefetched = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> torch.Tensor: + + inputs = ( + hidden_states.to(torch.float16), + attention_mask, + position_ids, + ) + + if len(self.kv_cache_parameters) > 0: + # the case kv cache changed + cached_prt = self.kv_cache_parameters[0].storage().data_ptr() + current_ptr = past_key_value.key_cache[self.layer_indexes[0]].storage().data_ptr() + if cached_prt != current_ptr: + self.kv_cache_parameters = [] + self.kv_cache_c_parameter_handel = [] + self.kv_cache_prefetched = False + + if len(self.kv_cache_parameters) == 0: + for idx in self.layer_indexes: + past_key = past_key_value.key_cache[idx] + past_value = past_key_value.value_cache[idx] + + invalidInputError( + past_key.dtype == torch.float16, f"past_key dtype is {past_key.dtype}" + ) + + new_size = (past_key.size(0), past_key.size(1), self.max_seq_len, past_key.size(3)) + past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0) + invalidInputError(past_key.is_contiguous(), "past_key is not contiguous") + past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0) + if self.transpose_value: + past_value = past_value.transpose(-1, -2) + invalidInputError(past_value.is_contiguous(), "past_value is not contiguous") + + self.kv_cache_parameters.append(past_key) + self.kv_cache_parameters.append(past_value) + + for i in range(self.intra_stages): + start, end = self.layer_ranges[i] + layer_kv_cache = self.kv_cache_parameters[start * 2:end * 2] + layer_kv_cache = [p.numpy() for p in layer_kv_cache] + handle = self.backend_decoders[i].create_parameters(layer_kv_cache) + self.kv_cache_c_parameter_handel.append(handle) + + x_np = [elem.to(torch.float16).numpy() for elem in inputs] + + with record_function(f"npu_factory"): + if not self.kv_cache_prefetched: + for i in range(self.intra_stages): + self.backend_decoders[i].load_wt_fn( + len(inputs), + self.backend_decoders[i]._mm, + self.kv_cache_c_parameter_handel[i], + ) + + array_type = ctypes.POINTER(ctypes.c_char) * self.intra_stages + models_ptr = array_type( + *[self.backend_decoders[i]._mm for i in range(self.intra_stages)] + ) + inputs_ptr = (ctypes.c_void_p * 3)( + x_np[0].ctypes.data_as(ctypes.c_void_p), + x_np[1].ctypes.data_as(ctypes.c_void_p), + x_np[2].ctypes.data_as(ctypes.c_void_p), + ) + t0 = time.perf_counter() + backend_lib.run_decoders(models_ptr, inputs_ptr, self.intra_stages, 3) + t1 = time.perf_counter() + + hidden_states = self.backend_decoders[-1].torch_out[0] + + if self.do_print: + print("outputs:", hidden_states) + + outputs = (hidden_states,) + outputs += (past_key_value,) + return outputs, t1 - t0 + + def post_forward(self, past_key_value): + key_value_states = [] + for i in range(self.intra_stages): + for j in range(1, len(self.backend_decoders[i].torch_out)): + key_value_states.append(self.backend_decoders[i].torch_out[j]) + + cache_kwargs = { + # "cache_position": cache_position, + "max_seq_len": self.max_seq_len, + "transpose": self.transpose_value, + } + for i in range(len(self.layer_indexes)): + key_states, value_states = past_key_value.update( + key_value_states[2 * i], + key_value_states[2 * i + 1], + self.layer_indexes[i], + cache_kwargs, + ) + + for i in range(self.intra_stages): + self.backend_decoders[i].load_wt_fn( + 3, self.backend_decoders[i]._mm, self.kv_cache_c_parameter_handel[i] + ) + self.kv_cache_prefetched = True + + +class FusedLlamaLowBitDecoderlayer(torch.nn.Module): + """LLAMA MLP operation NPU backend.""" + + def __init__( + self, + parameters: List[torch.Tensor], + cached_cos, + cached_sin, + layer_norm_0, + layer_norm_1, + num_heads: int, + # num_key_value_heads: int, + layer_idx: int, + rms_norm_eps, + intermediate_size, + max_seq_len: int = 128, + transpose_value: bool = False, + ): + super().__init__() + self.op_parameters = parameters + self.op_id = str(uuid.uuid4()) + self.layer_idx = layer_idx + self.max_seq_len = max_seq_len + self.transpose_value = transpose_value + # self.rotary_emb = rotary_emb + if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear + np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + else: # FP16 Linear + np_dtype = np.float16 + + self.backend_cls_prefill = partial( + LowBitLlamaMultiDecoderlayer, + num_heads=num_heads, + # num_key_value_heads=num_key_value_heads, + num_layers=1, + cached_cos=cached_cos, + cached_sin=cached_sin, + input_layernorm_weights=None, + post_attn_layernorm_weights=None, + max_seq_len=max_seq_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + mode="prefill", + transpose_value=self.transpose_value, + dtype=np_dtype, + ) + self.layer_norm_0 = layer_norm_0 + self.layer_norm_1 = layer_norm_1 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + # cache_position: Optional[torch.LongTensor] = None, + # **kwargs, + ) -> torch.Tensor: + """Torch module forward method. + + Args: + x (torch.Tensor): Input tensor + + Returns: + torch.Tensor: result + """ + + seq_len = hidden_states.shape[1] + + backend_cls = self.backend_cls_prefill + inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) + inputs += (self.layer_norm_0, self.layer_norm_1) + hidden_states, past_key, past_value = run_model( + inputs, self.op_parameters, backend_cls, self.op_id, replica=2 + ) + cache_kwargs = { + # "cache_position": cache_position, + "max_seq_len": self.max_seq_len, + "transpose": self.transpose_value, + } + key_states, value_states = past_key_value.update( + past_key, past_value, self.layer_idx, cache_kwargs + ) + + outputs = (hidden_states,) + outputs += (past_key_value,) + return outputs + + +def run_decode( + model, + rank, + world_size, + port, + layer_start, + layer_end, + intra_stages, + max_seq_len, + transpose_value_cache, + input_queue, + result_queue, +): + + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = port + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + print("start init process group, rank: ", rank, "world_size: ", world_size) + + dist.init_process_group() + my_rank = dist.get_rank() + my_size = dist.get_world_size() + logger.info(f"rank: {my_rank}, size: {my_size}") + + num_heads = model.model.layers[layer_start].self_attn.num_heads + # num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads + head_dim = model.model.layers[layer_start].self_attn.head_dim + rms_norm_eps = model.config.rms_norm_eps + intermediate_size = model.config.intermediate_size + deocderlayers = [] + layer_weights = [] + input_layer_norm_weights = [] + post_attn_layernorm_weights = [] + layer_indexs = range(layer_start, layer_end) + for layer_idx in layer_indexs: + curr_layer = model.model.layers[layer_idx] + attn_layer = curr_layer.self_attn + mlp_layer = curr_layer.mlp + + weights = [ + (attn_layer.W_pack.weight, attn_layer.W_pack.scale), + (attn_layer.o_proj.weight, attn_layer.o_proj.scale), + (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), + (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), + (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), + ] + + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) + layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + + layer_weights.extend(weights) + input_layer_norm_weights.append(layer_norm_0) + post_attn_layernorm_weights.append(layer_norm_1) + + multi_decoder = FusedLlamaLowBitMultiDecoderlayer( + parameters=layer_weights, + input_laynorm_weights=input_layer_norm_weights, + post_attn_layernorm_weights=post_attn_layernorm_weights, + layer_indexes=layer_indexs, + intra_stages=intra_stages, + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + head_dim=head_dim, + # num_key_value_heads=num_key_value_heads, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + max_seq_len=max_seq_len, + transpose_value=transpose_value_cache, + do_print=False, + ) + + dist.barrier() + + past_key_values = None + + control = torch.empty((), dtype=torch.int) + hidden_states = torch.empty((1, 1, head_dim * num_heads), dtype=torch.float16) + with torch.inference_mode(): + while True: + + dist.broadcast(control, src=0) + if control.item() == -2: + break + elif control.item() == -1: + past_key_values = input_queue.get() + else: + t0 = time.perf_counter() + # past_seen_tokens = past_key_values.get_seq_length() + # attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) + # cache_position = torch.arange( + # past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device + # ) + # + # position_ids = position_ids = cache_position.unsqueeze(0) + # causal_mask = model.model._update_causal_mask( + # attention_mask, hidden_states, cache_position, past_seen_tokens + # ) + past_key_values_length = past_key_values.get_seq_length() + seq_length_with_past = 1 + past_key_values_length + position_ids = torch.arange( + past_key_values_length, seq_length_with_past, dtype=torch.long + ) + position_ids = position_ids.unsqueeze(0).view(-1, 1) + attention_mask = torch.ones((1, seq_length_with_past), dtype=torch.bool) + attention_mask = model.model._prepare_decoder_attention_mask( + attention_mask, (1, 1), hidden_states, past_key_values_length + ) + + pad_len = multi_decoder.max_seq_len + 1 - attention_mask.size(-1) + + pad_mask = (0, pad_len) + padded_causal_mask = F.pad( + attention_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min + ) + padded_causal_mask[:, :, :, -1] = 0.0 + dist.recv(hidden_states, src=rank - 1) + t1 = time.perf_counter() + layer_outputs, elapse = multi_decoder( + hidden_states, + attention_mask=padded_causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + use_cache=True, + ) + t2 = time.perf_counter() + hidden_states = layer_outputs[0] + t3 = time.perf_counter() + dist.send(hidden_states, dst=(rank + 1) % world_size) + t4 = time.perf_counter() + multi_decoder.post_forward(past_key_values) + + +class DecodeRunner: + def __init__(self, model, max_seq_len, intra_pp=2, inter_pp=2, transpose_value_cache=True): + self.model = model + self.max_seq_len = max_seq_len + self.transpose_value_cache = transpose_value_cache + world_size = inter_pp + 1 + intra_stages = intra_pp + num_layers = self.model.model.config.num_hidden_layers + + port = "54791" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = port + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = str(world_size) + + self.input_queues = [] + self.output_queues = [] + self.decoder_processes = [] + + for rank in range(1, world_size): + input_q = mp.Queue() + output_q = mp.Queue() + start_layer = (rank - 1) * (num_layers // (world_size - 1)) + end_layer = (rank) * (num_layers // (world_size - 1)) + if rank == world_size - 1: + end_layer = num_layers + p = mp.Process( + target=run_decode, + args=( + self.model, + rank, + world_size, + port, + start_layer, + end_layer, + intra_stages, + self.max_seq_len, + self.transpose_value_cache, + input_q, + output_q, + ), + ) + p.daemon = True + p.start() + self.input_queues.append(input_q) + self.output_queues.append(output_q) + self.decoder_processes.append(p) + + dist.init_process_group() + my_rank = dist.get_rank() + self.world_size = dist.get_world_size() + logger.info(f"rank: {my_rank}, size: {self.world_size}") + + dist.barrier() + self.cache_past_key_value = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ): + t0 = time.perf_counter() + + if self.cache_past_key_value != past_key_value: + control = torch.tensor(-1, dtype=torch.int) + dist.broadcast(control, src=0) + for i in range(len(self.decoder_processes)): + self.input_queues[i].put(past_key_value) + + control = torch.tensor(0, dtype=torch.int) + dist.broadcast(control, src=0) + hidden_states = hidden_states.to(torch.float16) + dist.send(hidden_states, dst=1) + past_key_value.expand(self.transpose_value_cache) + dist.recv(hidden_states, src=self.world_size - 1) + t1 = time.perf_counter() + return hidden_states, past_key_value + + def shutdown(self): + control = torch.tensor(-2, dtype=torch.int) + dist.broadcast(control, src=0) + for p in self.decoder_processes: + p.join(3) + for p in self.decoder_processes: + if p.exitcode is None: + p.kill() + + def __del__(self): + self.shutdown() + + +def run_prefill( + model, max_output_len, max_prompt_len, transpose_value_cache, input_queue, result_queue +): + + layer_start = 0 + layer_end = len(model.model.layers) + num_heads = model.model.layers[layer_start].self_attn.num_heads + # num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads + head_dim = model.model.layers[layer_start].self_attn.head_dim + rms_norm_eps = model.config.rms_norm_eps + intermediate_size = model.config.intermediate_size + deocderlayers = [] + layer_weights = [] + input_layer_norm_weights = [] + post_attn_layernorm_weights = [] + layer_indexs = range(layer_start, layer_end) + for layer_idx in layer_indexs: + curr_layer = model.model.layers[layer_idx] + attn_layer = curr_layer.self_attn + mlp_layer = curr_layer.mlp + + weights = [ + (attn_layer.W_pack.weight, attn_layer.W_pack.scale), + (attn_layer.o_proj.weight, attn_layer.o_proj.scale), + (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), + (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), + (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), + ] + + cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) + cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) + + layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) + layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) + + new_decoderlayer = FusedLlamaLowBitDecoderlayer( + weights, + num_heads=num_heads, + # num_key_value_heads=num_key_value_heads, + cached_cos=cached_cos, + cached_sin=cached_sin, + layer_norm_0=layer_norm_0, + layer_norm_1=layer_norm_1, + layer_idx=layer_idx, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + max_seq_len=max_output_len, + transpose_value=transpose_value_cache, + ) + + layer_weights.extend(weights) + input_layer_norm_weights.append(layer_norm_0) + post_attn_layernorm_weights.append(layer_norm_1) + model.model.layers[layer_idx] = new_decoderlayer + deocderlayers.append(new_decoderlayer) + + print("finish creating all decode layers in prefill") + result_queue.put("loading finish") + + while True: + + result = input_queue.get() + if result == "stop": + break + + hidden_states, position_ids, causal_mask, past_key_values = result + with torch.inference_mode(): + for decoder_layer in deocderlayers: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + use_cache=True, + # cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + next_decoder_cache = layer_outputs[1] + + result_queue.put((hidden_states, next_decoder_cache)) + + +class PrefillRunner: + def __init__(self, model, max_output_len, max_prompt_len, transpose_value_cache): + self.model = model + self.max_output_len = max_output_len + self.max_prompt_len = max_prompt_len + self.transpose_value_cache = transpose_value_cache + + self.prefill_result_queue = mp.Queue() + self.prefill_input_queue = mp.Queue() + + self.p = mp.Process( + target=run_prefill, + args=( + model, + max_output_len, + max_prompt_len, + transpose_value_cache, + self.prefill_input_queue, + self.prefill_result_queue, + ), + ) + self.p.daemon = True + self.p.start() + output = self.prefill_result_queue.get() + print(Fore.GREEN + f"prefill process output: {output}") + print(Style.RESET_ALL) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ): + seq_len = hidden_states.size(1) + invalidInputError( + seq_len <= self.max_prompt_len, + ( + f"seq_len: {seq_len} should be less than or equal" + " to max_prompt_len {self.max_prompt_len}" + ), + ) + pad_len = self.max_prompt_len - seq_len + hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0) + position_ids = F.pad(position_ids, (0, pad_len), value=0) + attention_mask = F.pad( + attention_mask.to(torch.float16), + (0, pad_len, 0, pad_len), + value=torch.finfo(torch.float16).min, + ) + + args = (hidden_states, position_ids, attention_mask, past_key_value) + self.prefill_input_queue.put(args) + hidden_states, past_key_value = self.prefill_result_queue.get() + past_key_value.shrink(seq_len, self.transpose_value_cache) + hidden_states = hidden_states[:, :seq_len, :] + return hidden_states, past_key_value + + def shutdown(self): + self.prefill_input_queue.put("stop") + self.p.join(3) + if self.p.exitcode is None: + self.p.kill() + + def __del__(self): + self.shutdown() + + +def gen_baichuan_fused_model_forward(prefill_runner, decode_runner): + def baichuan_fused_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + t0 = time.perf_counter() + output_attentions = ( + output_attentions if output_attentions is not None else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + # if past_key_values is not None: + # past_key_values_length = past_key_values.get_seq_length() + # seq_length_with_past = seq_length_with_past + past_key_values_length + + # ipex-llm changes start + from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache + + if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache): + past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + if seq_length == 1: + layers_runner = decode_runner + else: + layers_runner = prefill_runner + layer_outputs = layers_runner.forward( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[1] + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # ipex-llm changes start + next_cache = next_decoder_cache if use_cache else None + # ipex-llm changes end + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + t1 = time.perf_counter() + # print("fused model forward time: ", t1 - t0) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + return baichuan_fused_model_forward From 22042ea23cd4c7fd73af1fe1e2138fa7187a05a0 Mon Sep 17 00:00:00 2001 From: jenniew Date: Mon, 26 Aug 2024 14:01:35 -0700 Subject: [PATCH 3/8] clean --- .../transformers/npu_models/baichuan_mp.py | 105 ------------------ .../transformers/npu_models/convert_mp.py | 2 - 2 files changed, 107 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index f94f8715e7a..b436c317e5f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -243,37 +243,6 @@ def __init__( print("start compiling") self.compile() - # def repeat_kv(self, hidden_states, n_rep, transpose=False): - # if n_rep == 1: - # return hidden_states - # if not transpose: - # hidden_states = self.reshape( - # hidden_states, - # [self.batch_size, self.num_key_value_heads, 1, self.kv_seq_len, self.head_dim], - # ) - # hidden_states = self.broadcast( - # hidden_states, - # [self.batch_size, self.num_key_value_heads, n_rep, self.kv_seq_len, self.head_dim], - # ) - # hidden_states = self.reshape( - # hidden_states, - # [self.batch_size, n_rep * self.num_key_value_heads, self.kv_seq_len, self.head_dim], - # ) - # else: - # hidden_states = self.reshape( - # hidden_states, - # [self.batch_size, self.num_key_value_heads, 1, self.head_dim, self.kv_seq_len], - # ) - # hidden_states = self.broadcast( - # hidden_states, - # [self.batch_size, self.num_key_value_heads, n_rep, self.head_dim, self.kv_seq_len], - # ) - # hidden_states = self.reshape( - # hidden_states, - # [self.batch_size, n_rep * self.num_key_value_heads, self.head_dim, self.kv_seq_len], - # ) - # return hidden_states - def build_decoder( self, hidden_states, @@ -325,44 +294,6 @@ def build_decoder( else: value_states = self.transpose(value_states, [0, 2, 1, 3]) - # query_states = self.linear( - # input_2d, - # self.num_heads * self.head_dim, - # self.hidden_size, - # bias=False, - # wt_dtype=self.dtype, - # ) - # key_states = self.linear( - # input_2d, - # self.num_key_value_heads * self.head_dim, - # self.hidden_size, - # bias=False, - # wt_dtype=self.dtype, - # ) - # value_states = self.linear( - # input_2d, - # self.num_key_value_heads * self.head_dim, - # self.hidden_size, - # bias=False, - # wt_dtype=self.dtype, - # ) - - # query_states = self.reshape( - # query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim] - # ) - # key_states = self.reshape( - # key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim] - # ) - # value_states = self.reshape( - # value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim] - # ) - # - # query_states = self.transpose(query_states, [0, 2, 1, 3]) - # key_states = self.transpose(key_states, [0, 2, 1, 3]) - # if self.transpose_value: - # value_states = self.transpose(value_states, [0, 2, 3, 1]) - # else: - # value_states = self.transpose(value_states, [0, 2, 1, 3]) cos = self.unsqueeze(self.squeeze(self.cos), [0]) sin = self.unsqueeze(self.squeeze(self.sin), [0]) query_states, key_states = self.apply_rotary_pos_emb( @@ -378,20 +309,6 @@ def build_decoder( else: value_states = self.concat(past_value, value_states, axis=-2) - # key_states = self.repeat_kv(key_states, self.num_key_value_groups) - # value_states = self.repeat_kv(value_states, self.num_key_value_groups, self.transpose_value) - - # if query_states.size(2) == key_states.size(2): - # # first token - # from intel_npu_acceleration_library.functional import scaled_dot_product_attention - # attn_output = scaled_dot_product_attention( - # query_states, - # key_states, - # value_states, - # attn_mask=attention_mask - # ) - # attn_weights = None - # else: attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim)) attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.convert_to_fp32(attn_weight) @@ -399,15 +316,6 @@ def build_decoder( attn_weight = self.convert_to_fp16(attn_weight) attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) - # attn_weight = self.matmul(query_states, key_states, False, True) / ( - # math.sqrt(self.head_dim) - # ) - # attn_weight = self.eltwise_add(attn_weight, attention_mask) - # attn_weight = self.convert_to_fp32(attn_weight) - # attn_weight = self.softmax(attn_weight, -1) - # attn_weight = self.convert_to_fp16(attn_weight) - # attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) - attn_output = self.transpose(attn_output, [0, 2, 1, 3]) attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size]) @@ -878,16 +786,6 @@ def run_decode( past_key_values = input_queue.get() else: t0 = time.perf_counter() - # past_seen_tokens = past_key_values.get_seq_length() - # attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) - # cache_position = torch.arange( - # past_seen_tokens, past_seen_tokens + 1, device=hidden_states.device - # ) - # - # position_ids = position_ids = cache_position.unsqueeze(0) - # causal_mask = model.model._update_causal_mask( - # attention_mask, hidden_states, cache_position, past_seen_tokens - # ) past_key_values_length = past_key_values.get_seq_length() seq_length_with_past = 1 + past_key_values_length position_ids = torch.arange( @@ -1210,9 +1108,6 @@ def baichuan_fused_model_forward( seq_length_with_past = seq_length past_key_values_length = 0 - # if past_key_values is not None: - # past_key_values_length = past_key_values.get_seq_length() - # seq_length_with_past = seq_length_with_past + past_key_values_length # ipex-llm changes start from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index d04a5c9c0be..1b7a9f4505c 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -77,12 +77,10 @@ def optimize_llm( max_prompt_len=max_prompt_len, transpose_value_cache=transpose_value_cache, ) - qwen2_model_forward = gen_qwen2_fused_model_forward( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, Qwen2Model, qwen2_model_forward) - from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward) From e20b00c85f6e06292afdcfedbb0dace5764e1729 Mon Sep 17 00:00:00 2001 From: jenniew Date: Tue, 27 Aug 2024 13:56:51 -0700 Subject: [PATCH 4/8] refactor --- .../transformers/npu_models/baichuan_mp.py | 326 ++++++------------ 1 file changed, 106 insertions(+), 220 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index b436c317e5f..f57293a1e87 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -48,7 +48,7 @@ import torch.multiprocessing as mp from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast - +from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory @torch.no_grad() def run_model( @@ -111,7 +111,7 @@ def run_model( return results -class LowBitLlamaMultiDecoderlayer(NNFactory): +class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): def __init__( self, # batch_size: int, @@ -120,7 +120,6 @@ def __init__( hidden_shape: Sequence[int], *shapes, num_heads: int, - # num_key_value_heads: int, num_layers: int, cached_cos, cached_sin, @@ -135,7 +134,12 @@ def __init__( rms_norm_eps, intermediate_size, ): - super().__init__(profile, device) + super().__init__(max_seq_len=max_seq_len, + transpose_value=transpose_value, + dtype=dtype, + profile=profile, + device=device) + # super().__init__(profile, device) self.max_seq_len = max_seq_len self.intermediate_size = intermediate_size self.dtype = dtype @@ -145,6 +149,7 @@ def __init__( self.mode = mode self.rms_norm_eps = rms_norm_eps self.transpose_value = transpose_value + self.num_layers = num_layers cos = self.constant(self.cached_cos) self.cos = self.unsqueeze(cos, axis=0) @@ -158,34 +163,32 @@ def __init__( self.kv_seq_len = self.seq_len self.num_heads = num_heads - # self.num_key_value_heads = num_key_value_heads self.head_dim = self.hidden_size // self.num_heads - # self.num_key_value_groups = self.num_heads // self.num_key_value_heads # define input, the order self.parameter matters - input = self.parameter((self.batch_size, self.seq_len, self.hidden_size)) + input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) # Self Attention if mode == "decode": - attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1)) + attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1)) else: - attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len)) + attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len)) - position_ids = self.parameter((self.batch_size, self.seq_len)) + position_ids = self.create_input_op((self.batch_size, self.seq_len)) past_keys = [] past_values = [] if mode == "decode": for i in range(num_layers): - past_key = self.parameter( + past_key = self.create_cache_op( (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) ) if transpose_value: - past_value = self.parameter( + past_value = self.create_cache_op( (self.batch_size, self.num_heads, self.head_dim, self.max_seq_len) ) else: - past_value = self.parameter( + past_value = self.create_cache_op( (self.batch_size, self.num_heads, self.max_seq_len, self.head_dim) ) past_keys.append(past_key) @@ -199,7 +202,7 @@ def __init__( post_attn_layernorm_weights = [] for i in range(num_layers): input_layernorm_weights.append( - self.parameter( + self.create_input_op( ( 1, self.hidden_size, @@ -207,7 +210,7 @@ def __init__( ) ) post_attn_layernorm_weights.append( - self.parameter( + self.create_input_op( ( 1, self.hidden_size, @@ -243,61 +246,50 @@ def __init__( print("start compiling") self.compile() - def build_decoder( - self, - hidden_states, - attention_mask, - position_ids, - input_layernorm_weight, - post_attention_layernorm_weight, - past_key=None, - past_value=None, - ): - - residual = hidden_states - - input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size)) - - # input layernorm - input_2d = self.convert_to_fp32(input_2d) - variance = self.reduce_mean( - self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), - -1, - keep_dims=True, - ) - eps = self.constant(self.rms_norm_eps) - input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps))) - input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight) - input_2d = self.eltwise_mul(input_layernorm_weight, input_2d) - input_2d = self.convert_to_fp16(input_2d) - - # attention - proj = self.linear(input_2d, 3 * self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype) - # proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) - proj = self.reshape(proj, [-1, 3, self.hidden_size]) # b*s, 3, h + def attention(self, + *, + hidden_states, + position_ids, + attention_mask, + past_key, + past_value, + cos, + sin, + mode, + num_heads, + head_dim, + seq_len, + q_bias=None, + k_bias=None, + v_bias=None): + hidden_size = num_heads * head_dim + proj = self.linear(hidden_states, 3 * hidden_size, hidden_size, bias=False, wt_dtype=self.dtype) + proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h proj = self.unsqueeze(proj, [0]) # b, s, 3, h proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h proj = self.squeeze(proj) # 3, b*s, h - print("proj shape: ", proj.shape) - proj = self.unsqueeze(proj, [1]) - print("proj shape after unsqueeze", proj.shape) - # query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = self.reshape(proj[0, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + query_states = self.reshape(proj[0, ...], [1, self.seq_len, num_heads, head_dim]) query_states = self.transpose(query_states, [0, 2, 1, 3]) - # key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.reshape(proj[1, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + key_states = self.reshape(proj[1, ...], [1, self.seq_len, num_heads, head_dim]) key_states = self.transpose(key_states, [0, 2, 1, 3]) - # value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.reshape(proj[2, ...], [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) + value_states = self.reshape(proj[2, ...], [1, self.seq_len, num_heads, head_dim]) if self.transpose_value: value_states = self.transpose(value_states, [0, 2, 3, 1]) else: value_states = self.transpose(value_states, [0, 2, 1, 3]) - cos = self.unsqueeze(self.squeeze(self.cos), [0]) - sin = self.unsqueeze(self.squeeze(self.sin), [0]) + cos = self.unsqueeze(self.squeeze(cos), [0]) + sin = self.unsqueeze(self.squeeze(sin), [0]) + query_states, key_states = self.apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids + q=query_states, + k=key_states, + cos=cos, + sin=sin, + position_ids=position_ids, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, ) new_key_states = key_states new_value_states = value_states @@ -309,7 +301,7 @@ def build_decoder( else: value_states = self.concat(past_value, value_states, axis=-2) - attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim)) + attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(head_dim)) attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.softmax(attn_weight, -1) @@ -317,92 +309,53 @@ def build_decoder( attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) attn_output = self.transpose(attn_output, [0, 2, 1, 3]) - attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size]) + attn_output = self.reshape(attn_output, [1, seq_len, hidden_size]) attn_output = self.linear( - attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=self.dtype + attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype ) + return attn_output, new_key_states, new_value_states - hidden_states = self.eltwise_add(residual, attn_output) + def build_decoder( + self, + hidden_states, + attention_mask, + position_ids, + input_layernorm_weight, + post_attention_layernorm_weight, + past_key=None, + past_value=None, + ): - # Fully Connected residual = hidden_states - # post_attention_layernorm forward - hidden_states = self.convert_to_fp32(hidden_states) - variance = self.reduce_mean( - self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), - -1, - keep_dims=True, - ) - hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) - post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight) - hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states) - hidden_states = self.convert_to_fp16(hidden_states) - - # mlp - # gate proj - mm1 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, - bias=False, wt_dtype=self.dtype) - # up proj - mm2 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, - bias=False, wt_dtype=self.dtype) # type: ignore[attr-defined] - mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] - # down proj - hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype) + input_2d = self.reshape(hidden_states, (self.batch_size * self.seq_len, self.hidden_size)) + input_2d = self.layer_norm(input_2d, input_layernorm_weight) + # attention + attn_output, new_key_states, new_value_states = self.attention( + hidden_states=input_2d, + position_ids=position_ids, + attention_mask=attention_mask, + past_key=past_key, + past_value=past_value, + cos=self.cos, + sin=self.sin, + mode=self.mode, + num_heads=self.num_heads, + head_dim=self.head_dim, + seq_len=self.seq_len, + ) + + hidden_states = self.eltwise_add(residual, attn_output) + residual = hidden_states + hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) + hidden_states = self.mlp(hidden_states) hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) return hidden_states, new_key_states, new_value_states - def rotate_half(self, x): - x1 = self.slice( - x, - [0, 0, 0, 0], - [self.batch_size, self.num_heads, self.seq_len, self.head_dim // 2], - ) - x2 = self.slice( - x, - [0, 0, 0, self.head_dim // 2], - [self.batch_size, self.num_heads, self.seq_len, self.head_dim], - ) - return self.concat(self.negative(x2), x1, axis=-1) - - def apply_rotary_pos_emb2(self, q, k, cos, sin, position_ids): - - cos = self.squeeze(cos) # [seq_len, dim] - sin = self.squeeze(sin) # [seq_len, dim] - # cos = cos[position_ids] - cos = self.unsqueeze(cos, [0, 1]) # [bs, 1, seq_len, dim] - # sin = sin[position_ids] - sin = self.unsqueeze(sin, [0, 1]) # [bs, 1, seq_len, dim] - - q_embed = self.eltwise_add( - self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin) - ) - k_embed = self.eltwise_add( - self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin) - ) - - return q_embed, k_embed - - def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids): - position_ids = self.squeeze(position_ids) - cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) - sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) - cos = self.unsqueeze(cos, [1]) - sin = self.unsqueeze(sin, [1]) - - q_embed = self.eltwise_add( - self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin) - ) - k_embed = self.eltwise_add( - self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin) - ) - - return q_embed, k_embed - class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): @@ -445,8 +398,6 @@ def __init__( self.intra_stages = intra_stages self.layer_indexes = layer_indexes - self.num_layers_1 = len(self.layer_indexes) // 2 - self.num_layers_0 = len(self.layer_indexes) - self.num_layers_1 num_layers = len(self.layer_indexes) // intra_stages self.layer_ranges = [] for i in range(intra_stages): @@ -481,16 +432,8 @@ def __init__( for i in range(intra_stages): start, end = self.layer_ranges[i] - num_intra_layers = end - start - self.backend_decoders[i].setWeights( - 3 + (num_intra_layers) * 2, self.op_id, *op_parameters[start * 5:end * 5] - ) - with FileLock(f"decoder_run.lock"): - backend_lib.run(self.backend_decoders[i]._mm) + self.backend_decoders[i].set_weights(self.op_id, op_parameters[start * 5:end * 5]) - self.kv_cache_c_parameter_handel = [] - self.kv_cache_parameters = [] - self.kv_cache_prefetched = False def forward( self, @@ -507,100 +450,43 @@ def forward( attention_mask, position_ids, ) + + for i in range(self.intra_stages): + start, end = self.layer_ranges[i] + self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end]) - if len(self.kv_cache_parameters) > 0: - # the case kv cache changed - cached_prt = self.kv_cache_parameters[0].storage().data_ptr() - current_ptr = past_key_value.key_cache[self.layer_indexes[0]].storage().data_ptr() - if cached_prt != current_ptr: - self.kv_cache_parameters = [] - self.kv_cache_c_parameter_handel = [] - self.kv_cache_prefetched = False - - if len(self.kv_cache_parameters) == 0: - for idx in self.layer_indexes: - past_key = past_key_value.key_cache[idx] - past_value = past_key_value.value_cache[idx] - - invalidInputError( - past_key.dtype == torch.float16, f"past_key dtype is {past_key.dtype}" - ) - - new_size = (past_key.size(0), past_key.size(1), self.max_seq_len, past_key.size(3)) - past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0) - invalidInputError(past_key.is_contiguous(), "past_key is not contiguous") - past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0) - if self.transpose_value: - past_value = past_value.transpose(-1, -2) - invalidInputError(past_value.is_contiguous(), "past_value is not contiguous") - - self.kv_cache_parameters.append(past_key) - self.kv_cache_parameters.append(past_value) - - for i in range(self.intra_stages): - start, end = self.layer_ranges[i] - layer_kv_cache = self.kv_cache_parameters[start * 2:end * 2] - layer_kv_cache = [p.numpy() for p in layer_kv_cache] - handle = self.backend_decoders[i].create_parameters(layer_kv_cache) - self.kv_cache_c_parameter_handel.append(handle) - - x_np = [elem.to(torch.float16).numpy() for elem in inputs] - - with record_function(f"npu_factory"): - if not self.kv_cache_prefetched: - for i in range(self.intra_stages): - self.backend_decoders[i].load_wt_fn( - len(inputs), - self.backend_decoders[i]._mm, - self.kv_cache_c_parameter_handel[i], - ) - - array_type = ctypes.POINTER(ctypes.c_char) * self.intra_stages - models_ptr = array_type( - *[self.backend_decoders[i]._mm for i in range(self.intra_stages)] - ) - inputs_ptr = (ctypes.c_void_p * 3)( - x_np[0].ctypes.data_as(ctypes.c_void_p), - x_np[1].ctypes.data_as(ctypes.c_void_p), - x_np[2].ctypes.data_as(ctypes.c_void_p), - ) - t0 = time.perf_counter() - backend_lib.run_decoders(models_ptr, inputs_ptr, self.intra_stages, 3) - t1 = time.perf_counter() - - hidden_states = self.backend_decoders[-1].torch_out[0] + hidden_states, new_keys, new_values = LowBitLlamaMultiDecoderlayer.run_decoders( + inputs, + decoders=self.backend_decoders) if self.do_print: print("outputs:", hidden_states) outputs = (hidden_states,) - outputs += (past_key_value,) - return outputs, t1 - t0 + outputs += (past_key_value, new_keys, new_values) + return outputs - def post_forward(self, past_key_value): + def post_forward(self, past_key_value, new_keys, new_values): key_value_states = [] for i in range(self.intra_stages): for j in range(1, len(self.backend_decoders[i].torch_out)): key_value_states.append(self.backend_decoders[i].torch_out[j]) cache_kwargs = { - # "cache_position": cache_position, "max_seq_len": self.max_seq_len, "transpose": self.transpose_value, } + for i in range(len(self.layer_indexes)): key_states, value_states = past_key_value.update( - key_value_states[2 * i], - key_value_states[2 * i + 1], + new_keys[i], + new_values[i], self.layer_indexes[i], cache_kwargs, ) for i in range(self.intra_stages): - self.backend_decoders[i].load_wt_fn( - 3, self.backend_decoders[i]._mm, self.kv_cache_c_parameter_handel[i] - ) - self.kv_cache_prefetched = True + self.backend_decoders[i].load_cache_async() class FusedLlamaLowBitDecoderlayer(torch.nn.Module): @@ -660,8 +546,6 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - # cache_position: Optional[torch.LongTensor] = None, - # **kwargs, ) -> torch.Tensor: """Torch module forward method. @@ -681,7 +565,6 @@ def forward( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 ) cache_kwargs = { - # "cache_position": cache_position, "max_seq_len": self.max_seq_len, "transpose": self.transpose_value, } @@ -806,7 +689,7 @@ def run_decode( padded_causal_mask[:, :, :, -1] = 0.0 dist.recv(hidden_states, src=rank - 1) t1 = time.perf_counter() - layer_outputs, elapse = multi_decoder( + layer_outputs = multi_decoder( hidden_states, attention_mask=padded_causal_mask, position_ids=position_ids, @@ -819,7 +702,10 @@ def run_decode( t3 = time.perf_counter() dist.send(hidden_states, dst=(rank + 1) % world_size) t4 = time.perf_counter() - multi_decoder.post_forward(past_key_values) + past_key_values = layer_outputs[1] + new_keys = layer_outputs[2] + new_values = layer_outputs[3] + multi_decoder.post_forward(past_key_values, new_keys, new_values) class DecodeRunner: From 1d5b3e429c5f97e502a57cfc37771a6d6a8ccec8 Mon Sep 17 00:00:00 2001 From: jenniew Date: Tue, 27 Aug 2024 14:15:19 -0700 Subject: [PATCH 5/8] merge --- .../NPU/HF-Transformers-AutoModels/LLM/baichuan2.py | 13 ------------- .../ipex_llm/transformers/npu_models/convert_mp.py | 9 --------- 2 files changed, 22 deletions(-) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py index bfd1e983369..f3f4cb109f6 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/baichuan2.py @@ -46,24 +46,15 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]], parser.add_argument( "--repo-id-or-model-path", type=str, -<<<<<<< HEAD - default="meta-llama/Llama-2-7b-chat-hf", - help="The huggingface repo id for the Llama2 model to be downloaded" -======= default="baichuan-inc/Baichuan2-7B-Chat", help="The huggingface repo id for the Baichuan2 model to be downloaded" ->>>>>>> 460bc96d3263f043f0df35c87e2731995db42f3f ", or the path to the huggingface checkpoint folder", ) parser.add_argument('--prompt', type=str, default="What is AI?", help='Prompt to infer') parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--max-output-len", type=int, default=1024) -<<<<<<< HEAD - parser.add_argument("--max-prompt-len", type=int, default=768) -======= parser.add_argument("--max-prompt-len", type=int, default=512) ->>>>>>> 460bc96d3263f043f0df35c87e2731995db42f3f parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--intra-pp", type=int, default=2) parser.add_argument("--inter-pp", type=int, default=2) @@ -77,11 +68,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]], trust_remote_code=True, attn_implementation="eager", load_in_low_bit="sym_int4", -<<<<<<< HEAD - enable_mp=True, -======= optimize_model=True, ->>>>>>> 460bc96d3263f043f0df35c87e2731995db42f3f max_output_len=args.max_output_len, max_prompt_len=args.max_prompt_len, intra_pp=args.intra_pp, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 7ba65d5bcca..5e75508513e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -140,16 +140,12 @@ def optimize_llm( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, module.MiniCPMModel, minicpm_model_forward) -<<<<<<< HEAD - elif model.config.model_type == "baichuan": -======= elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32: # for Baichuan2-7B if intra_pp is None: intra_pp = 2 if inter_pp is None: inter_pp = 2 ->>>>>>> 460bc96d3263f043f0df35c87e2731995db42f3f from ipex_llm.transformers.npu_models.baichuan_mp import gen_baichuan_fused_model_forward from ipex_llm.transformers.npu_models.baichuan_mp import DecodeRunner, PrefillRunner decode_runner = DecodeRunner( @@ -166,13 +162,8 @@ def optimize_llm( transpose_value_cache=transpose_value_cache, ) baichuan_model_forward = gen_baichuan_fused_model_forward( -<<<<<<< HEAD - prefill_runner=prefill_runner, decode_runner=decode_runner - ) -======= prefill_runner=prefill_runner, decode_runner=decode_runner ) ->>>>>>> 460bc96d3263f043f0df35c87e2731995db42f3f modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) convert_forward(model, module.BaichuanModel, baichuan_model_forward) From c9df4ec8db1de1a5253f1440706db40fe9b79ab7 Mon Sep 17 00:00:00 2001 From: jenniew Date: Tue, 27 Aug 2024 15:01:52 -0700 Subject: [PATCH 6/8] style --- .../transformers/npu_models/baichuan_mp.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index 149eddfa831..1f8eedea7ff 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -50,6 +50,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory + @torch.no_grad() def run_model( x: Union[torch.Tensor, List[torch.Tensor]], @@ -177,7 +178,6 @@ def __init__( position_ids = self.create_input_op((self.batch_size, self.seq_len)) # self.num_key_value_heads = num_key_value_heads - past_keys = [] past_values = [] if mode == "decode": @@ -265,7 +265,13 @@ def attention(self, k_bias=None, v_bias=None): hidden_size = num_heads * head_dim - proj = self.linear(hidden_states, 3 * hidden_size, hidden_size, bias=False, wt_dtype=self.dtype) + proj = self.linear( + hidden_states, + 3 * hidden_size, + hidden_size, + bias=False, + wt_dtype=self.dtype + ) proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h proj = self.unsqueeze(proj, [0]) # b, s, 3, h proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h @@ -349,7 +355,7 @@ def build_decoder( head_dim=self.head_dim, seq_len=self.seq_len, ) - + hidden_states = self.eltwise_add(residual, attn_output) residual = hidden_states hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) @@ -437,7 +443,6 @@ def __init__( start, end = self.layer_ranges[i] self.backend_decoders[i].set_weights(self.op_id, op_parameters[start * 5:end * 5]) - def forward( self, hidden_states: torch.Tensor, @@ -453,7 +458,7 @@ def forward( attention_mask, position_ids, ) - + for i in range(self.intra_stages): start, end = self.layer_ranges[i] self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end]) @@ -479,7 +484,7 @@ def post_forward(self, past_key_value, new_keys, new_values): "max_seq_len": self.max_seq_len, "transpose": self.transpose_value, } - + for i in range(len(self.layer_indexes)): key_states, value_states = past_key_value.update( new_keys[i], From bc2b7067e94057245b49d996eefab5430ef254e7 Mon Sep 17 00:00:00 2001 From: jenniew Date: Tue, 27 Aug 2024 15:03:20 -0700 Subject: [PATCH 7/8] update --- python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index 1f8eedea7ff..4f60f6d0936 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -271,7 +271,7 @@ def attention(self, hidden_size, bias=False, wt_dtype=self.dtype - ) + ) proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h proj = self.unsqueeze(proj, [0]) # b, s, 3, h proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h From d925efc1b00ec9559c3a6ead7245e2c4b71d9388 Mon Sep 17 00:00:00 2001 From: jenniew Date: Tue, 27 Aug 2024 23:17:35 -0700 Subject: [PATCH 8/8] update --- .../transformers/npu_models/baichuan_mp.py | 78 +++---------------- 1 file changed, 9 insertions(+), 69 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py index 4f60f6d0936..25cb790db99 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py @@ -48,71 +48,11 @@ import torch.multiprocessing as mp from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast +from ipex_llm.transformers.npu_models.mp_models_base import run_model from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory -@torch.no_grad() -def run_model( - x: Union[torch.Tensor, List[torch.Tensor]], - weights: List[torch.Tensor], - backend_cls: Any, - op_id: str, - replica: int = 1, -) -> torch.Tensor: - global _model_cache - import time - - t0 = time.perf_counter() - - # Use or not op_id depending on the class used - op_kwargs = {"op_id": op_id} if op_id else {} - - if not isinstance(x, (list, tuple)): - x = [x] - - # Reshape input - input_dtype = x[0].dtype - x_np = [set_contiguous(elem).to(torch.float16).numpy() for elem in x] - op_args = [] - op_args_flatten = [] - for w in weights: - if isinstance(w, tuple): # from QuantizedLinear - op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) - op_args_flatten.append(op_args[-1][0]) - op_args_flatten.append(op_args[-1][1]) - else: - op_args.append(set_contiguous(w).to(torch.float16).numpy()) - op_args_flatten.append(op_args[-1]) - - shape_dtype_signature = "_".join( - ["_".join(str(dim) for dim in t.shape) + f"_{t.dtype}" for t in x_np + op_args_flatten] - ) - key = f"{backend_cls.func.__name__}_{shape_dtype_signature}" - models = _model_cache.get(key, None) - - input_shapes = [elem.shape for elem in x_np] - if models is None: - _model_cache[key] = deque([backend_cls(*input_shapes) for i in range(replica)]) - elif len(models) < 1: - _model_cache[key].append(backend_cls(*input_shapes)) - else: - _model_cache[key].rotate(1) - - # Get the model - model = _model_cache[key][0] - - with record_function(f"npu_factory_mul_{key}"): - ret = model.run(x_np, *op_args, **op_kwargs) - - if isinstance(ret, list): - results = [adapt_output_tensor(r, r.shape, input_dtype) for r in ret] - else: - results = adapt_output_tensor(ret, ret.shape, input_dtype) - - return results - - -class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): +class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory): def __init__( self, # batch_size: int, @@ -366,7 +306,7 @@ def build_decoder( return hidden_states, new_key_states, new_value_states -class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): +class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module): def __init__( self, @@ -421,7 +361,7 @@ def __init__( start, end = self.layer_ranges[i] lm_0 = input_laynorm_weights[start:end] lm_1 = post_attn_layernorm_weights[start:end] - decoder = LowBitLlamaMultiDecoderlayer( + decoder = LowBitBaichuanMultiDecoderlayer( [1, 1, num_heads * head_dim], input_layernorm_weights=lm_0, post_attn_layernorm_weights=lm_1, @@ -463,7 +403,7 @@ def forward( start, end = self.layer_ranges[i] self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end]) - hidden_states, new_keys, new_values = LowBitLlamaMultiDecoderlayer.run_decoders( + hidden_states, new_keys, new_values = LowBitBaichuanMultiDecoderlayer.run_decoders( inputs, decoders=self.backend_decoders) @@ -497,7 +437,7 @@ def post_forward(self, past_key_value, new_keys, new_values): self.backend_decoders[i].load_cache_async() -class FusedLlamaLowBitDecoderlayer(torch.nn.Module): +class FusedBaichuanLowBitDecoderlayer(torch.nn.Module): """LLAMA MLP operation NPU backend.""" def __init__( @@ -528,7 +468,7 @@ def __init__( np_dtype = np.float16 self.backend_cls_prefill = partial( - LowBitLlamaMultiDecoderlayer, + LowBitBaichuanMultiDecoderlayer, num_heads=num_heads, # num_key_value_heads=num_key_value_heads, num_layers=1, @@ -643,7 +583,7 @@ def run_decode( input_layer_norm_weights.append(layer_norm_0) post_attn_layernorm_weights.append(layer_norm_1) - multi_decoder = FusedLlamaLowBitMultiDecoderlayer( + multi_decoder = FusedBaichuanLowBitMultiDecoderlayer( parameters=layer_weights, input_laynorm_weights=input_layer_norm_weights, post_attn_layernorm_weights=post_attn_layernorm_weights, @@ -846,7 +786,7 @@ def run_prefill( layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_1 = curr_layer.post_attention_layernorm.weight.to(torch.float16) - new_decoderlayer = FusedLlamaLowBitDecoderlayer( + new_decoderlayer = FusedBaichuanLowBitDecoderlayer( weights, num_heads=num_heads, # num_key_value_heads=num_key_value_heads,