diff --git a/python/llm/src/ipex_llm/transformers/models/gpt2.py b/python/llm/src/ipex_llm/transformers/models/gpt2.py index 1a968bd052c..de78907470b 100644 --- a/python/llm/src/ipex_llm/transformers/models/gpt2.py +++ b/python/llm/src/ipex_llm/transformers/models/gpt2.py @@ -15,6 +15,7 @@ # import torch +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import use_sdp_non_causal @@ -44,10 +45,11 @@ def gpt2_attention_attn( else: attention_mask = attention_mask.expand(-1, -1, seq_len, seq_len) - import xe_addons attn_weights = None - attn_output = xe_addons.sdp_non_causal(query, key.contiguous(), - value.contiguous(), attention_mask) + attn_output = scaled_dot_product_attention( + query, key.contiguous(), value.contiguous(), + attention_mask, False + ) return attn_output, attn_weights # ipex-llm changes end diff --git a/python/llm/src/ipex_llm/transformers/models/internvl.py b/python/llm/src/ipex_llm/transformers/models/internvl.py index 43ce6f563fd..1e22ddacb4e 100644 --- a/python/llm/src/ipex_llm/transformers/models/internvl.py +++ b/python/llm/src/ipex_llm/transformers/models/internvl.py @@ -26,6 +26,7 @@ import torch from ipex_llm.utils.common.log4Error import invalidInputError +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import use_sdp_non_causal @@ -177,8 +178,10 @@ def intern_attention_forward(self, x: torch.Tensor) -> torch.Tensor: k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) if use_sdp_non_causal(self.head_dim, q.device, q.dtype): - import xe_addons - x = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None) + x = scaled_dot_product_attention( + q, k.contiguous(), v.contiguous(), + None, False, self.scale + ) else: attn = ((q * self.scale) @ k.transpose(-2, -1)) attn = attn.softmax(dim=-1) diff --git a/python/llm/src/ipex_llm/transformers/models/mllama.py b/python/llm/src/ipex_llm/transformers/models/mllama.py index 9086fd2a247..2dc6896962f 100644 --- a/python/llm/src/ipex_llm/transformers/models/mllama.py +++ b/python/llm/src/ipex_llm/transformers/models/mllama.py @@ -32,7 +32,6 @@ # limitations under the License. -import math import torch from typing import Optional, Tuple, Union @@ -40,11 +39,10 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.mllama.modeling_mllama import MllamaVisionAttention from transformers.models.mllama.modeling_mllama import MllamaTextSelfAttention -from transformers.models.mllama.modeling_mllama import repeat_kv -from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_sdp_non_causal -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache from ipex_llm.transformers.utils import invalidInputError @@ -67,27 +65,11 @@ def mllama_vision_attention_forward( qkv = qkv.transpose(1, 2) query, key, value = qkv.chunk(3, dim=1) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - else: - causal_mask = None - - if use_sdp_non_causal(self.head_dim, query.device, query.dtype): - import xe_addons - attn_output = xe_addons.sdp_non_causal(query, key.contiguous(), - value.contiguous(), causal_mask) - attn_weights = None - else: - attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - from ipex_llm.transformers.models.common import attention_softmax - attn_weights = attention_softmax(attn_weights) - - attn_output = torch.matmul(attn_weights, value) + attn_weights = None + attn_output = scaled_dot_product_attention( + query, key.contiguous(), value.contiguous(), + attention_softmax, False + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) @@ -278,44 +260,11 @@ def mllama_cross_attention_forward( past_key_value.value_cache[self.layer_idx], ) - kv_seq_len = key_states.size(2) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :kv_seq_len] - else: - causal_mask = None - attn_weights = None - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, causal_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, causal_mask) - else: - if isinstance(past_key_value, DynamicFp8Cache): - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if causal_mask is not None: - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == key_states.size(2) + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1)