From 105e1247525458fbe3fa0d956dc8601b18495540 Mon Sep 17 00:00:00 2001 From: Zhao Changmin Date: Thu, 11 Jul 2024 13:59:14 +0800 Subject: [PATCH] optimize phi3-v encoder npu performance and add multimodal example (#11553) * phi3-v * readme --- .../Multimodal/README.md | 75 +++++++ .../Multimodal/generate.py | 93 +++++++++ .../transformers/npu_models/convert.py | 12 ++ .../transformers/npu_models/phi3_v.py | 190 ++++++++++++++++++ 4 files changed, 370 insertions(+) create mode 100644 python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/README.md create mode 100644 python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/generate.py create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/phi3_v.py diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/README.md new file mode 100644 index 00000000000..4977079b2e1 --- /dev/null +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/README.md @@ -0,0 +1,75 @@ +# Run Large Multimodal Model on Intel NPU +In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on [Intel NPUs](../../../README.md). In this directory, you will find examples on how you could apply IPEX-LLM INT4 or INT8 optimizations on Large Multimodal Models on Intel NPUs. See the table blow for verified models. + +## Verified Models + +| Model | Model Link | +|------------|----------------------------------------------------------------| +| Phi-3-Vision | [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) | + +## 0. Requirements +To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU. +Go to https://www.intel.com/content/www/us/en/download/794734/intel-npu-driver-windows.html to download and unzip the driver. +Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**. +Right click and select **Update Driver**. And then manually select the folder unzipped from the driver. + +## Example: Predict Tokens using `generate()` API +In the example [generate.py](./generate.py), we show a basic use case for a phi-3-vision model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel NPUs. +### 1. Install +#### 1.1 Installation on Windows +We suggest using conda to manage environment: +```bash +conda create -n llm python=3.10 libuv +conda activate llm + +# below command will install intel_extension_for_pytorch==2.1.10+xpu as default +pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ + +# below command will install intel_npu_acceleration_library +pip install intel-npu-acceleration-library==1.3 + +pip install transformers==4.40 +``` + +### 2. Runtime Configurations +For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device. +#### 2.1 Configurations for Windows + +**Following envrionment variables are required**: + +```cmd +set BIGDL_USE_NPU=1 +``` + +### 3. Running examples + +``` +python ./generate.py +``` + +Arguments info: +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Phi-3-vision model (e.g. `microsoft/Phi-3-vision-128k-instruct`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'microsoft/Phi-3-vision-128k-instruct'`, and more verified models please see the list in [Verified Models](#verified-models). +- `--image-url-or-path IMAGE_URL_OR_PATH`: argument defining the image to be infered. It is default to be `'http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg'`. +- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is in the image?'`. +- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. +- `--load_in_low_bit`: argument defining the `load_in_low_bit` format used. It is default to be `sym_int8`, `sym_int4` can also be used. + +#### Sample Output +#### [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) + +```log +Inference time: xxxx s +-------------------- Prompt -------------------- +Message: [{'role': 'user', 'content': '<|image_1|>\nWhat is in the image?'}] +Image link/path: http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg +-------------------- Output -------------------- + + +What is in the image? + The image shows a young girl holding a white teddy bear. She is wearing a pink dress with a heart on it. The background includes a stone +``` + +The sample input image is (which is fetched from [COCO dataset](https://cocodataset.org/#explore?id=264959)): + + + diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/generate.py b/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/generate.py new file mode 100644 index 00000000000..230ee1a0e89 --- /dev/null +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/generate.py @@ -0,0 +1,93 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import time +import torch +import argparse +import requests + +from PIL import Image +from ipex_llm.transformers.npu_model import AutoModelForCausalLM +from transformers import AutoProcessor + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for phi-3 model') + parser.add_argument('--repo-id-or-model-path', type=str, default="microsoft/Phi-3-vision-128k-instruct", + help='The huggingface repo id for the phi-3-vision model to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--image-url-or-path', type=str, + default="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg", + help='The URL or path to the image to infer') + parser.add_argument('--prompt', type=str, default="What is in the image?", + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=32, + help='Max tokens to predict') + parser.add_argument('--load_in_low_bit', type=str, default="sym_int4", + help='Load in low bit to use') + + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + image_path = args.image_url_or_path + + # Load model in SYM_INT4, + # which convert the relevant layers in the model into SYM_INT4 format + # You could also try `'sym_int8'` for INT8 + # `_attn_implementation="eager"` is required for phi-3-vision + # `modules_to_not_convert=["vision_embed_tokens"]` and `model = model.half()` are for acceleration and are optional + model = AutoModelForCausalLM.from_pretrained(model_path, + trust_remote_code=True, + load_in_low_bit=args.load_in_low_bit, + _attn_implementation="eager", + modules_to_not_convert=["vision_embed_tokens"]) + + # Load processor + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + # here the message formatting refers to https://huggingface.co/microsoft/Phi-3-vision-128k-instruct#sample-inference-code + messages = [ + {"role": "user", "content": "<|image_1|>\n{prompt}".format(prompt=args.prompt)}, + ] + prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + if os.path.exists(image_path): + image = Image.open(image_path) + else: + image = Image.open(requests.get(image_path, stream=True).raw) + + # Generate predicted tokens + with torch.inference_mode(): + # start inference + st = time.time() + + inputs = processor(prompt, [image], return_tensors="pt") + output = model.generate(**inputs, + eos_token_id=processor.tokenizer.eos_token_id, + num_beams=1, + do_sample=False, + max_new_tokens=args.n_predict, + temperature=0.0) + end = time.time() + print(f'Inference time: {end-st} s') + output_str = processor.decode(output[0], + skip_special_tokens=True, + clean_up_tokenization_spaces=False) + print('-'*20, 'Prompt', '-'*20) + print(f'Message: {messages}') + print(f'Image link/path: {image_path}') + print('-'*20, 'Output', '-'*20) + print(output_str) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 03a1f18d7be..6d3c95ee0bf 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -177,3 +177,15 @@ def optimize_llm(model: torch.nn.Module): model.apply(merge_mlp) convert_forward(model, module.MLP, baichuan_mlp_forward) + + elif model.config.model_type == "phi3_v": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.npu_models.phi3_v import merge_qkv + from ipex_llm.transformers.npu_models.phi3_v import phi3v_encoder_attention_forward + from ipex_llm.transformers.npu_models.phi3_v import phi3v_model_forward + model.apply(merge_qkv) + + from transformers.models.clip.modeling_clip import CLIPAttention + convert_forward(model, CLIPAttention, phi3v_encoder_attention_forward) + convert_forward(model, module.Phi3VModel, phi3v_model_forward) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/phi3_v.py b/python/llm/src/ipex_llm/transformers/npu_models/phi3_v.py new file mode 100644 index 00000000000..8fd8a38620d --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/phi3_v.py @@ -0,0 +1,190 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +# which is licensed under Apache License 2.0: +# +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import importlib +from torch import nn +from typing import Optional, Tuple, List +from transformers.models.clip.modeling_clip import CLIPAttention +from ipex_llm.utils.common.log4Error import invalidInputError + + +def merge_qkv(module: torch.nn.Module): + if isinstance(module, CLIPAttention): + new_weight = torch.cat([ + module.q_proj.weight.data, + module.k_proj.weight.data, + module.v_proj.weight.data, + ], dim=0) + + if module.q_proj.bias is not None: + qkv_proj = torch.nn.Linear(0, 0, bias=True) + new_bias = torch.cat([ + module.q_proj.bias.data, + module.k_proj.bias.data, + module.v_proj.bias.data, + ], dim=0) + qkv_proj.bias = torch.nn.Parameter(new_bias, requires_grad=False) + else: + qkv_proj = torch.nn.Linear(0, 0, bias=False) + qkv_proj.weight = torch.nn.Parameter(new_weight, requires_grad=False) + qkv_proj.in_features = new_weight.size(1) + qkv_proj.out_features = new_weight.size(0) + module.qkv_proj = qkv_proj + + del module.q_proj, module.k_proj, module.v_proj + + +def phi3v_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + # ipex-llm changes start + from ipex_llm.transformers.kv import DynamicNormalCache + # IPEX-LLM OPT: kv cache and quantize kv cache + use_cache = use_cache if use_cache is not None else self.config.use_cache + if use_cache: + if not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + modeling_module_name = self.__class__.__module__ + module = importlib.import_module(modeling_module_name) + return module.Phi3VModel.forward( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +def phi3v_encoder_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, tgt_len, embed_dim = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, tgt_len, self.num_heads * 3, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_heads, + self.num_heads], dim=1) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = query_states.reshape(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + invalidInputError( + False, + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}," + f" but is {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) \ + + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}," + f" but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + invalidInputError( + False, + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}," + f" but is {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped