Skip to content

Commit

Permalink
refactor sd 1.5 and qwen2-vl and fix (#12590)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Dec 20, 2024
1 parent b050368 commit 098eb33
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 58 deletions.
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def siglip_attention_forward(

attn_weights = None
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
query_states, key_states.contiguous(), value_states.contiguous(),
attention_mask, False, 1 / math.sqrt(self.head_dim)
)

Expand Down
3 changes: 1 addition & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,7 @@ def qwen2_attention_forward(
self.layer_idx, None)

attn_weights = None
if query_states.device.type == 'xpu' \
and use_flash_attention(query_states, key_states, attention_mask):
if use_flash_attention(query_states, key_states, attention_mask):
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down
55 changes: 15 additions & 40 deletions python/llm/src/ipex_llm/transformers/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@
import torch

from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_sdp_non_causal
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
from ipex_llm.utils.common import invalidInputError
Expand Down Expand Up @@ -198,7 +199,6 @@ def qwen2_vision_attention_forward(
"unexpected input")

if use_sdp_non_causal(self.head_dim, q.device, q.dtype):
import xe_addons
image_num = len(seq_lens) - 1
image_size = seq_lens[1] - seq_lens[0]
guessed_seq_lens = torch.arange(0, (image_num + 1) * image_size, image_size,
Expand All @@ -209,7 +209,10 @@ def qwen2_vision_attention_forward(
v = v.view(image_num, image_size, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
# q, k, v: [image_num, num_heads, image_size, head_dim]

attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None)
attn_output = scaled_dot_product_attention(
q, k.contiguous(), v.contiguous(),
None, False
)
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.view(seq_length, self.num_heads, self.head_dim)
# attn_output: [seq_length, num_heads, head_dim]
Expand All @@ -226,7 +229,10 @@ def qwen2_vision_attention_forward(
tmp_q = q[:, :, start_idx:end_idx, :]
tmp_k = k[:, :, start_idx:end_idx, :]
tmp_v = v[:, :, start_idx:end_idx, :]
attn_output = xe_addons.sdp_non_causal(tmp_q, tmp_k, tmp_v, None)
attn_output = scaled_dot_product_attention(
tmp_q, tmp_k, tmp_v,
None, False
)
attn_output = attn_output.permute(0, 2, 1, 3)
# attn_output: [1, seq_length, num_heads, head_dim]
attn_outputs.append(attn_output)
Expand Down Expand Up @@ -293,42 +299,11 @@ def qwen2_vl_attention_forward(
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)

kv_seq_len = key_states.size(2)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, :kv_seq_len]

attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, causal_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, causal_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if causal_mask is not None:
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == key_states.size(2)
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
Expand Down
21 changes: 6 additions & 15 deletions python/llm/src/ipex_llm/transformers/models/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from typing import Optional

from ipex_llm.transformers.utils import get_xpu_device_type
from ipex_llm.transformers.models.common import padding_qkv_hd, attention_softmax
from ipex_llm.transformers.models.utils import use_sdp_non_causal
from ipex_llm.transformers.models.common import padding_qkv_hd
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from diffusers.models.attention_processor import Attention


Expand Down Expand Up @@ -110,19 +110,10 @@ def __call__(
if query.device.type == "xpu" and query.dtype in [torch.half, torch.float]:
# padding head_dim 40 to 64
query, key, value = padding_qkv_hd(query, key, value, 40, 64)

if use_sdp_non_causal(query.size(-1), query.device, query.dtype):
import xe_addons
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), attention_mask)
else:
scale = 1 / math.sqrt(head_dim)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights)
hidden_states = torch.matmul(attn_weights, value)

hidden_states = scaled_dot_product_attention(
query, key.contiguous(), value.contiguous(),
attention_mask, False, 1 / math.sqrt(head_dim)
)
hidden_states = hidden_states[:, :, :, :head_dim]
else:
hidden_states = torch.nn.functional.scaled_dot_product_attention(
Expand Down

0 comments on commit 098eb33

Please sign in to comment.