Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] Llama2 prefill use ov sdp #12310

Merged
merged 7 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,20 @@ def __init__(
# define input, the order self.parameter matters
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))

# open llama2 other models need to test
use_prefill_sdp = self.intermediate_size == 11008

# Self Attention
if mode == "decode":
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
dtype=np.int64)
else:
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
dtype=np.int64)
if use_prefill_sdp:
attention_mask = None
else:
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
self.seq_len),
dtype=np.int64)

position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)

Expand Down Expand Up @@ -177,6 +184,7 @@ def __init__(
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
past_key=past_keys[i],
past_value=past_values[i],
use_prefill_sdp=use_prefill_sdp,
)
curr_key_values.append((new_key_states, new_value_states))

Expand All @@ -202,6 +210,7 @@ def build_decoder(
post_attention_layernorm_weight,
past_key=None,
past_value=None,
use_prefill_sdp=False,
):

residual = hidden_states
Expand All @@ -220,6 +229,7 @@ def build_decoder(
num_key_value_heads=self.num_key_value_heads,
head_dim=self.head_dim,
seq_len=self.seq_len,
use_prefill_sdp=use_prefill_sdp,
)
hidden_states = self.eltwise_add(residual, attn_output)
residual = hidden_states
Expand Down Expand Up @@ -427,6 +437,7 @@ def __init__(
)
self.layer_norm_0 = layer_norm_0
self.layer_norm_1 = layer_norm_1
self.use_prefill_sdp = intermediate_size == 11008

def forward(
self,
Expand All @@ -451,9 +462,13 @@ def forward(
seq_len = hidden_states.shape[1]

backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64))
if self.use_prefill_sdp:
inputs = (hidden_states.to(torch.float16),
position_ids.to(torch.int64))
else:
inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64))
inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model(
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
Expand Down
41 changes: 27 additions & 14 deletions python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ def attention(self,
seq_len,
q_bias=None,
k_bias=None,
v_bias=None):
v_bias=None,
use_prefill_sdp=False):
hidden_size = num_heads * head_dim
num_key_value_groups = num_heads // num_key_value_heads
groupsize = hidden_size // self.n_splits_linear
if self.n_splits_linear == 1:
query_states = self.linear(
hidden_states,
Expand Down Expand Up @@ -200,8 +200,14 @@ def attention(self,

query_states = self.transpose(query_states, [0, 2, 1, 3])
key_states = self.transpose(key_states, [0, 2, 1, 3])
use_ov_sdp = (mode == "prefill") and use_prefill_sdp
print(f"-------------------- use_ov_sdp: {use_ov_sdp}, groupsize: {self.group_size}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this print ?

if self.transpose_value:
value_states = self.transpose(value_states, [0, 2, 3, 1])
new_value_states = self.transpose(value_states, [0, 2, 3, 1])
if use_ov_sdp:
value_states = self.transpose(value_states, [0, 2, 1, 3])
else:
value_states = new_value_states
else:
value_states = self.transpose(value_states, [0, 2, 1, 3])

Expand All @@ -216,7 +222,6 @@ def attention(self,
head_dim=head_dim,
)
new_key_states = key_states
new_value_states = value_states

if mode == "decode":
key_states = self.concat(past_key, key_states, axis=-2)
Expand All @@ -238,16 +243,24 @@ def attention(self,
num_key_value_heads=num_key_value_heads,
kv_seq_len=kv_seq_len,
head_dim=head_dim,
transpose=self.transpose_value)
attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(head_dim)
)
attention_mask = self.convert_to_fp16(attention_mask)
attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1)
attn_weight = self.convert_to_fp16(attn_weight)
attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value)
transpose=(self.transpose_value and (not use_ov_sdp)))
if use_ov_sdp:
value_states = self.convert_to_fp32(value_states)
key_states = self.convert_to_fp32(key_states)
query_states = self.convert_to_fp32(query_states)
attn_output = self.scaled_dot_product_attention(
query_states, key_states, value_states, None, True)
attn_output = self.convert_to_fp16(attn_output)
else:
attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(head_dim)
)
attention_mask = self.convert_to_fp16(attention_mask)
attn_weight = self.eltwise_add(attn_weight, attention_mask)
attn_weight = self.convert_to_fp32(attn_weight)
attn_weight = self.softmax(attn_weight, -1)
attn_weight = self.convert_to_fp16(attn_weight)
attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value)

attn_output = self.transpose(attn_output, [0, 2, 1, 3])
attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
Expand Down
Loading