Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add npu sdp and update npu readme #11562

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ In this directory, you will find examples on how you could apply IPEX-LLM INT4 o
| Phi-3 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) |
| Stablelm | [stabilityai/stablelm-zephyr-3b](https://huggingface.co/stabilityai/stablelm-zephyr-3b) |
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
| Deepseek | [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct) |

## 0. Requirements
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
Expand Down
64 changes: 64 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@


import torch
from torch.nn import functional as F
import importlib
from typing import Optional, Tuple
from ipex_llm.transformers.npu_models.common import merge_linear


Expand All @@ -51,3 +54,64 @@ def baichuan_mlp_forward(self, x):
gate_proj, up_proj = gate_up_proj.chunk(2, dim=-1)
down_proj = self.down_proj(self.act_fn(gate_proj) * up_proj)
return down_proj


def baichuan_attention_fwd(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
modeling_module_name = self.__class__.__module__
module = importlib.import_module(modeling_module_name)
apply_rotary_pos_emb = module.apply_rotary_pos_emb

bsz, q_len, _ = hidden_states.size()

proj = self.W_pack(hidden_states)
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
# [bsz, nh, t, hd]

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

if query_states.size(2) == key_states.size(2):
# first token
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask
)
attn_weights = None
else:
with torch.backends.cuda.sdp_kernel(enable_flash=True,
enable_math=True, enable_mem_efficient=True):
attn_output = F.scaled_dot_product_attention(query_states, key_states,
value_states, attn_mask=attention_mask)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
12 changes: 12 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,11 @@ def optimize_llm(model: torch.nn.Module):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.npu_models.baichuan import baichuan_mlp_forward, merge_mlp
from ipex_llm.transformers.npu_models.baichuan import baichuan_attention_fwd
model.apply(merge_mlp)

convert_forward(model, module.MLP, baichuan_mlp_forward)
convert_forward(model, module.Attention, baichuan_attention_fwd)

elif model.config.model_type == "phi3_v":
modeling_module_name = model.__class__.__module__
Expand All @@ -189,3 +191,13 @@ def optimize_llm(model: torch.nn.Module):
from transformers.models.clip.modeling_clip import CLIPAttention
convert_forward(model, CLIPAttention, phi3v_encoder_attention_forward)
convert_forward(model, module.Phi3VModel, phi3v_model_forward)

from ipex_llm.transformers.npu_models.phi3 import phi3_attention_forward
convert_forward(model, module.Phi3Attention, phi3_attention_forward)

elif model.config.model_type == "phi3":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.npu_models.phi3 import phi3_attention_forward

convert_forward(model, module.Phi3Attention, phi3_attention_forward)
157 changes: 157 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_models/phi3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
# which is licensed under Apache License 2.0:
#
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional, Tuple, List
import torch
from torch import nn
import math
import importlib
from transformers.cache_utils import Cache
from ipex_llm.utils.common.log4Error import invalidInputError


def phi3_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
modeling_module_name = self.__class__.__module__
module = importlib.import_module(modeling_module_name)
apply_rotary_pos_emb, repeat_kv = module.apply_rotary_pos_emb, module.repeat_kv
bsz, q_len, _ = hidden_states.size()

qkv = self.qkv_proj(hidden_states)
query_pos = self.num_heads * self.head_dim
query_states = qkv[..., :query_pos]
key_states = qkv[..., query_pos:query_pos + self.num_key_value_heads * self.head_dim]
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim:]

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(
False,
f"The cache structure has changed since version v4.36."
f"If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching,"
"please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
else:
causal_mask = None

if query_states.size(2) == key_states.size(2):
# first token
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=self.is_causal and causal_mask is None and q_len > 1,
)
attn_weights = None
else:

attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention weights should be of"
f"size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value_states.dtype)
attn_weights = nn.functional.dropout(attn_weights,
p=self.attention_dropout, training=self.training)

attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value
Loading