From 8086554d3396b18b9d1cbac6db45c762d5dcc50b Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 12 Apr 2024 10:49:02 +0800 Subject: [PATCH] use new fp16 sdp in llama and mistral (#10734) --- .../src/ipex_llm/transformers/models/llama.py | 32 ++++++------- .../ipex_llm/transformers/models/mistral.py | 46 +++++++++++-------- 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 47f5aec73de..99bddf04f04 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -647,12 +647,11 @@ def llama_attention_forward_4_31_original( past_key_value = (key_states, value_states) if use_cache else None - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - if not self.training and not hidden_states.requires_grad and \ use_flash_attention(query_states, key_states, attention_mask): + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), key_states.to(device, dtype=torch.float16), value_states.to(device, dtype=torch.float16), @@ -660,13 +659,14 @@ def llama_attention_forward_4_31_original( attn_weights = None elif not self.training and not hidden_states.requires_grad and \ use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query_states, - key_states, - value_states) + import linear_q4_0 + attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) # otherwise, use native attention attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attention_mask, @@ -1305,12 +1305,11 @@ def llama_attention_forward_4_36_original( past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - if not self.training and not hidden_states.requires_grad and \ use_flash_attention(query_states, key_states, attention_mask): + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) # now only use flash attention for first token attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), key_states.to(device, dtype=torch.float16), @@ -1319,13 +1318,14 @@ def llama_attention_forward_4_36_original( attn_weights = None elif not self.training and not hidden_states.requires_grad and \ use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query_states, - key_states, - value_states) + import linear_q4_0 + attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) # otherwise, use native attention attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attention_mask, diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index c81cafff3a9..c287ec0ff9f 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -495,13 +495,12 @@ def mistral_attention_forward_original( else: attention_dtype = original_dtype - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) - value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) - if fsdp_flag: + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), key_states, value_states, @@ -510,15 +509,19 @@ def mistral_attention_forward_original( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query_states, - key_states, - value_states) + # new fp16 sdp doesn't require repeat_kv + import linear_q4_0 + attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) else: + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states, @@ -885,13 +888,12 @@ def mistral_attention_forward_4_36_original( else: attention_dtype = original_dtype - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) - value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) - if fsdp_flag: + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), key_states, value_states, @@ -900,15 +902,19 @@ def mistral_attention_forward_4_36_original( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query_states, - key_states, - value_states) + # new fp16 sdp doesn't require repeat_kv + import linear_q4_0 + attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) else: + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states,