Skip to content

Commit

Permalink
remove old rope usage (#12552)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Dec 16, 2024
1 parent a86487c commit 5ae0006
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 203 deletions.
67 changes: 0 additions & 67 deletions python/llm/src/ipex_llm/transformers/layers/rope_embedding.py

This file was deleted.

120 changes: 0 additions & 120 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2500,126 +2500,6 @@ def custom_forward(*inputs):
)


# For training
def llama_attention_fast_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
use_fast_rope = should_use_fast_rope(self, hidden_states, position_ids)

# Check for inference
if use_cache and past_key_value is not None and q_len == 1:
A, past_key_value = llama_attention_forward_4_31(
self,
hidden_states,
past_key_value,
position_ids,
)
return A, None, past_key_value

if self.config.pretraining_tp > 1:
key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
self.config.pretraining_tp)
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [F.linear(hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)

key_states = [F.linear(hidden_states, key_slices[i])
for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)

value_states = [F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)

else:
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=2)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

if use_fast_rope:
from ipex_llm.transformers.layers.rope_embedding import apply_fast_rope_embedding
query_states, key_states = apply_fast_rope_embedding(query_states,
key_states,
position_ids,
"llama")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

cache_position = None
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask, cache_position,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions)

attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
invalidInputError(False,
f"`attn_output` should be of size {attn_output_size},"
f" but is {attn_output.size()}")

attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp,
dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


def llama_model_forward_4_41_internal(
self,
input_ids: torch.LongTensor = None,
Expand Down
16 changes: 0 additions & 16 deletions python/llm/src/ipex_llm/transformers/qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def get_peft_model(*args, **kwargs):

if model.device.type == "xpu":
cast_lora_weight(model, torch.bfloat16)
_optimize_post(model)
torch.xpu.synchronize()

return model
Expand Down Expand Up @@ -390,18 +389,3 @@ def cast_lora_weight(model, dtype=torch.bfloat16):
if hasattr(module, 'weight'):
if module.weight.dtype == torch.float32:
module = module.to(dtype)


def _optimize_post(model):
import transformers
from packaging import version
from ipex_llm.transformers.convert import convert_forward
from ipex_llm.transformers.models.llama import llama_attention_fast_forward

trans_version = transformers.__version__
if version.parse(trans_version) >= version.parse("4.31.0"):
LOG.info("Optimizing Llama finetuning....")
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_fast_forward,)

0 comments on commit 5ae0006

Please sign in to comment.