Skip to content

Commit

Permalink
LLM: add flash attention support for llama (#9518)
Browse files Browse the repository at this point in the history
* add initial flash attention for llama

* accelerate fp32 first token by changing to fp16 in advance

* support fp32
  • Loading branch information
rnwang04 authored Nov 23, 2023
1 parent d488b4b commit 8c97e33
Showing 1 changed file with 69 additions and 25 deletions.
94 changes: 69 additions & 25 deletions python/llm/src/bigdl/llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ def llama_attention_forward_4_31(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
if not self.training and not hidden_states.requires_grad:
fsdp_flag = check_flash_attention_available(hidden_states)
else:
fsdp_flag = False
if fsdp_flag and q_len > 1:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype

if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
Expand Down Expand Up @@ -194,31 +204,23 @@ def llama_attention_forward_4_31(

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
dtype=hidden_states.dtype)
dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=hidden_states.dtype)

attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

attn_weights_size = (bsz, self.num_heads, q_len, kv_seq_len)
if attn_weights.size() != attn_weights_size:
invalidInputError(False,
f"Attention weights should be of size {attn_weights_size}, "
f"but is {attn_weights.size()}")

if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
dtype=attention_dtype)

if fsdp_flag and q_len > 1:
# now only use flash attention for first token
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
key_states,
value_states,
is_causal=True)
attn_weights = None
else:
# otherwise, use native attention
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads)

attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size:
Expand All @@ -241,4 +243,46 @@ def llama_attention_forward_4_31(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
return attn_output.to(original_dtype), attn_weights, past_key_value


def check_flash_attention_available(query):
# check whether ipex flash attention can be used
if query.device.type != "xpu":
# ipex flash attention only support for xpu
return False
ipex_version = get_ipex_version()
if ipex_version <= "2.0.110+xpu":
# ipex flash attention is supported from ipex 2.1
return False
if not torch.xpu.has_xetla():
# ipex flash attention is only supported for xetla
# may update this later
return False
return True


def native_sdp(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim, num_heads):
attn_weights = torch.matmul(query,
key.transpose(2, 3)) / math.sqrt(head_dim)

attn_weights_size = (bsz, num_heads, q_len, kv_seq_len)
if attn_weights.size() != attn_weights_size:
invalidInputError(False,
f"Attention weights should be of size {attn_weights_size}, "
f"but is {attn_weights.size()}")

if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value.dtype)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights

0 comments on commit 8c97e33

Please sign in to comment.