diff --git a/python/llm/src/ipex_llm/transformers/layers/rope_embedding.py b/python/llm/src/ipex_llm/transformers/layers/rope_embedding.py deleted file mode 100644 index be03c5abb48..00000000000 --- a/python/llm/src/ipex_llm/transformers/layers/rope_embedding.py +++ /dev/null @@ -1,67 +0,0 @@ -# -# Copyright 2016 The BigDL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import torch -import logging -from ipex_llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd -from ipex_llm.utils.common import invalidInputError - -LOG = logging.getLogger("ipex_llm.rope_embedding") - - -# Fast RoPE for finetuning, split the q and k -def apply_fast_rope_embedding(q, k, position_ids, model_family): - if q.device.type != "xpu": - invalidInputError(False, - f"only xpu is supported in this function") - if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral", - "mixtral"]: - q_embed = FastRopeEmbedding.apply(q, position_ids) - k_embed = FastRopeEmbedding.apply(k, position_ids) - return q_embed, k_embed - else: - invalidInputError(False, - f"{model_family} is not supported.") - - -# Fast RoPE for finetuning, split the q and k -class FastRopeEmbedding(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward(ctx, x, position_ids): - import xe_addons - x_embed = torch.empty(x.shape, dtype=x.dtype, device=x.device) - xe_addons.apply_rotary_embedding_half_q_or_k(x, position_ids, - x_embed, False) - ctx.save_for_backward(position_ids) - return x_embed - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - import xe_addons - # LOG.info(f"backward, grad_output: {grad_output}") - position_ids, = ctx.saved_tensors - dx = torch.empty(grad_output.shape, - dtype=grad_output.dtype, - device=grad_output.device) - xe_addons.apply_rotary_embedding_half_q_or_k(grad_output, - position_ids, - dx, - True) - # LOG.info(f"backward, dx: {dx}, position_ids: {position_ids}, - # requires_grad: {ctx.needs_input_grad}") - return dx, None diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 4adbfcbdcc5..cafcaedfaec 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -2500,126 +2500,6 @@ def custom_forward(*inputs): ) -# For training -def llama_attention_fast_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - device = hidden_states.device - use_fast_rope = should_use_fast_rope(self, hidden_states, position_ids) - - # Check for inference - if use_cache and past_key_value is not None and q_len == 1: - A, past_key_value = llama_attention_forward_4_31( - self, - hidden_states, - past_key_value, - position_ids, - ) - return A, None, past_key_value - - if self.config.pretraining_tp > 1: - key_value_slicing = ((self.num_key_value_heads * self.head_dim) // - self.config.pretraining_tp) - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - if hasattr(self, "q_proj"): - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - else: - qkv = self.qkv_proj(hidden_states) - qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) - query_states, key_states, value_states = qkv.split([self.num_heads, - self.num_key_value_heads, - self.num_key_value_heads], dim=2) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - if use_fast_rope: - from ipex_llm.transformers.layers.rope_embedding import apply_fast_rope_embedding - query_states, key_states = apply_fast_rope_embedding(query_states, - key_states, - position_ids, - "llama") - else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "llama") - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - cache_position = None - attn_output, attn_weights = native_sdp(query_states, key_states, value_states, - attention_mask, cache_position, - bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads, output_attentions) - - attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) - if attn_output.size() != attn_output_size: - invalidInputError(False, - f"`attn_output` should be of size {attn_output_size}," - f" but is {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, - dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def llama_model_forward_4_41_internal( self, input_ids: torch.LongTensor = None, diff --git a/python/llm/src/ipex_llm/transformers/qlora.py b/python/llm/src/ipex_llm/transformers/qlora.py index a9696f9c699..1af0cf1849e 100644 --- a/python/llm/src/ipex_llm/transformers/qlora.py +++ b/python/llm/src/ipex_llm/transformers/qlora.py @@ -296,7 +296,6 @@ def get_peft_model(*args, **kwargs): if model.device.type == "xpu": cast_lora_weight(model, torch.bfloat16) - _optimize_post(model) torch.xpu.synchronize() return model @@ -390,18 +389,3 @@ def cast_lora_weight(model, dtype=torch.bfloat16): if hasattr(module, 'weight'): if module.weight.dtype == torch.float32: module = module.to(dtype) - - -def _optimize_post(model): - import transformers - from packaging import version - from ipex_llm.transformers.convert import convert_forward - from ipex_llm.transformers.models.llama import llama_attention_fast_forward - - trans_version = transformers.__version__ - if version.parse(trans_version) >= version.parse("4.31.0"): - LOG.info("Optimizing Llama finetuning....") - convert_forward( - model, - transformers.models.llama.modeling_llama.LlamaAttention, - llama_attention_fast_forward,)