From a828648e59470f77fed7a9ddd3d2c47d0710c7e6 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 24 Dec 2024 15:02:12 +0800 Subject: [PATCH 1/2] refactor chatglm2, internlm, stablelm and qwen --- .../ipex_llm/transformers/models/chatglm2.py | 86 ++----------- .../ipex_llm/transformers/models/internlm.py | 114 +++--------------- .../src/ipex_llm/transformers/models/qwen.py | 85 +++---------- .../ipex_llm/transformers/models/stablelm.py | 47 ++------ 4 files changed, 53 insertions(+), 279 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 9e213e178e9..633b4c7acc3 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -18,17 +18,16 @@ # import os -import math import torch from typing import Optional, Tuple from transformers.modeling_outputs import BaseModelOutputWithPast -from ipex_llm.utils.common.log4Error import invalidInputError -from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal +from ipex_llm.transformers.models.common import scaled_dot_product_attention +from ipex_llm.transformers.models.utils import update_past_key_value +from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \ - use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36 +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, +from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) @@ -310,50 +309,10 @@ def chatglm2_attention_forward( value_states.permute(2, 0, 1, 3)) if use_cache else None # IPEX-LLM OPT: sdp - attn_weights = None - if use_sdp(q_len, kv_seq_len, head_dim, query_states): - import xe_addons - if use_compresskv and attention_mask is not None: - attention_mask = None - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, - attention_mask) - elif query_states.device.type == "cpu": - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, n_head // n_kv_head) - value_states = repeat_kv(value_states, n_head // n_kv_head) - if q_len == kv_seq_len: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, is_causal=True - ) - else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask - ) - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, n_head // n_kv_head) - value_states = repeat_kv(value_states, n_head // n_kv_head) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) # context_layer's shape: [bsz, n_head, seq_len, head_dim] -> [seq_len, bsz, n_head * head_dim] attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(q_len, bsz, n_head * head_dim) @@ -541,29 +500,10 @@ def codegeex_attention_forward( # ================= # Output. [sq, b, h] # ================= - context_layer = None - if use_sdp(q_len, kv_seq_len, head_dim, query_layer): - import xe_addons - context_layer = xe_addons.sdp(query_layer, key_layer, value_layer, attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_layer, self.training): - import xe_addons - context_layer = xe_addons.sdp_causal(query_layer, key_layer, value_layer, attention_mask) - else: - # repeat k/v heads if n_kv_heads < n_heads - key_layer = repeat_kv(key_layer, n_head // n_kv_head) - value_layer = repeat_kv(value_layer, n_head // n_kv_head) - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - is_causal=True) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - attention_mask) + context_layer = scaled_dot_product_attention( + query_layer, key_layer, value_layer, + attention_mask, q_len == kv_seq_len + ) context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(q_len, bsz, diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index 68e47df6a47..9f71fb38cfa 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -36,18 +36,16 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch InternLM model.""" -import math from typing import Optional, Tuple, List import torch import torch.utils.checkpoint -from torch import nn from ipex_llm.utils.common.log4Error import invalidInputError -from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax +from ipex_llm.transformers.models.common import merge_qkv_base +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import update_past_key_value -from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal from einops import rearrange @@ -98,35 +96,10 @@ def internlm_attention_forward( # IPEX-LLM OPT: sdp attn_weights = None - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -207,38 +180,10 @@ def internlm2_attention_forward( # IPEX-LLM OPT: sdp attn_weights = None - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -409,38 +354,11 @@ def internlm_xcomposser2_attention_forward( past_key_value = (key_states, value_states) if use_cache else None # IPEX-LLM OPT: sdp - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attn_weights = None + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index 5211536b5e1..b5fb5d7a83a 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -22,19 +22,19 @@ # LICENSE file in the root directory of this source tree. # -import math from typing import Optional, Tuple, Union, Callable, List import torch import torch.nn.functional as F import torch.utils.checkpoint from transformers.utils import logging +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import update_past_key_value, should_use_fuse_rope -from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache +from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import rotate_half, SILU from ipex_llm.transformers.models.utils import mlp_fusion_check -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal -from ipex_llm.utils.common import invalidInputError, invalidOperationError +from ipex_llm.transformers.models.utils import use_flash_attention +from ipex_llm.utils.common import invalidInputError from transformers.modeling_outputs import BaseModelOutputWithPast @@ -118,20 +118,13 @@ def qwen_attention_forward( # IPEX-LLM OPT: sdp attn_weights = None - if not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query_states, key_states, attention_mask): + if use_flash_attention(query_states, key_states, attention_mask): attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16), key_states.to(dtype=torch.float16), value_states.to(dtype=torch.float16), is_causal=True).to(hidden_states.dtype) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, None) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, None) else: - if q_len > 1: + if q_len > 1 and q_len != kv_seq_len: causal_mask = torch.tril( torch.ones((kv_seq_len, kv_seq_len), dtype=torch.bool, device=query_states.device) ).view(1, 1, kv_seq_len, kv_seq_len) @@ -146,29 +139,10 @@ def qwen_attention_forward( else: attention_mask = None - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - if self.softmax_in_fp32: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to( - value_states.dtype) - else: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -247,20 +221,14 @@ def qwen_attention_forward_registered( # IPEX-LLM OPT: sdp attn_weights = None - if not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query_states, key_states, attention_mask): + + if use_flash_attention(query_states, key_states, attention_mask): attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16), key_states.to(dtype=torch.float16), value_states.to(dtype=torch.float16), is_causal=True).to(hidden_states.dtype) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, None) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, None) else: - if q_len > 1: + if q_len > 1 and q_len != kv_seq_len: causal_mask = registered_causal_mask[ :, :, kv_seq_len - q_len:kv_seq_len, :kv_seq_len ] @@ -272,29 +240,10 @@ def qwen_attention_forward_registered( else: attention_mask = None - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - if self.softmax_in_fp32: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to( - value_states.dtype) - else: - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index af6c5dee530..9965a25e7d1 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -37,18 +37,16 @@ # limitations under the License. # -import math from typing import Optional, Tuple, List import torch from transformers.cache_utils import Cache -from transformers.models.stablelm.modeling_stablelm import repeat_kv from transformers.models.stablelm.modeling_stablelm import StableLmAttention, StableLmModel -from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax +from ipex_llm.transformers.models.common import merge_qkv_base +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import apply_rotary_pos_emb -from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal -from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, use_quantize_kv_cache +from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache @@ -143,41 +141,10 @@ def stablelm_attention_forward( # IPEX-LLM OPT: sdp attn_weights = None - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) - else: - if isinstance(past_key_value, DynamicFp8Cache): - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights) - attn_weights = self.attention_dropout(attn_weights) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) From 2baf33653d74feb63872a89ca03fb7f16862ce63 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 24 Dec 2024 15:21:14 +0800 Subject: [PATCH 2/2] fix --- python/llm/src/ipex_llm/transformers/models/chatglm2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 633b4c7acc3..c41744d21c2 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -26,7 +26,7 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, +from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache