Skip to content

Commit

Permalink
[NPU] Llama2 prefill use ov sdp (#12310)
Browse files Browse the repository at this point in the history
* prefill use sdp

* add param

* update

* fix style

* fix style

* meet comments
  • Loading branch information
cyita authored Nov 1, 2024
1 parent eda7649 commit 05c5d02
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 19 deletions.
25 changes: 20 additions & 5 deletions python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,20 @@ def __init__(
# define input, the order self.parameter matters
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))

# llama2 use ov sdp, other models need to test
use_prefill_sdp = self.intermediate_size == 11008

# Self Attention
if mode == "decode":
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
dtype=np.int64)
else:
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
dtype=np.int64)
if use_prefill_sdp:
attention_mask = None
else:
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
self.seq_len),
dtype=np.int64)

position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)

Expand Down Expand Up @@ -177,6 +184,7 @@ def __init__(
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
past_key=past_keys[i],
past_value=past_values[i],
use_prefill_sdp=use_prefill_sdp,
)
curr_key_values.append((new_key_states, new_value_states))

Expand All @@ -202,6 +210,7 @@ def build_decoder(
post_attention_layernorm_weight,
past_key=None,
past_value=None,
use_prefill_sdp=False,
):

residual = hidden_states
Expand All @@ -220,6 +229,7 @@ def build_decoder(
num_key_value_heads=self.num_key_value_heads,
head_dim=self.head_dim,
seq_len=self.seq_len,
use_prefill_sdp=use_prefill_sdp,
)
hidden_states = self.eltwise_add(residual, attn_output)
residual = hidden_states
Expand Down Expand Up @@ -427,6 +437,7 @@ def __init__(
)
self.layer_norm_0 = layer_norm_0
self.layer_norm_1 = layer_norm_1
self.use_prefill_sdp = intermediate_size == 11008

def forward(
self,
Expand All @@ -451,9 +462,13 @@ def forward(
seq_len = hidden_states.shape[1]

backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64))
if self.use_prefill_sdp:
inputs = (hidden_states.to(torch.float16),
position_ids.to(torch.int64))
else:
inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64))
inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
Expand Down
40 changes: 26 additions & 14 deletions python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ def attention(self,
seq_len,
q_bias=None,
k_bias=None,
v_bias=None):
v_bias=None,
use_prefill_sdp=False):
hidden_size = num_heads * head_dim
num_key_value_groups = num_heads // num_key_value_heads
groupsize = hidden_size // self.n_splits_linear
if self.n_splits_linear == 1:
query_states = self.linear(
hidden_states,
Expand Down Expand Up @@ -200,8 +200,13 @@ def attention(self,

query_states = self.transpose(query_states, [0, 2, 1, 3])
key_states = self.transpose(key_states, [0, 2, 1, 3])
use_ov_sdp = (mode == "prefill") and use_prefill_sdp
if self.transpose_value:
value_states = self.transpose(value_states, [0, 2, 3, 1])
new_value_states = self.transpose(value_states, [0, 2, 3, 1])
if use_ov_sdp:
value_states = self.transpose(value_states, [0, 2, 1, 3])
else:
value_states = new_value_states
else:
value_states = self.transpose(value_states, [0, 2, 1, 3])

Expand All @@ -216,7 +221,6 @@ def attention(self,
head_dim=head_dim,
)
new_key_states = key_states
new_value_states = value_states

if mode == "decode":
key_states = self.concat(past_key, key_states, axis=-2)
Expand All @@ -238,16 +242,24 @@ def attention(self,
num_key_value_heads=num_key_value_heads,
kv_seq_len=kv_seq_len,
head_dim=head_dim,
transpose=self.transpose_value)
attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(head_dim)
)
attention_mask = self.convert_to_fp16(attention_mask)
attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1)
attn_weight = self.convert_to_fp16(attn_weight)
attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value)
transpose=(self.transpose_value and (not use_ov_sdp)))
if use_ov_sdp:
value_states = self.convert_to_fp32(value_states)
key_states = self.convert_to_fp32(key_states)
query_states = self.convert_to_fp32(query_states)
attn_output = self.scaled_dot_product_attention(
query_states, key_states, value_states, None, True)
attn_output = self.convert_to_fp16(attn_output)
else:
attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(head_dim)
)
attention_mask = self.convert_to_fp16(attention_mask)
attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1)
attn_weight = self.convert_to_fp16(attn_weight)
attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value)

attn_output = self.transpose(attn_output, [0, 2, 1, 3])
attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
Expand Down

0 comments on commit 05c5d02

Please sign in to comment.