diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2_disaggregate_prefill.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2_disaggregate_prefill.py index f6692051f5f6..943096fff31d 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2_disaggregate_prefill.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2_disaggregate_prefill.py @@ -806,9 +806,9 @@ def pipeline_parallel_generate(self, if self.device.type == 'xpu': torch.xpu.synchronize() self.rest_cost_mean = np.mean(self.next_token_time) - return output_ids, _past_key_values - + return output_ids, _past_key_values, model_kwargs +import types def run_decode(model, rank, world_size, layer_start, layer_end, max_seq_len, transpose_value_cache, input_queue, result_queue): @@ -829,6 +829,14 @@ def run_decode(model, rank, world_size, layer_start, layer_end, # trust_remote_code=True, attn_implementation="eager", # load_in_low_bit="sym_int4", pipeline_parallel_stages=world_size) + + from ipex_llm.transformers.npu_models.pipeline_parallel import pipeline_parallel, pipeline_parallel_generate + model = pipeline_parallel(model, 2, + torch.float16, device="cpu") + + # add pipeline_parallel_generate to pretrained model dynamically + model.pipeline_parallel_generate = types.MethodType(pipeline_parallel_generate, + model) from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post optimize_llm(model) @@ -882,6 +890,12 @@ def run_decode(model, rank, world_size, layer_start, layer_end, transpose_value=transpose_value_cache ) + if rank == 0: + print(model) + dist.barrier() + if rank == 1: + print(model) + model.model.multi_decoder = multi_decoder result_queue.put("loading success") @@ -893,12 +907,12 @@ def run_decode(model, rank, world_size, layer_start, layer_end, result = input_queue.get() if result == "stop": break - input_ids, past_key_value, n_predict = result - output = model.generate(input_ids, num_beams=1, do_sample=False, max_new_tokens=n_predict, past_key_values=past_key_value) + input_ids, model_kwargs, n_predict = result + model_kwargs.pop("max_new_tokens", None) + output = model.generate(input_ids, **model_kwargs, num_beams=1, do_sample=False, max_new_tokens=n_predict) result_queue.put(output) - def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_queue): @@ -970,20 +984,20 @@ def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_q with torch.inference_mode(): - output, past_key_value = generate(model, input_ids, num_beams=1, do_sample=False, max_new_tokens=n_predict) + output, past_key_value, model_kwargs = generate(model, input_ids, num_beams=1, do_sample=False, max_new_tokens=n_predict) - result_queue.put((output, past_key_value)) + result_queue.put((output, past_key_value, model_kwargs)) import torch.multiprocessing as mp if __name__ == "__main__": parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for npu model') - parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", + parser.add_argument('--repo-id-or-model-path', type=str, default="D:\llm-models\Llama-2-7b-chat-hf", help='The huggingface repo id for the Llama2 model to be downloaded' ', or the path to the huggingface checkpoint folder') parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun", help='Prompt to infer') - parser.add_argument('--n-predict', type=int, default=1, + parser.add_argument('--n-predict', type=int, default=32, help='Max tokens to predict') parser.add_argument('--max-seq-len', type=int, default=1024) parser.add_argument('--transpose-value-cache', action="store_true", default=False) @@ -1030,20 +1044,21 @@ def run_prefill(model, max_seq_len, transpose_value_cache, input_queue, result_q output = prefill_result_queue.get() - print(Fore.RED + f"prefill process output: {output}") + print(Fore.GREEN + f"prefill process output: {output}") print(Style.RESET_ALL) - prefill_input_queue.put((input_ids, args.n_predict, args.max_seq_len, args.transpose_value_cache)) - prefill_output, past_key_value = prefill_result_queue.get() + prefill_input_queue.put((input_ids, 1, args.max_seq_len, args.transpose_value_cache)) + prefill_output, past_key_value, model_kwargs = prefill_result_queue.get() print("finish prefill") print("output tokens", prefill_output) print("past_key_value", past_key_value) + print("model_kwargs", model_kwargs) decode_input_ids = prefill_output[:, -1:] - decode_input_queue_0.put((decode_input_ids, past_key_value, args.n_predict)) - decode_input_queue_1.put((decode_input_ids, past_key_value, args.n_predict)) + decode_input_queue_0.put((decode_input_ids, model_kwargs, args.n_predict)) + decode_input_queue_1.put((decode_input_ids, model_kwargs, args.n_predict)) output0 = decode_result_queue_0.get() output1 = decode_result_queue_1.get() diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index bc6153c2576f..68f4046858e6 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -41,7 +41,6 @@ def wrapper(model: torch.nn.Module, qtype, device, *args, **kwargs): """ for name, layer in model.named_children(): - print(f"converting layers {name}") new_layer = func(layer, qtype, device, *args, **kwargs) if new_layer: model.add_module(name, new_layer) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama.py b/python/llm/src/ipex_llm/transformers/npu_models/llama.py index a322d731e51c..a81baa586e53 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama.py @@ -233,6 +233,9 @@ def llama_fused_model_forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) + # print("attention_mask", attention_mask) + # print("cache_position", cache_position) + # print('past_seen_tokens', past_seen_tokens) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py index 69614439338c..7a2fa19e7c85 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/pipeline_parallel.py @@ -325,13 +325,18 @@ def pipeline_parallel_generate(self, if _input_ids is None: _input_ids = input_ids - - model_inputs = self.prepare_inputs_for_generation(output_ids, **model_kwargs) - + + model_inputs = self.prepare_inputs_for_generation(_input_ids, **model_kwargs) + if local_rank == 1: + from colorama import Fore, Back, Style + # print(Fore.GREEN + f"model_inputs: ", model_inputs, Style.RESET_ALL) + # print(Fore.GREEN + f"_input_ids: ", _input_ids, Style.RESET_ALL) + tic = time.time() if local_rank == 0: outputs = self(**model_inputs) else: + # _input_ids = model_inputs.pop("input_ids") _inputs_shape = _input_ids.shape + (self.config.hidden_size,) if step == 0 and self.config.model_type == "chatglm" \ and hasattr(self.config, "vision_config"):