From e243e7601f962fb87e80caaee8e056a80d63cd61 Mon Sep 17 00:00:00 2001 From: plusbang Date: Fri, 23 Aug 2024 09:37:04 +0800 Subject: [PATCH 1/3] add --- .../src/ipex_llm/transformers/npu_model.py | 3 +- .../transformers/npu_models/common.py | 10 +++ .../transformers/npu_models/convert_mp.py | 7 ++ .../transformers/npu_models/llama_mp.py | 79 +++++++++++++++++++ 4 files changed, 98 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 601c7cd1aac..63e16d971d3 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -148,7 +148,7 @@ def from_pretrained(cls, *args, **kwargs): " than max_output_len ({max_output_len})" ), ) - from ipex_llm.transformers.npu_models.convert_mp import optimize_llm + from ipex_llm.transformers.npu_models.convert_mp import optimize_llm, optimize_llm_post with torch.no_grad(): cls.load_convert(qtype, model, "cpu", *args, **kwargs) @@ -166,6 +166,7 @@ def from_pretrained(cls, *args, **kwargs): intra_pp=intra_pp, transpose_value_cache=transpose_value_cache, ) + optimize_llm_post(model) else: from ipex_llm.transformers.npu_models.convert import optimize_llm optimize_llm(model) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/common.py b/python/llm/src/ipex_llm/transformers/npu_models/common.py index bb08b1abea5..32841838d6d 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/common.py @@ -30,3 +30,13 @@ def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear: new_linear.in_features = new_weight.size(1) new_linear.out_features = new_weight.size(0) return new_linear + + +def reshape_lm_head_input(x): + if x.dim() > 3: + x = x.reshape([-1, x.shape[-2], x.shape[-1]]) + shape = list(x.size()) + if shape[1] > 10: + shape[1] = 1 + x = x[:, -1, :].view(shape) + return x diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 150788be4ec..6dc1671e667 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -77,3 +77,10 @@ def optimize_llm( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, Qwen2Model, qwen2_model_forward) + + +def optimize_llm_post(model: torch.nn.Module): + if model.config.model_type == "llama": + from transformers.models.llama.modeling_llama import LlamaForCausalLM + from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward + convert_forward(model, LlamaForCausalLM, llama2_casullm_forward) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 0e6d113cae3..46c4236f2f1 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -39,6 +39,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.transformers.npu_models.mp_models_base import run_model from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory +from ipex_llm.transformers.npu_models.common import reshape_lm_head_input +from transformers.modeling_outputs import CausalLMOutputWithPast +from torch.nn import CrossEntropyLoss class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory): @@ -944,3 +947,79 @@ def llama_fused_model_forward( ) return llama_fused_model_forward + + +def llama2_casullm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # ipex-llm change start + hidden_states = reshape_lm_head_input(hidden_states) + # ipex-llm change end + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, + dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) + for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) From 5bb1ed93c06898bb9d21c74b69fe87513a6d5f7b Mon Sep 17 00:00:00 2001 From: plusbang Date: Fri, 23 Aug 2024 10:17:23 +0800 Subject: [PATCH 2/3] fix and add qwen2 --- .../src/ipex_llm/transformers/npu_model.py | 3 +- .../transformers/npu_models/convert_mp.py | 13 ++-- .../transformers/npu_models/qwen2_mp.py | 72 +++++++++++++++++++ 3 files changed, 79 insertions(+), 9 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 63e16d971d3..601c7cd1aac 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -148,7 +148,7 @@ def from_pretrained(cls, *args, **kwargs): " than max_output_len ({max_output_len})" ), ) - from ipex_llm.transformers.npu_models.convert_mp import optimize_llm, optimize_llm_post + from ipex_llm.transformers.npu_models.convert_mp import optimize_llm with torch.no_grad(): cls.load_convert(qtype, model, "cpu", *args, **kwargs) @@ -166,7 +166,6 @@ def from_pretrained(cls, *args, **kwargs): intra_pp=intra_pp, transpose_value_cache=transpose_value_cache, ) - optimize_llm_post(model) else: from ipex_llm.transformers.npu_models.convert import optimize_llm optimize_llm(model) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 6dc1671e667..7056f1f9923 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -54,6 +54,9 @@ def optimize_llm( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, LlamaModel, llama_model_forward) + from transformers.models.llama.modeling_llama import LlamaForCausalLM + from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward + convert_forward(model, LlamaForCausalLM, llama2_casullm_forward) elif model.config.model_type == "qwen2" and model.config.intermediate_size == 8960: # for qwen2-1.5B from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward @@ -77,10 +80,6 @@ def optimize_llm( prefill_runner=prefill_runner, decode_runner=decode_runner ) convert_forward(model, Qwen2Model, qwen2_model_forward) - - -def optimize_llm_post(model: torch.nn.Module): - if model.config.model_type == "llama": - from transformers.models.llama.modeling_llama import LlamaForCausalLM - from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward - convert_forward(model, LlamaForCausalLM, llama2_casullm_forward) + from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM + from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward + convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 7a61ad9d24b..35d18947215 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -39,6 +39,9 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.transformers.npu_models.mp_models_base import run_model from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory +from ipex_llm.transformers.npu_models.common import reshape_lm_head_input +from transformers.modeling_outputs import CausalLMOutputWithPast +from torch.nn import CrossEntropyLoss class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory): @@ -981,3 +984,72 @@ def qwen2_fused_model_forward( ) return qwen2_fused_model_forward + + +def qwen2_casullm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # ipex-llm change start + hidden_states = reshape_lm_head_input(hidden_states) + # ipex-llm change end + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) From 60cc2607111d93c9c6f4f2e8d26dd32bf8971773 Mon Sep 17 00:00:00 2001 From: plusbang Date: Fri, 23 Aug 2024 10:31:49 +0800 Subject: [PATCH 3/3] fix --- python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 35d18947215..ec5e701fd4b 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -1019,7 +1019,7 @@ def qwen2_casullm_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, + # cache_position=cache_position, ) hidden_states = outputs[0]