From 7aaf02f602785e73260e9fb6bc900444e7cc69c3 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 24 Dec 2024 14:16:30 +0800 Subject: [PATCH] refactor baichuan, glm4 and minicpm3 (#12600) --- .../ipex_llm/transformers/models/baichuan.py | 42 +++----------- .../ipex_llm/transformers/models/chatglm4.py | 58 +++---------------- .../ipex_llm/transformers/models/chatglm4v.py | 57 +++--------------- .../ipex_llm/transformers/models/minicpm3.py | 42 +++----------- 4 files changed, 32 insertions(+), 167 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 9d41279244b..a44909caf33 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -24,16 +24,16 @@ import torch.utils.checkpoint from torch.nn import functional as F from transformers.modeling_outputs import BaseModelOutputWithPast +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \ - should_use_compresskv, get_compresskv_attn_mask + should_use_compresskv from ipex_llm.transformers.models.utils import update_past_key_value from ipex_llm.transformers.models.utils import should_use_fuse_rope -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU from ipex_llm.transformers.models.utils import mlp_fusion_check from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 from ipex_llm.transformers.kv import DynamicCompressFp8Cache, DynamicCompressCache -from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache import warnings import os @@ -301,42 +301,16 @@ def baichuan_attention_forward_7b( # IPEX-LLM OPT: sdp attn_weights = None - if not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query_states, key_states, attention_mask): + if use_flash_attention(query_states, key_states, attention_mask): attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16), key_states.to(dtype=torch.float16), value_states.to(dtype=torch.float16), is_causal=True).to(hidden_states.dtype) - elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - if use_compresskv: - attention_mask = get_compresskv_attn_mask(key_states, attention_mask) - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index cf38f5eaa95..c3adc3720ee 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -20,15 +20,14 @@ import os import torch from typing import Optional, Tuple, Union -from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \ - use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36, \ - get_compresskv_attn_mask +from ipex_llm.transformers.models.common import scaled_dot_product_attention +from ipex_llm.transformers.models.utils import update_past_key_value +from ipex_llm.transformers.models.utils import use_quantize_kv_cache +from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb -from ipex_llm.transformers.models.chatglm2 import repeat_kv from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache from transformers.modeling_outputs import BaseModelOutputWithPast -import math + KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) @@ -241,49 +240,10 @@ def chatglm4_attention_forward( past_key_value = None # IPEX-LLM OPT: sdp - attn_weights = None - if use_sdp(q_len, kv_seq_len, head_dim, query_states): - import xe_addons - if use_compresskv: - attention_mask = get_compresskv_attn_mask(key_states, attention_mask) - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, - attention_mask) - elif query_states.device.type == "cpu": - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, n_head // n_kv_head) - value_states = repeat_kv(value_states, n_head // n_kv_head) - if q_len == kv_seq_len: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, is_causal=True - ) - else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask - ) - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, n_head // n_kv_head) - value_states = repeat_kv(value_states, n_head // n_kv_head) - attn_weights = torch.matmul(query_states / math.sqrt(head_dim), - key_states.transpose(2, 3)) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) # context_layer's shape: [bsz, n_head, seq_len, head_dim] -> [seq_len, bsz, n_head * head_dim] attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, n_head * head_dim) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py index 2028cae033f..8696846336b 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py @@ -20,10 +20,10 @@ import torch from typing import Optional, Tuple, Union from ipex_llm.transformers.models.common import merge_qkv_base -from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal +from ipex_llm.transformers.models.common import scaled_dot_product_attention +from ipex_llm.transformers.models.utils import update_past_key_value +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb -from ipex_llm.transformers.models.chatglm2 import repeat_kv from ipex_llm.utils.common import invalidInputError from transformers.modeling_outputs import BaseModelOutputWithPast import math @@ -246,53 +246,10 @@ def chatglm4v_attention_forward( past_key_value = None # IPEX-LLM OPT: sdp - attn_weights = None - if use_sdp(q_len, kv_seq_len, head_dim, query_states): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training): - import xe_addons - if use_quantize_kv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, value_states, - attention_mask) - elif query_states.device.type == "cpu": - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, n_head // n_kv_head) - value_states = repeat_kv(value_states, n_head // n_kv_head) - if q_len == kv_seq_len: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, is_causal=True - ) - else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask - ) - else: - if use_quantize_kv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, n_head // n_kv_head) - value_states = repeat_kv(value_states, n_head // n_kv_head) - attn_weights = torch.matmul(query_states / math.sqrt(head_dim), - key_states.transpose(2, 3)) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - if kv_seq_len >= 2048 or bsz >= 64: - # for memory considerations, do not upcast attention to fp32 - # for long sequences or large batches - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - else: - # upcast attention to fp32 - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len + ) # context_layer's shape: [bsz, n_head, seq_len, head_dim] -> [seq_len, bsz, n_head * head_dim] attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, n_head * head_dim) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm3.py b/python/llm/src/ipex_llm/transformers/models/minicpm3.py index 820cce22fc6..8cef25f0989 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm3.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm3.py @@ -6,10 +6,10 @@ from transformers.cache_utils import Cache from ipex_llm.utils.common.log4Error import invalidInputError +from ipex_llm.transformers.models.common import scaled_dot_product_attention from ipex_llm.transformers.models.utils import should_use_fuse_rope from ipex_llm.transformers.models.utils import rotate_half -from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import use_quantize_kv_cache from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache @@ -25,7 +25,7 @@ def pre_compute_inv_freq(module: torch.nn.Module): def padding_v_head_dim(module: torch.nn.Module): if module.__class__.__name__ == "MiniCPMAttention": - k_head_dim = module.qk_rope_head_dim + module.qk_nope_head_dim + k_head_dim = module.q_head_dim v_head_dim = module.v_head_dim invalidInputError(k_head_dim >= v_head_dim, f"unsupported k_head_dim and v_head_dim: {k_head_dim} {v_head_dim}") @@ -183,37 +183,11 @@ def minicpm3_attention_forward( self.layer_idx, None) attn_weights = None - if use_sdp(q_len, kv_seq_len, self.q_head_dim, query_states): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - attn_output = attn_output[:, :, :, :self.v_head_dim] - elif use_sdp_causal(q_len, kv_seq_len, self.q_head_dim, query_states, False): - import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) - attn_output = attn_output[:, :, :, :self.v_head_dim] - else: - if isinstance(past_key_value, DynamicFp8Cache): - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states[:, :, :, :self.v_head_dim]) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len, self.softmax_scale + ) + attn_output = attn_output[:, :, :, :self.v_head_dim] attn_output = attn_output.transpose(1, 2).contiguous()