Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor yuan2 and starcoder2 and fix #12589

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def llama_attention_forward(
attn_weights = None
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == key_states.size(2), math.sqrt(self.head_dim)
attention_mask, q_len == key_states.size(2)
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
9 changes: 3 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@

import torch
import warnings
import torch.nn as nn
from typing import Optional, Tuple, Union, List
import math
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import should_use_compresskv, should_use_fuse_rope
from ipex_llm.transformers.models.llama import repeat_kv
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \
DynamicCompressCache, DynamicCompressFp8Cache
from transformers.cache_utils import Cache
Expand Down Expand Up @@ -127,11 +125,10 @@ def minicpm_attention_forward(
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)

from ipex_llm.transformers.models.common import scaled_dot_product_attention
attn_weights = None
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim)
attention_mask, q_len == kv_seq_len
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
8 changes: 5 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.nn.functional import linear
from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from transformers import AutoProcessor, TextIteratorStreamer
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor

Expand Down Expand Up @@ -72,10 +73,11 @@ def siglip_attention_forward(
72, 80
)

from ipex_llm.transformers.models.common import scaled_dot_product_attention
attn_weights = None
attn_output = scaled_dot_product_attention(query_states, key_states, value_states,
attention_mask, False, math.sqrt(self.head_dim))
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, False, 1 / math.sqrt(self.head_dim)
)

attn_output = attn_output[:, :, :, :self.head_dim]

Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def qwen2_attention_forward(
else:
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim)
attention_mask, q_len == kv_seq_len
)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
50 changes: 9 additions & 41 deletions python/llm/src/ipex_llm/transformers/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,15 @@
import torch
import warnings

from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.utils import (
use_quantize_kv_cache, restore_fp8_kv_cache,
should_use_fuse_rope, use_sdp, use_sdp_causal
)
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, should_use_fuse_rope
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
from ipex_llm.utils.common.log4Error import invalidInputError

from typing import Optional, Tuple, List
from transformers.cache_utils import Cache
from transformers.models.starcoder2.modeling_starcoder2 import repeat_kv, apply_rotary_pos_emb
from transformers.models.starcoder2.modeling_starcoder2 import apply_rotary_pos_emb
from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model, Starcoder2Attention


Expand Down Expand Up @@ -103,41 +101,11 @@ def attention_forward(
self.layer_idx, None)

# IPEX-LLM OPT: sdp
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None:
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_weights = None
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
Expand Down
40 changes: 9 additions & 31 deletions python/llm/src/ipex_llm/transformers/models/yuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
import torch

from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
mlp_fusion_check, fp16_fusion_check
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import SILU, update_past_key_value
from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope


def merge_qk(module: torch.nn.Module):
Expand Down Expand Up @@ -214,34 +214,12 @@ def yuan_attention_forward(
)
past_key_value = (key_states, value_states, before_hidden_states) if use_cache else None

# IPEX-LLM OPT: sdp
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states,
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
# IPEX-LLM OPT: sdpa
attn_weights = None
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
Expand Down
Loading