diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2_disaggregate_prefill.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2_disaggregate_prefill.py index 5013b574bd7..ff17fa3d31b 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2_disaggregate_prefill.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2_disaggregate_prefill.py @@ -45,6 +45,7 @@ from transformers.utils import logging logger = logging.get_logger(__name__) import gc +from colorama import Fore, Back, Style @torch.no_grad() @@ -257,11 +258,9 @@ def build_decoder(self, hidden_states, attention_mask, position_ids, # input layernorm input_2d = self.convert_to_fp32(input_2d) - # variance = self.reduce_mean(self.eltwise_mul(input_2d, input_2d), -1, keep_dims=True) variance = self.reduce_mean(self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), -1, keep_dims=True) eps = self.constant(self.rms_norm_eps) input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps))) - # input_layernorm_weight = self.constant(input_layernorm_weight) input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight) input_2d = self.eltwise_mul(input_layernorm_weight, input_2d) input_2d = self.convert_to_fp16(input_2d) @@ -270,12 +269,6 @@ def build_decoder(self, hidden_states, attention_mask, position_ids, query_states = self.linear(input_2d, self.num_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype) key_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype) value_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype) - - # cos = self.constant(self.cached_cos) - # cos = self.unsqueeze(cos, axis=0) - - # sin = self.constant(self.cached_sin) - # sin = self.unsqueeze(sin, axis=0) query_states = self.reshape(query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) key_states = self.reshape(key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim]) @@ -385,9 +378,12 @@ def __init__( intermediate_size, max_seq_len: int = 1024, transpose_value: bool = False, + do_print: bool = False ): super().__init__() + self.do_print = do_print + op_parameters = [] for w in parameters: if isinstance(w, tuple): # from QuantizedLinear @@ -456,14 +452,9 @@ def __init__( self.backend_cls_decode_1.setWeights(3+self.num_layers_1*2, self.op_id, *op_parameters[self.num_layers_0*7:]) - print("weight setted") backend_lib.run(self.backend_cls_decode_1._mm) - - self.op_parameters = None - gc.collect() - print("2nd inference done") self.kv_cache_c_parameter_handel = (None, None) @@ -488,17 +479,11 @@ def forward(self, Returns: torch.Tensor: result """ - seq_len = hidden_states.shape[1] - pad_len = self.max_seq_len + 1 - attention_mask.size(-1) - - pad_mask = (0, pad_len) - padded_attention_mask = F.pad(attention_mask.to(torch.float16), pad_mask, - value=torch.finfo(torch.float16).min) - padded_attention_mask[:,:,:,-1] = 0.0 + inputs = (hidden_states.to(torch.float16), - padded_attention_mask, - position_ids, - ) + attention_mask, + position_ids, + ) if self.kv_cache_parameters is None: self.kv_cache_parameters = [] @@ -509,7 +494,6 @@ def forward(self, cached_prt = self.kv_cache_parameters[0].storage().data_ptr() current_ptr = past_key_value.key_cache[self.layer_indexes[0]].storage().data_ptr() if cached_prt != current_ptr: - # print("kv cache changed") self.kv_cache_parameters = [] self.kv_cache_c_parameter_handel = (None, None) self.kv_cache_prefetched = False @@ -550,31 +534,35 @@ def forward(self, models_ptr = (ctypes.POINTER(ctypes.c_char) * 2)(self.backend_cls_decode_0._mm, self.backend_cls_decode_1._mm) inputs_ptr = (ctypes.c_void_p * 3)(x_np[0].ctypes.data_as(ctypes.c_void_p), x_np[1].ctypes.data_as(ctypes.c_void_p), x_np[2].ctypes.data_as(ctypes.c_void_p)) - + t0 = time.perf_counter() backend_lib.run_decoders(models_ptr, inputs_ptr, 2, 3) + t1 = time.perf_counter() + hidden_states = self.backend_cls_decode_1.torch_out[0] + + if self.do_print: + print("outputs:", hidden_states) + + outputs = (hidden_states,) + outputs += (past_key_value,) + return outputs, t1 - t0 + + def post_forward(self, past_key_value, cache_position): + key_value_states = [] for i in range(1, len(self.backend_cls_decode_0.torch_out)): key_value_states.append(self.backend_cls_decode_0.torch_out[i]) for i in range(1, len(self.backend_cls_decode_1.torch_out)): key_value_states.append(self.backend_cls_decode_1.torch_out[i]) - - hidden_states = self.backend_cls_decode_1.torch_out[0] - cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len, "transpose": self.transpose_value} for i in range(len(self.layer_indexes)): key_states, value_states = past_key_value.update(key_value_states[2*i], key_value_states[2*i+1], self.layer_indexes[i], cache_kwargs) - - self.backend_cls_decode_0.load_wt_fn(len(inputs), self.backend_cls_decode_0._mm, self.kv_cache_c_parameter_handel[0]) - self.backend_cls_decode_1.load_wt_fn(len(inputs), self.backend_cls_decode_1._mm, self.kv_cache_c_parameter_handel[1]) + self.backend_cls_decode_0.load_wt_fn(3, self.backend_cls_decode_0._mm, self.kv_cache_c_parameter_handel[0]) + self.backend_cls_decode_1.load_wt_fn(3, self.backend_cls_decode_1._mm, self.kv_cache_c_parameter_handel[1]) self.kv_cache_prefetched = True - outputs = (hidden_states,) - outputs += (past_key_value,) - return outputs - class FusedLlamaLowBitDecoderlayer(torch.nn.Module): """LLAMA MLP operation NPU backend.""" @@ -641,8 +629,6 @@ def forward(self, torch.Tensor: result """ assert not output_attentions - # assert cache_position is None - # assert use_cache seq_len = hidden_states.shape[1] assert seq_len > 1, "seq_len must be 1 for decode mode" @@ -650,9 +636,7 @@ def forward(self, backend_cls = self.backend_cls_prefill inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) inputs += (self.layer_norm_0, self.layer_norm_1) - # print("start run_model prefill") hidden_states, past_key, past_value = run_model(inputs, self.op_parameters, backend_cls, self.op_id, replica=2) - # print("end run model prefill") cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len, "transpose": self.transpose_value} key_states, value_states = past_key_value.update(past_key, past_value, self.layer_idx, cache_kwargs) @@ -660,156 +644,9 @@ def forward(self, outputs += (past_key_value,) return outputs -from typing import Callable, List, Optional, Union, Tuple -from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList -@torch.no_grad() -def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]]=None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - **kwargs, -): - # priority: `generation_config` argument > `model.generation_config` - if generation_config is None: - if ( - self.generation_config._from_model_config - and self.generation_config._original_object_hash == hash(self.generation_config) - and self.config._has_non_default_generation_parameters() - ): - new_generation_config = GenerationConfig.from_model_config(self.config) - if new_generation_config != self.generation_config: - self.generation_config = new_generation_config - generation_config = self.generation_config - - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning("Setting `pad_token_id` to `eos_token_id`: " - f"{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id - - if generation_config is not None and generation_config.max_new_tokens is not None: - max_new_tokens = generation_config.pop("max_new_tokens") - else: - max_new_tokens = kwargs.pop("max_new_tokens", None) - - return pipeline_parallel_generate(self=self, - inputs=inputs, - max_new_tokens=max_new_tokens, - generation_config=generation_config, - **kwargs) - - -@torch.no_grad() -def pipeline_parallel_generate(self, - inputs: Optional[torch.Tensor] = None, - max_new_tokens: int = 32, - generation_config: Optional[GenerationConfig] = None, - **kwargs): - model_kwargs = generation_config.update(**kwargs) - inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( - inputs, generation_config.bos_token_id, model_kwargs - ) - bs = inputs_tensor.shape[0] - if model_kwargs.get("attention_mask", None) is None: - model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id) - if self.config.is_encoder_decoder: - input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( - batch_size=bs, - model_input_name=model_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, - device=inputs_tensor.device, - ) - else: - input_ids = inputs_tensor if model_input_name == "input_ids" \ - else model_kwargs.pop("input_ids") - - self.first_token_time = 0 - self.next_token_time = [] - - pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) \ - if eos_token_id is not None else None - - _input_ids = None - _past_key_values = None - - bs = input_ids.shape[0] - output_ids = input_ids.clone() - - step = 0 - # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) - this_peer_finished = False - while True: - if step >= max_new_tokens: - break - - if _input_ids is None: - _input_ids = input_ids - - model_inputs = self.prepare_inputs_for_generation(output_ids, **model_kwargs) - - tic = time.time() - outputs = self(**model_inputs) - logits = outputs.logits - next_ids = torch.argmax(logits[:, -1:, :], dim=-1) - _input_ids = next_ids - output_ids = torch.cat([output_ids, next_ids], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - - # finished sentences should have their next token be a padding token - next_ids = next_ids.squeeze() - if eos_token_id is not None: - if pad_token_id is None: - invalidInputError(False, "If `eos_token_id` is defined, " - "make sure that `pad_token_id` is defined.") - next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - _past_key_values = outputs.past_key_values - - toc = time.time() - if step == 0: - self.first_token_time = toc - tic - else: - self.next_token_time.append(toc - tic) - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_ids.tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True - if this_peer_finished: - break - - step += 1 - if self.device.type == 'xpu': - torch.xpu.synchronize() - self.rest_cost_mean = np.mean(self.next_token_time) - return output_ids, _past_key_values, model_kwargs - +import time import types -def run_decode(model, rank, world_size, layer_start, layer_end, +def run_decode(model, rank, world_size, port, layer_start, layer_end, max_seq_len, transpose_value_cache, input_queue, result_queue): @@ -825,21 +662,6 @@ def run_decode(model, rank, world_size, layer_start, layer_end, my_size = dist.get_world_size() logger.info(f"rank: {my_rank}, size: {my_size}") - # model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, - # trust_remote_code=True, attn_implementation="eager", - # load_in_low_bit="sym_int4", pipeline_parallel_stages=world_size) - - - from ipex_llm.transformers.npu_models.pipeline_parallel import pipeline_parallel, pipeline_parallel_generate - model = pipeline_parallel(model, 2, - torch.float16, device="cpu") - - # add pipeline_parallel_generate to pretrained model dynamically - model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate, - model) - from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post - optimize_llm(model) - num_heads = model.model.layers[layer_start].self_attn.num_heads num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads head_dim = model.model.layers[layer_start].self_attn.head_dim @@ -887,40 +709,150 @@ def run_decode(model, rank, world_size, layer_start, layer_end, rms_norm_eps=rms_norm_eps, intermediate_size=intermediate_size, max_seq_len=max_seq_len, - transpose_value=transpose_value_cache + transpose_value=transpose_value_cache, + do_print=False, # layer_start == 0, ) - if rank == 0: - print(model) - dist.barrier() - if rank == 1: - print(model) - - model.model.multi_decoder = multi_decoder - result_queue.put("loading success") + past_key_values = None + control = torch.empty((), dtype=torch.int) + hidden_states = torch.empty((1, 1, head_dim*num_heads), dtype=torch.float16) with torch.inference_mode(): while True: - - result = input_queue.get() - if result == "stop": + + dist.broadcast(control, src=0) + if control.item() == -2: break - input_ids, model_kwargs, n_predict = result - model_kwargs.pop("max_new_tokens", None) - output = model.generate(input_ids, **model_kwargs, num_beams=1, do_sample=False, max_new_tokens=n_predict) - result_queue.put(output) + elif control.item() == -1: + past_key_values = input_queue.get() + else: + t0 = time.perf_counter() + past_seen_tokens = past_key_values.get_seq_length() + attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + 1, + device=hidden_states.device) + + position_ids = position_ids = cache_position.unsqueeze(0) + causal_mask = model.model._update_causal_mask(attention_mask, hidden_states, + cache_position, past_seen_tokens) + pad_len = multi_decoder.max_seq_len + 1 - causal_mask.size(-1) + + pad_mask = (0, pad_len) + padded_causal_mask = F.pad(causal_mask.to(torch.float16), pad_mask, + value=torch.finfo(torch.float16).min) + padded_causal_mask[:,:,:,-1] = 0.0 + dist.recv(hidden_states, src=rank - 1) + t1 = time.perf_counter() + layer_outputs, elapse = multi_decoder(hidden_states, + attention_mask=padded_causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + use_cache=True, + cache_position=cache_position,) + t2 = time.perf_counter() + hidden_states = layer_outputs[0] + t3 = time.perf_counter() + dist.send(hidden_states, dst=(rank + 1)%world_size) + t4 = time.perf_counter() + multi_decoder.post_forward(past_key_values, cache_position) +import time +class DecodeRunner: + def __init__(self, model, max_seq_len, transpose_value_cache): + self.model = model + self.max_seq_len = max_seq_len + self.transpose_value_cache = transpose_value_cache + + port = '54791' + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = port + os.environ['RANK'] = '0' + os.environ['WORLD_SIZE'] = '3' + + + + self.decode_result_queue_0 = mp.Queue() + self.decode_input_queue_0 = mp.Queue() + self.decode_p_0 = mp.Process(target=run_decode, args=(self.model, + 1, 3, port, + 0, 16, + self.max_seq_len, + self.transpose_value_cache, + self.decode_input_queue_0, + self.decode_result_queue_0)) + self.decode_p_0.start() + + self.decode_result_queue_1 = mp.Queue() + self.decode_input_queue_1 = mp.Queue() + self.decode_p_1 = mp.Process(target=run_decode, args=(self.model, + 2, 3, port, + 16, 32, + self.max_seq_len, + self.transpose_value_cache, + self.decode_input_queue_1, + self.decode_result_queue_1)) + self.decode_p_1.start() + + dist.init_process_group() + my_rank = dist.get_rank() + my_size = dist.get_world_size() + logger.info(f"rank: {my_rank}, size: {my_size}") + + output = self.decode_result_queue_0.get() + print(Fore.GREEN + f"decode process 0 output: {output}") + print(Style.RESET_ALL) + + output = self.decode_result_queue_1.get() + print(Fore.GREEN + f"decode process 1 output: {output}") + print(Style.RESET_ALL) + + self.cache_past_key_value = None -def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_queue): + def forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs,): + t0 = time.perf_counter() + + if self.cache_past_key_value != past_key_value: + control = torch.tensor(-1, dtype=torch.int) + dist.broadcast(control, src=0) + self.decode_input_queue_0.put(past_key_value) + self.decode_input_queue_1.put(past_key_value) + + control = torch.tensor(0, dtype=torch.int) + dist.broadcast(control, src=0) + hidden_states = hidden_states.to(torch.float16) + dist.send(hidden_states, dst=1) + past_key_value.expand() + dist.recv(hidden_states, src=2) + t1 = time.perf_counter() + return hidden_states, past_key_value + + def shutdown(self): + control = torch.tensor(-2, dtype=torch.int) + dist.broadcast(control, src=0) + self.decode_p_0.join() + self.decode_p_1.join() + def __del__(self): + self.shutdown() + +def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_queue): print("finish loading prefill model") - from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post + # from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post - optimize_llm(model) + # optimize_llm(model) layer_start = 0 layer_end = len(model.model.layers) @@ -971,6 +903,7 @@ def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_q input_layer_norm_weights.append(layer_norm_0) post_attn_layernorm_weights.append(layer_norm_1) model.model.layers[layer_idx] = new_decoderlayer + deocderlayers.append(new_decoderlayer) print("finish creating all decode layers in prefill") result_queue.put("loading finish") @@ -980,14 +913,180 @@ def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_q result = input_queue.get() if result == "stop": break - input_ids, n_predict, max_seq_len, transpose_value_cache = result + + hidden_states, position_ids, causal_mask, past_key_values, cache_position = result with torch.inference_mode(): - + for decoder_layer in deocderlayers: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + use_cache=True, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + next_decoder_cache = layer_outputs[1] + + result_queue.put((hidden_states, next_decoder_cache)) + + +class PrefillRunner: + def __init__(self, model, max_seq_len, transpose_value_cache): + self.model = model + self.max_seq_len = max_seq_len + self.transpose_value_cache = transpose_value_cache - output, past_key_value, model_kwargs = generate(model, input_ids, num_beams=1, do_sample=False, max_new_tokens=n_predict) + self.prefill_result_queue = mp.Queue() + self.prefill_input_queue = mp.Queue() - result_queue.put((output, past_key_value, model_kwargs)) + self.p = mp.Process(target=run_prefill, args=(model, + args.max_seq_len, + args.transpose_value_cache, + self.prefill_input_queue, + self.prefill_result_queue)) + self.p.start() + output = self.prefill_result_queue.get() + print(Fore.GREEN + f"prefill process output: {output}") + print(Style.RESET_ALL) + def forward(self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs,): + self.prefill_input_queue.put((hidden_states, position_ids, attention_mask, past_key_value, cache_position)) + return self.prefill_result_queue.get() + + def shutdown(self): + self.prefill_input_queue.put("stop") + self.p.join() + + def __del__(self): + self.shutdown() + + +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import repeat_kv, apply_rotary_pos_emb +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP +def gen_llama_fused_model_forward(prefill_runner, decode_runner): + def llama_fused_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + t0 = time.perf_counter() + output_attentions = ( + output_attentions if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + invalidInputError(False, + ("You cannot specify both input_ids and inputs_embeds at the same time, " + "and must specify either one")) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + + # ipex-llm changes start + from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache + if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache): + past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device) + # ipex-llm changes end + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, + cache_position, past_seen_tokens) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + seq_len = hidden_states.size(1) + + if seq_len == 1: + layers_runner = decode_runner + else: + layers_runner = prefill_runner + layer_outputs = layers_runner.forward(hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position,) + hidden_states = layer_outputs[0] + + next_decoder_cache = layer_outputs[1] + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # ipex-llm changes start + next_cache = next_decoder_cache if use_cache else None + # ipex-llm changes end + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, + all_hidden_states, all_self_attns] if v is not None) + t1 = time.perf_counter() + # print("fused model forward time: ", t1 - t0) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return llama_fused_model_forward + +def convert_forward(m, target_m, new_forward): + if m.__class__ == target_m: + bound_method = new_forward.__get__(m, m.__class__) + setattr(m, "forward", bound_method) + for _, sub_m in m.named_children(): + convert_forward(sub_m, target_m, new_forward) + +from transformers.models.llama.modeling_llama import LlamaModel import torch.multiprocessing as mp if __name__ == "__main__": @@ -997,7 +1096,7 @@ def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_q ', or the path to the huggingface checkpoint folder') parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun", help='Prompt to infer') - parser.add_argument('--n-predict', type=int, default=32, + parser.add_argument('--n-predict', type=int, default=64, help='Max tokens to predict') parser.add_argument('--max-seq-len', type=int, default=1024) parser.add_argument('--transpose-value-cache', action="store_true", default=False) @@ -1035,71 +1134,44 @@ def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_q prompt = prompt_1024 input_ids = tokenizer.encode(prompt, return_tensors="pt") - input_ids = input_ids[:, :950] - - from colorama import Fore, Back, Style - - decode_result_queue_0 = mp.Queue() - decode_input_queue_0 = mp.Queue() - decode_p_0 = mp.Process(target=run_decode, args=(model, 0, 2, 0, 16, args.max_seq_len, args.transpose_value_cache, decode_input_queue_0, decode_result_queue_0)) - decode_p_0.start() - - decode_result_queue_1 = mp.Queue() - decode_input_queue_1 = mp.Queue() - decode_p_1 = mp.Process(target=run_decode, args=(model, 1, 2, 16, 32, args.max_seq_len, args.transpose_value_cache, decode_input_queue_1, decode_result_queue_1)) - decode_p_1.start() + input_ids = input_ids[:, :args.max_seq_len - args.n_predict] - output = decode_result_queue_0.get() - print(Fore.GREEN + f"decode process 0 output: {output}") - print(Style.RESET_ALL) - output = decode_result_queue_1.get() - print(Fore.GREEN + f"decode process 1 output: {output}") - print(Style.RESET_ALL) + decode_runner = DecodeRunner(model, args.max_seq_len, args.transpose_value_cache) + prefill_runner = PrefillRunner(model, args.max_seq_len, args.transpose_value_cache) - prefill_result_queue = mp.Queue() - prefill_input_queue = mp.Queue() + llama_model_forward = gen_llama_fused_model_forward(prefill_runner=prefill_runner, + decode_runner=decode_runner) + convert_forward(model, LlamaModel, llama_model_forward) - p = mp.Process(target=run_prefill, args=(model, args.max_seq_len, args.transpose_value_cache, prefill_input_queue , prefill_result_queue)) - p.start() + from ipex_llm.utils.benchmark_util_4_29 import BenchmarkWrapper - output = prefill_result_queue.get() + model = BenchmarkWrapper(model, do_print=True) - print(Fore.GREEN + f"prefill process output: {output}") - print(Style.RESET_ALL) - - prefill_input_queue.put((input_ids, 1, args.max_seq_len, args.transpose_value_cache)) - prefill_output, past_key_value, model_kwargs = prefill_result_queue.get() - - print("finish prefill") - print("output tokens", prefill_output) - print("past_key_value", past_key_value) - print("model_kwargs", model_kwargs) - - decode_input_ids = prefill_output[:, -1:] - - decode_input_queue_0.put((decode_input_ids, model_kwargs, args.n_predict)) - decode_input_queue_1.put((decode_input_ids, model_kwargs, args.n_predict)) - - output0 = decode_result_queue_0.get() - output1 = decode_result_queue_1.get() + with torch.inference_mode(): + # input_ids = tokenizer.encode(prompt, return_tensors="pt") + print("finish to load") + print('input length:', len(input_ids[0])) + for i in range(3): + st = time.time() + output = model.generate(input_ids, num_beams=1, do_sample=False, max_new_tokens=args.n_predict) + end = time.time() + print(f'Inference time: {end-st} s') - print("output0", output0) - print("output1", output1) + print('-'*20, 'Input', '-'*20) + input_str = tokenizer.decode(input_ids[0], skip_special_tokens=False) + print(input_str) - out_str = tokenizer.decode(input_ids[0], skip_special_tokens=False) - print("prompt: ", out_str) - output_str = tokenizer.decode(output0[0], skip_special_tokens=False) - print("output: ", output_str) + output_str = tokenizer.decode(output[0], skip_special_tokens=False) + print('-'*20, 'Output', '-'*20) + print(output_str) - prefill_input_queue.put("stop") - decode_input_queue_0.put("stop") - decode_input_queue_1.put("stop") + print('-'*80) + print('done') - p.join() - decode_p_0.join() - decode_p_1.join() + decode_runner.shutdown() + prefill_runner.shutdown() print("success shut down") diff --git a/python/llm/src/ipex_llm/transformers/npu_models/kv.py b/python/llm/src/ipex_llm/transformers/npu_models/kv.py index 36260e8e3e8..55bef013fdb 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/kv.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/kv.py @@ -91,6 +91,30 @@ def append_fused_kv_cache(cache_k, cache_v, key_states, value_states, transpose_ new_cache_v[:, :, :, start:end] = value_states return new_cache_k, new_cache_v.transpose(-1, -2) +def expand_fused_kv_cache(cache_k, cache_v, transpose_value=False): + if not transpose_value: + new_size = (cache_k.size(0), + cache_k.size(1), + cache_k.size(2) + 1, + cache_k.size(3)) + new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0) + new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) + return new_cache_k, new_cache_v + else: + new_size_key = (cache_k.size(0), + cache_k.size(1), + cache_k.size(2) + 1, + cache_k.size(3)) + new_cache_k = cache_k.as_strided(new_size_key, cache_k.stride(), storage_offset=0) + new_size_value = (cache_v.size(0), + cache_v.size(1), + cache_v.size(3), + cache_v.size(2) + 1, + ) + raw_cache_v = cache_v.transpose(-1, -2) + new_cache_v = raw_cache_v.as_strided(new_size_value, raw_cache_v.stride(), storage_offset=0) + return new_cache_k, new_cache_v.transpose(-1, -2) + class DynamicFusedNormalCache(DynamicCache): # Experimental support for fused decoderlayer implementation on NPU @@ -147,6 +171,12 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: for idx, layer in self.key_cache.items(): return layer.shape[-2] + + def expand(self): + for idx, layer in self.key_cache.items(): + key_cache, value_cache = expand_fused_kv_cache(self.key_cache[idx], self.value_cache[idx]) + self.key_cache[idx] = key_cache + self.value_cache[idx] = value_cache @property def _seen_tokens(self):