diff --git a/python/llm/src/ipex_llm/transformers/models/bloom.py b/python/llm/src/ipex_llm/transformers/models/bloom.py index 4967aa1897c..162d5168b5a 100644 --- a/python/llm/src/ipex_llm/transformers/models/bloom.py +++ b/python/llm/src/ipex_llm/transformers/models/bloom.py @@ -37,7 +37,6 @@ import torch import torch.utils.checkpoint from torch.nn import functional as F -from ipex_llm.transformers.models.utils import use_fused_layer_norm from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache import os diff --git a/python/llm/src/ipex_llm/transformers/models/llama32.py b/python/llm/src/ipex_llm/transformers/models/llama32.py index e6a9c53bb14..15c156a192c 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama32.py +++ b/python/llm/src/ipex_llm/transformers/models/llama32.py @@ -42,14 +42,12 @@ from typing import Optional, Tuple, Union from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import repeat_kv from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.models.common import attention_softmax -from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import should_use_fuse_rope -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import should_use_compresskv, \ is_enough_kv_cache_room_4_36 from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache, \ @@ -233,44 +231,11 @@ def llama_attention_forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, None) - kv_seq_len = key_states.size(2) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, :kv_seq_len] - else: - causal_mask = None - attn_weights = None - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, causal_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, causal_mask) - else: - if isinstance(past_key_value, DynamicFp8Cache): - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if causal_mask is not None: - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == key_states.size(2), math.sqrt(self.head_dim) + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 8774fd11059..d746f079991 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -46,11 +46,12 @@ from torch.nn.functional import scaled_dot_product_attention as sdpa from ipex_llm.transformers.models.common import merge_qkv_base +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check from ipex_llm.transformers.models.utils import should_use_fuse_rope -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \ - should_use_compresskv, is_enough_kv_cache_room_4_36, get_compresskv_attn_mask -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, \ + should_use_compresskv, is_enough_kv_cache_room_4_36 +from ipex_llm.transformers.models.utils import use_flash_attention from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, \ DynamicCompressCache, DynamicCompressFp8Cache from ipex_llm.utils.common import invalidInputError @@ -532,7 +533,6 @@ def qwen2_attention_forward( # [CompressKV] from ipex_llm.transformers.kv import DynamicCompressCache use_compresskv = isinstance(past_key_value, DynamicCompressCache) - use_quantizekv = isinstance(past_key_value, DynamicFp8Cache) if hasattr(self, 'qkv_proj') and self.qkv_proj is not None: qkv = self.qkv_proj(hidden_states) @@ -583,18 +583,8 @@ def qwen2_attention_forward( self.layer_idx, None) attn_weights = None - if query_states.device.type == "cpu": - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_output = sdpa(query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=self.is_causal and attention_mask is None and q_len > 1) - elif not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query_states, key_states, attention_mask): + if query_states.device.type == 'xpu' \ + and use_flash_attention(query_states, key_states, attention_mask): # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -602,42 +592,11 @@ def qwen2_attention_forward( key_states.to(device, dtype=torch.float16), value_states.to(device, dtype=torch.float16), is_causal=True).to(hidden_states.dtype) - elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_compresskv: - attention_mask = get_compresskv_attn_mask(key_states, attention_mask) - if use_quantizekv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantizekv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) else: - if use_quantizekv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim) + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 8d085ee8a32..5dec7a940f1 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -358,16 +358,6 @@ def use_xmx(x: torch.Tensor, qtype: int): ) -def use_fused_layer_norm(x: torch.Tensor, training: bool): - device = get_xpu_device_type(x) - return ( - not training - and not x.requires_grad - and device in ["arc", "flex", "pvc", "mtl", "lnl"] # fused layer norm cannot run on UHD - and x.numel() // x.size(-1) == 1 # fused layer norm is slower in first token - ) - - def fp16_fusion_check(proj, x, training): # only use fp16 fusion on PVC inference if proj is None: