From b654b8a67fb0d08fdcc328179890987da54fc48e Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 19 Dec 2024 14:11:02 +0800 Subject: [PATCH] optimize new minicpm model --- .../ipex_llm/transformers/models/common.py | 10 ++-- .../ipex_llm/transformers/models/minicpm.py | 47 ++----------------- .../ipex_llm/transformers/models/minicpmv.py | 20 ++------ 3 files changed, 15 insertions(+), 62 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index fa5f94d7724..0c140c5c68f 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -217,8 +217,8 @@ def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, de return mask -def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - mask: torch.Tensor = None, +def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, mask: torch.Tensor = None, is_causal: bool = False, scale: float = None) -> torch.Tensor: bsz, n_heads, seq_length, head_dim = query.shape _, n_kv_heads, kv_length, _ = key.shape @@ -268,7 +268,7 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: attn_output = xe_addons.sdp(query, key, value, mask) else: if key.dtype == torch.uint8: - attn_output = xe_addons.sdp_fp8(query, key, value, mask) + attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask) else: attn_output = xe_addons.sdp_non_causal(query, key, value, mask) @@ -281,6 +281,8 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: key = repeat_kv(key, n_heads // n_kv_heads) value = repeat_kv(value, n_heads // n_kv_heads) - return torch.nn.functional.scaled_dot_product_attention( + attn_output = torch.nn.functional.scaled_dot_product_attention( query, key, value, mask, is_causal=is_causal, scale=scale ) + attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug + return attn_output diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index 30a262776ff..3bc95d6c3c7 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -127,49 +127,12 @@ def minicpm_attention_forward( key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, None) + from ipex_llm.transformers.models.common import scaled_dot_product_attention attn_weights = None - if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): - import xe_addons - # [CompressKV] - if use_compresskv: - attention_mask = get_compresskv_attn_mask(key_states, attention_mask) - - if use_quantizekv: - attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, - attention_mask) - else: - attn_output = xe_addons.sdp(query_states, key_states, value_states, - attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): - import xe_addons - if use_quantizekv: - attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, - value_states, attention_mask) - else: - attn_output = xe_addons.sdp_causal(query_states, key_states, - value_states, attention_mask) - else: - if use_quantizekv: - key_states, value_states = restore_fp8_kv_cache(key_states, value_states, - query_states.dtype) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul( - query_states, key_states.transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training - ) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = scaled_dot_product_attention( + query_states, key_states, value_states, + attention_mask, q_len == kv_seq_len, math.sqrt(self.head_dim) + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 7f9bba681cc..dc996691c71 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -28,7 +28,6 @@ from torch.nn.functional import linear from ipex_llm.transformers.models.common import merge_qkv_base, padding_qkv_hd from ipex_llm.transformers.models.common import attention_softmax -from ipex_llm.transformers.models.utils import use_sdp_non_causal from transformers import AutoProcessor, TextIteratorStreamer from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor @@ -73,21 +72,10 @@ def siglip_attention_forward( 72, 80 ) - if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype): - import xe_addons - attn_weights = None - attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(), - value_states.contiguous(), attention_mask) - else: - attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3)) - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = attention_softmax(attn_weights) - - attn_weights = torch.nn.functional.dropout(attn_weights, - p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + from ipex_llm.transformers.models.common import scaled_dot_product_attention + attn_weights = None + attn_output = scaled_dot_product_attention(query_states, key_states, value_states, + attention_mask, False, math.sqrt(self.head_dim)) attn_output = attn_output[:, :, :, :self.head_dim]