diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 0bb0b643357..7703f9f3158 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -53,28 +53,39 @@ def siglip_attention_forward( qkv = qkv.transpose(1, 2) query_states, key_states, value_states = qkv.chunk(3, dim=1) - query_states, key_states, value_states = padding_qkv_hd( - query_states, key_states, value_states, - 72, 80 - ) - - if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype): + from ipex_llm.transformers.utils import get_xpu_device_type + if ( + self.head_dim == 72 + and get_xpu_device_type(query_states) in ["arc", "flex"] and + query_states.dtype in [torch.float, torch.half] + ): import xe_addons attn_weights = None - attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(), - value_states.contiguous(), attention_mask) + attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states, + value_states, attention_softmax) else: - attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3)) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask + query_states, key_states, value_states = padding_qkv_hd( + query_states, key_states, value_states, + 72, 80 + ) + + if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype): + import xe_addons + attn_weights = None + attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(), + value_states.contiguous(), attention_mask) + else: + attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3)) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - attn_weights = attention_softmax(attn_weights) + attn_weights = attention_softmax(attn_weights) - attn_weights = torch.nn.functional.dropout(attn_weights, - p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_weights = torch.nn.functional.dropout(attn_weights, + p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output[:, :, :, :self.head_dim] + attn_output = attn_output[:, :, :, :self.head_dim] attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)