-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
optimize phi3-v encoder npu performance and add multimodal example (#…
…11553) * phi3-v * readme
- Loading branch information
1 parent
70ab1a6
commit 105e124
Showing
4 changed files
with
370 additions
and
0 deletions.
There are no files selected for viewing
75 changes: 75 additions & 0 deletions
75
python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Run Large Multimodal Model on Intel NPU | ||
In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on [Intel NPUs](../../../README.md). In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on Intel NPUs. See the table blow for verified models. | ||
|
||
## Verified Models | ||
|
||
| Model | Model Link | | ||
|------------|----------------------------------------------------------------| | ||
| Phi-3-Vision | [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) | | ||
|
||
## 0. Requirements | ||
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU. | ||
Go to https://www.intel.com/content/www/us/en/download/794734/intel-npu-driver-windows.html to download and unzip the driver. | ||
Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**. | ||
Right click and select **Update Driver**. And then manually select the folder unzipped from the driver. | ||
|
||
## Example: Predict Tokens using `generate()` API | ||
In the example [generate.py](./generate.py), we show a basic use case for a phi-3-vision model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel NPUs. | ||
### 1. Install | ||
#### 1.1 Installation on Windows | ||
We suggest using conda to manage environment: | ||
```bash | ||
conda create -n llm python=3.10 libuv | ||
conda activate llm | ||
|
||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default | ||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ | ||
|
||
# below command will install intel_npu_acceleration_library | ||
pip install intel-npu-acceleration-library==1.3 | ||
|
||
pip install transformers==4.40 | ||
``` | ||
|
||
### 2. Runtime Configurations | ||
For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device. | ||
#### 2.1 Configurations for Windows | ||
|
||
**Following envrionment variables are required**: | ||
|
||
```cmd | ||
set BIGDL_USE_NPU=1 | ||
``` | ||
|
||
### 3. Running examples | ||
|
||
``` | ||
python ./generate.py | ||
``` | ||
|
||
Arguments info: | ||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Phi-3-vision model (e.g. `microsoft/Phi-3-vision-128k-instruct`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'microsoft/Phi-3-vision-128k-instruct'`, and more verified models please see the list in [Verified Models](#verified-models). | ||
- `--image-url-or-path IMAGE_URL_OR_PATH`: argument defining the image to be infered. It is default to be `'http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg'`. | ||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is in the image?'`. | ||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. | ||
- `--load_in_low_bit`: argument defining the `load_in_low_bit` format used. It is default to be `sym_int8`, `sym_int4` can also be used. | ||
|
||
#### Sample Output | ||
#### [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) | ||
|
||
```log | ||
Inference time: xxxx s | ||
-------------------- Prompt -------------------- | ||
Message: [{'role': 'user', 'content': '<|image_1|>\nWhat is in the image?'}] | ||
Image link/path: http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg | ||
-------------------- Output -------------------- | ||
What is in the image? | ||
The image shows a young girl holding a white teddy bear. She is wearing a pink dress with a heart on it. The background includes a stone | ||
``` | ||
|
||
The sample input image is (which is fetched from [COCO dataset](https://cocodataset.org/#explore?id=264959)): | ||
|
||
<a href="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg"><img width=400px src="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg" ></a> | ||
|
93 changes: 93 additions & 0 deletions
93
python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/generate.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import os | ||
import time | ||
import torch | ||
import argparse | ||
import requests | ||
|
||
from PIL import Image | ||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM | ||
from transformers import AutoProcessor | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for phi-3 model') | ||
parser.add_argument('--repo-id-or-model-path', type=str, default="microsoft/Phi-3-vision-128k-instruct", | ||
help='The huggingface repo id for the phi-3-vision model to be downloaded' | ||
', or the path to the huggingface checkpoint folder') | ||
parser.add_argument('--image-url-or-path', type=str, | ||
default="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg", | ||
help='The URL or path to the image to infer') | ||
parser.add_argument('--prompt', type=str, default="What is in the image?", | ||
help='Prompt to infer') | ||
parser.add_argument('--n-predict', type=int, default=32, | ||
help='Max tokens to predict') | ||
parser.add_argument('--load_in_low_bit', type=str, default="sym_int4", | ||
help='Load in low bit to use') | ||
|
||
|
||
args = parser.parse_args() | ||
model_path = args.repo_id_or_model_path | ||
image_path = args.image_url_or_path | ||
|
||
# Load model in SYM_INT4, | ||
# which convert the relevant layers in the model into SYM_INT4 format | ||
# You could also try `'sym_int8'` for INT8 | ||
# `_attn_implementation="eager"` is required for phi-3-vision | ||
# `modules_to_not_convert=["vision_embed_tokens"]` and `model = model.half()` are for acceleration and are optional | ||
model = AutoModelForCausalLM.from_pretrained(model_path, | ||
trust_remote_code=True, | ||
load_in_low_bit=args.load_in_low_bit, | ||
_attn_implementation="eager", | ||
modules_to_not_convert=["vision_embed_tokens"]) | ||
|
||
# Load processor | ||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | ||
|
||
# here the message formatting refers to https://huggingface.co/microsoft/Phi-3-vision-128k-instruct#sample-inference-code | ||
messages = [ | ||
{"role": "user", "content": "<|image_1|>\n{prompt}".format(prompt=args.prompt)}, | ||
] | ||
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | ||
|
||
if os.path.exists(image_path): | ||
image = Image.open(image_path) | ||
else: | ||
image = Image.open(requests.get(image_path, stream=True).raw) | ||
|
||
# Generate predicted tokens | ||
with torch.inference_mode(): | ||
# start inference | ||
st = time.time() | ||
|
||
inputs = processor(prompt, [image], return_tensors="pt") | ||
output = model.generate(**inputs, | ||
eos_token_id=processor.tokenizer.eos_token_id, | ||
num_beams=1, | ||
do_sample=False, | ||
max_new_tokens=args.n_predict, | ||
temperature=0.0) | ||
end = time.time() | ||
print(f'Inference time: {end-st} s') | ||
output_str = processor.decode(output[0], | ||
skip_special_tokens=True, | ||
clean_up_tokenization_spaces=False) | ||
print('-'*20, 'Prompt', '-'*20) | ||
print(f'Message: {messages}') | ||
print(f'Image link/path: {image_path}') | ||
print('-'*20, 'Output', '-'*20) | ||
print(output_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
190 changes: 190 additions & 0 deletions
190
python/llm/src/ipex_llm/transformers/npu_models/phi3_v.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Some parts of this file is adapted from | ||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py | ||
# which is licensed under Apache License 2.0: | ||
# | ||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import torch | ||
import importlib | ||
from torch import nn | ||
from typing import Optional, Tuple, List | ||
from transformers.models.clip.modeling_clip import CLIPAttention | ||
from ipex_llm.utils.common.log4Error import invalidInputError | ||
|
||
|
||
def merge_qkv(module: torch.nn.Module): | ||
if isinstance(module, CLIPAttention): | ||
new_weight = torch.cat([ | ||
module.q_proj.weight.data, | ||
module.k_proj.weight.data, | ||
module.v_proj.weight.data, | ||
], dim=0) | ||
|
||
if module.q_proj.bias is not None: | ||
qkv_proj = torch.nn.Linear(0, 0, bias=True) | ||
new_bias = torch.cat([ | ||
module.q_proj.bias.data, | ||
module.k_proj.bias.data, | ||
module.v_proj.bias.data, | ||
], dim=0) | ||
qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) | ||
else: | ||
qkv_proj = torch.nn.Linear(0, 0, bias=False) | ||
qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) | ||
qkv_proj.in_features = new_weight.size(1) | ||
qkv_proj.out_features = new_weight.size(0) | ||
module.qkv_proj = qkv_proj | ||
|
||
del module.q_proj, module.k_proj, module.v_proj | ||
|
||
|
||
def phi3v_model_forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
pixel_values: Optional[torch.FloatTensor] = None, | ||
image_sizes: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
): | ||
# ipex-llm changes start | ||
from ipex_llm.transformers.kv import DynamicNormalCache | ||
# IPEX-LLM OPT: kv cache and quantize kv cache | ||
use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
if use_cache: | ||
if not isinstance(past_key_values, DynamicNormalCache): | ||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) | ||
modeling_module_name = self.__class__.__module__ | ||
module = importlib.import_module(modeling_module_name) | ||
return module.Phi3VModel.forward( | ||
self=self, | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
pixel_values=pixel_values, | ||
image_sizes=image_sizes, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
|
||
|
||
def phi3v_encoder_attention_forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
causal_attention_mask: Optional[torch.Tensor] = None, | ||
output_attentions: Optional[bool] = False, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
bsz, tgt_len, embed_dim = hidden_states.size() | ||
|
||
qkv = self.qkv_proj(hidden_states) | ||
qkv = qkv.view(bsz, tgt_len, self.num_heads * 3, self.head_dim) | ||
qkv = qkv.transpose(1, 2) | ||
query_states, key_states, value_states = qkv.split([self.num_heads, | ||
self.num_heads, | ||
self.num_heads], dim=1) | ||
|
||
proj_shape = (bsz * self.num_heads, -1, self.head_dim) | ||
query_states = query_states.reshape(*proj_shape) | ||
key_states = key_states.reshape(*proj_shape) | ||
value_states = value_states.reshape(*proj_shape) | ||
|
||
src_len = key_states.size(1) | ||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | ||
|
||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): | ||
invalidInputError( | ||
False, | ||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}," | ||
f" but is {attn_weights.size()}" | ||
) | ||
|
||
# apply the causal_attention_mask first | ||
if causal_attention_mask is not None: | ||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): | ||
invalidInputError( | ||
False, | ||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" | ||
f" {causal_attention_mask.size()}" | ||
) | ||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) \ | ||
+ causal_attention_mask | ||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | ||
|
||
if attention_mask is not None: | ||
if attention_mask.size() != (bsz, 1, tgt_len, src_len): | ||
invalidInputError( | ||
False, | ||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}," | ||
f" but is {attention_mask.size()}" | ||
) | ||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask | ||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | ||
|
||
attn_weights = nn.functional.softmax(attn_weights, dim=-1) | ||
|
||
if output_attentions: | ||
# this operation is a bit akward, but it's required to | ||
# make sure that attn_weights keeps its gradient. | ||
# In order to do so, attn_weights have to reshaped | ||
# twice and have to be reused in the following | ||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | ||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) | ||
else: | ||
attn_weights_reshaped = None | ||
|
||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | ||
|
||
attn_output = torch.bmm(attn_probs, value_states) | ||
|
||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): | ||
invalidInputError( | ||
False, | ||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}," | ||
f" but is {attn_output.size()}" | ||
) | ||
|
||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) | ||
attn_output = attn_output.transpose(1, 2) | ||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) | ||
|
||
attn_output = self.out_proj(attn_output) | ||
|
||
return attn_output, attn_weights_reshaped |