diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 93f1ff36448..76187872b38 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -110,13 +110,20 @@ def __init__( # define input, the order self.parameter matters input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) + # llama2 use ov sdp, other models need to test + use_prefill_sdp = self.intermediate_size == 11008 + # Self Attention if mode == "decode": attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), dtype=np.int64) else: - attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), - dtype=np.int64) + if use_prefill_sdp: + attention_mask = None + else: + attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, + self.seq_len), + dtype=np.int64) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) @@ -177,6 +184,7 @@ def __init__( post_attention_layernorm_weight=post_attn_layernorm_weights[i], past_key=past_keys[i], past_value=past_values[i], + use_prefill_sdp=use_prefill_sdp, ) curr_key_values.append((new_key_states, new_value_states)) @@ -202,6 +210,7 @@ def build_decoder( post_attention_layernorm_weight, past_key=None, past_value=None, + use_prefill_sdp=False, ): residual = hidden_states @@ -220,6 +229,7 @@ def build_decoder( num_key_value_heads=self.num_key_value_heads, head_dim=self.head_dim, seq_len=self.seq_len, + use_prefill_sdp=use_prefill_sdp, ) hidden_states = self.eltwise_add(residual, attn_output) residual = hidden_states @@ -427,6 +437,7 @@ def __init__( ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 + self.use_prefill_sdp = intermediate_size == 11008 def forward( self, @@ -451,9 +462,13 @@ def forward( seq_len = hidden_states.shape[1] backend_cls = self.backend_cls_prefill - inputs = (hidden_states.to(torch.float16), - attention_mask.to(torch.int64), - position_ids.to(torch.int64)) + if self.use_prefill_sdp: + inputs = (hidden_states.to(torch.float16), + position_ids.to(torch.int64)) + else: + inputs = (hidden_states.to(torch.float16), + attention_mask.to(torch.int64), + position_ids.to(torch.int64)) inputs += (self.layer_norm_0, self.layer_norm_1) hidden_states, past_key, past_value = run_model( inputs, self.op_parameters, backend_cls, self.op_id, replica=2 diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 1550d6837f6..3ac026aa687 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -135,10 +135,10 @@ def attention(self, seq_len, q_bias=None, k_bias=None, - v_bias=None): + v_bias=None, + use_prefill_sdp=False): hidden_size = num_heads * head_dim num_key_value_groups = num_heads // num_key_value_heads - groupsize = hidden_size // self.n_splits_linear if self.n_splits_linear == 1: query_states = self.linear( hidden_states, @@ -200,8 +200,13 @@ def attention(self, query_states = self.transpose(query_states, [0, 2, 1, 3]) key_states = self.transpose(key_states, [0, 2, 1, 3]) + use_ov_sdp = (mode == "prefill") and use_prefill_sdp if self.transpose_value: - value_states = self.transpose(value_states, [0, 2, 3, 1]) + new_value_states = self.transpose(value_states, [0, 2, 3, 1]) + if use_ov_sdp: + value_states = self.transpose(value_states, [0, 2, 1, 3]) + else: + value_states = new_value_states else: value_states = self.transpose(value_states, [0, 2, 1, 3]) @@ -216,7 +221,6 @@ def attention(self, head_dim=head_dim, ) new_key_states = key_states - new_value_states = value_states if mode == "decode": key_states = self.concat(past_key, key_states, axis=-2) @@ -238,16 +242,24 @@ def attention(self, num_key_value_heads=num_key_value_heads, kv_seq_len=kv_seq_len, head_dim=head_dim, - transpose=self.transpose_value) - attn_weight = self.matmul(query_states, key_states, False, True) / ( - math.sqrt(head_dim) - ) - attention_mask = self.convert_to_fp16(attention_mask) - attn_weight = self.eltwise_add(attn_weight, attention_mask) - attn_weight = self.convert_to_fp32(attn_weight) - attn_weight = self.softmax(attn_weight, -1) - attn_weight = self.convert_to_fp16(attn_weight) - attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) + transpose=(self.transpose_value and (not use_ov_sdp))) + if use_ov_sdp: + value_states = self.convert_to_fp32(value_states) + key_states = self.convert_to_fp32(key_states) + query_states = self.convert_to_fp32(query_states) + attn_output = self.scaled_dot_product_attention( + query_states, key_states, value_states, None, True) + attn_output = self.convert_to_fp16(attn_output) + else: + attn_weight = self.matmul(query_states, key_states, False, True) / ( + math.sqrt(head_dim) + ) + attention_mask = self.convert_to_fp16(attention_mask) + attn_weight = self.eltwise_add(attn_weight, attention_mask) + attn_weight = self.convert_to_fp32(attn_weight) + attn_weight = self.softmax(attn_weight, -1) + attn_weight = self.convert_to_fp16(attn_weight) + attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) attn_output = self.transpose(attn_output, [0, 2, 1, 3]) attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])