Skip to content

Commit

Permalink
optimize siglip attention on arc (#12569)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Dec 18, 2024
1 parent 1a2ab12 commit a4eb561
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,39 @@ def siglip_attention_forward(
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.chunk(3, dim=1)

query_states, key_states, value_states = padding_qkv_hd(
query_states, key_states, value_states,
72, 80
)

if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
from ipex_llm.transformers.utils import get_xpu_device_type
if (
self.head_dim == 72
and get_xpu_device_type(query_states) in ["arc", "flex"] and
query_states.dtype in [torch.float, torch.half]
):
import xe_addons
attn_weights = None
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
value_states.contiguous(), attention_mask)
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
value_states, attention_softmax)
else:
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
query_states, key_states, value_states = padding_qkv_hd(
query_states, key_states, value_states,
72, 80
)

if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
import xe_addons
attn_weights = None
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
value_states.contiguous(), attention_mask)
else:
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

attn_weights = attention_softmax(attn_weights)
attn_weights = attention_softmax(attn_weights)

attn_weights = torch.nn.functional.dropout(attn_weights,
p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_weights = torch.nn.functional.dropout(attn_weights,
p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output[:, :, :, :self.head_dim]
attn_output = attn_output[:, :, :, :self.head_dim]

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
Expand Down

0 comments on commit a4eb561

Please sign in to comment.