-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* temp save * meet review, update * update * meet review, add license * typo
- Loading branch information
Showing
4 changed files
with
339 additions
and
0 deletions.
There are no files selected for viewing
39 changes: 39 additions & 0 deletions
39
python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
cmake_minimum_required(VERSION 3.10) | ||
|
||
project(LLM_NPU_EXAMPLE VERSION 1.0.0 LANGUAGES CXX) | ||
|
||
set (CMAKE_CXX_STANDARD 17) | ||
SET (CMAKE_CXX_STANDARD_REQUIRED True) | ||
|
||
if(DEFINED ENV{CONDA_ENV_DIR}) | ||
set(ENV_DIR $ENV{CONDA_ENV_DIR}) | ||
set(LIBRARY_DIR ${ENV_DIR}/bigdl-core-npu) | ||
include_directories(${LIBRARY_DIR}/include) | ||
set(DLL_DIR ${ENV_DIR}/intel_npu_acceleration_library/lib/Release) | ||
else() | ||
set(LIBRARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}) | ||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) | ||
endif() | ||
|
||
add_library(npu_llm STATIC IMPORTED) | ||
set_target_properties(npu_llm PROPERTIES IMPORTED_LOCATION ${LIBRARY_DIR}/npu_llm.lib) | ||
|
||
set(TARGET llm-npu-cli) | ||
add_executable(${TARGET} llm-npu-cli.cpp) | ||
install(TARGETS ${TARGET} RUNTIME) | ||
target_link_libraries(${TARGET} PRIVATE npu_llm) | ||
target_compile_features(${TARGET} PRIVATE cxx_std_17) | ||
|
||
add_custom_command(TARGET llm-npu-cli POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy_if_different | ||
${LIBRARY_DIR}/npu_llm.dll | ||
${CMAKE_BINARY_DIR}/Release/ | ||
COMMENT "Copying npu_llm.dll to build/Release\n" | ||
) | ||
|
||
add_custom_command(TARGET llm-npu-cli POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy_directory | ||
${DLL_DIR}/ | ||
${CMAKE_BINARY_DIR}/Release/ | ||
COMMENT "Copying dependency to build/Release\n" | ||
) |
92 changes: 92 additions & 0 deletions
92
python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# C++ Example of running LLM on Intel NPU using IPEX-LLM (Experimental) | ||
In this directory, you will find a C++ example on how to run LLM models on Intel NPUs using IPEX-LLM (leveraging *Intel NPU Acceleration Library*). See the table blow for verified models. | ||
|
||
## Verified Models | ||
|
||
| Model | Model Link | | ||
|------------|----------------------------------------------------------------| | ||
| Qwen2 | [Qwen/Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct), [Qwen/Qwen2-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2-1.5B-Instruct) | | ||
| Qwen2.5 | [Qwen/Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) | | ||
|
||
|
||
## 0. Requirements | ||
To run this C++ example with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU. | ||
Go to https://www.intel.com/content/www/us/en/download/794734/intel-npu-driver-windows.html to download and unzip the driver. | ||
Then go to **Device Manager**, find **Neural Processors** -> **Intel(R) AI Boost**. | ||
Right click and select **Update Driver** -> **Browse my computer for drivers**. And then manually select the unzipped driver folder to install. | ||
|
||
## 1. Install | ||
### 1.1 Installation on Windows | ||
We suggest using conda to manage environment: | ||
```cmd | ||
conda create -n llm python=3.10 | ||
conda activate llm | ||
:: install ipex-llm with 'npu' option | ||
pip install --pre --upgrade ipex-llm[npu] | ||
:: [optional] for Llama-3.2-1B-Instruct & Llama-3.2-3B-Instruct | ||
pip install transformers==4.45.0 accelerate==0.33.0 | ||
``` | ||
|
||
## 2. Convert Model | ||
We provide a [convert script](convert_model.py) under current directory, by running it, you can obtain the whole weights and configuration files which are required to run C++ example. | ||
|
||
```cmd | ||
:: to convert Qwen2.5-7b-Instruct | ||
python convert_model.py --repo-id-or-model-path Qwen/Qwen2.5-7B-Instruct --save-directory <converted_model_path> | ||
``` | ||
|
||
Arguments info: | ||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the model (e.g. `Qwen/Qwen2.5-7B-Instruct`) to be downloaded, or the path to the huggingface checkpoint folder. | ||
- `--save-directory SAVE_DIRECTORY`: argument defining the path to save converted model. If it is a non-existing path, the original pretrained model specified by `REPO_ID_OR_MODEL_PATH` will be loaded, and the converted model will be saved into `SAVE_DIRECTORY`. | ||
- `--max-context-len MAX_CONTEXT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`. | ||
- `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `960`. | ||
- `--disable-transpose-value-cache`: Disable the optimization of transposing value cache. | ||
|
||
## 3. Build C++ Example `llm-npu-cli` | ||
|
||
You can run below cmake script in cmd to build `llm-npu-cli`, don't forget to replace below conda env dir with your own path. | ||
|
||
```cmd | ||
:: under current directory | ||
:: please replace below conda env dir with your own path | ||
set CONDA_ENV_DIR=C:\Users\arda\miniforge3\envs\llm\Lib\site-packages | ||
mkdir build | ||
cd build | ||
cmake .. | ||
cmake --build . --config Release -j | ||
cd Release | ||
``` | ||
|
||
## 4. Run `llm-npu-cli` | ||
|
||
With built `llm-npu-cli`, you can run the example with specified paramaters. For example, | ||
|
||
```cmd | ||
llm-npu-cli.exe -m <converted_model_path> -n 64 "AI是什么?" | ||
``` | ||
|
||
Arguments info: | ||
- `-m` : argument defining the path of saved converted model. | ||
- `-n` : argument defining how many tokens will be generated. | ||
- Last argument is your input prompt. | ||
|
||
### 5. Sample Output | ||
#### [`Qwen/Qwen2.5-7B-Instruct`](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) | ||
```cmd | ||
Input: | ||
<|im_start|>system | ||
You are a helpful assistant.<|im_end|> | ||
<|im_start|>user | ||
AI是什么?<|im_end|> | ||
<|im_start|>assistant | ||
Prefill 22 tokens cost xxxx ms. | ||
Output: | ||
AI是"人工智能"的缩写,是英文"Artificial Intelligence"的翻译。它是研究如何使计算机也具有智能的一种技术和理论。简而言之,人工智能就是让计算机能够模仿人智能行为的一项技术。 | ||
Decode 46 tokens cost xxxx ms (avg xx.xx ms each token). | ||
``` |
74 changes: 74 additions & 0 deletions
74
python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/convert.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
|
||
import torch | ||
import argparse | ||
from ipex_llm.transformers.npu_model import AutoModelForCausalLM | ||
from transformers import AutoTokenizer | ||
from transformers.utils import logging | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Convert LLM for C++ NPU inference and save" | ||
) | ||
parser.add_argument( | ||
"--repo-id-or-model-path", | ||
type=str, | ||
default="Qwen/Qwen2.5-7B-Instruct", # Or Qwen2-7B-Instruct, Qwen2-1.5B-Instruct | ||
help="The huggingface repo id for the Qwen model to be downloaded" | ||
", or the path to the huggingface checkpoint folder", | ||
) | ||
parser.add_argument("--save-directory", type=str, | ||
required=True, | ||
help="The path of folder to save converted model, " | ||
"If path not exists, lowbit model will be saved there. " | ||
"Else, program will raise error.", | ||
) | ||
parser.add_argument("--max-context-len", type=int, default=1024) | ||
parser.add_argument("--max-prompt-len", type=int, default=960) | ||
parser.add_argument("--quantization_group_size", type=int, default=0) | ||
parser.add_argument('--load_in_low_bit', type=str, default="sym_int4", | ||
help='Load in low bit to use') | ||
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) | ||
|
||
args = parser.parse_args() | ||
model_path = args.repo_id_or_model_path | ||
save_dir = args.save_directory | ||
|
||
model = AutoModelForCausalLM.from_pretrained(model_path, | ||
optimize_model=True, | ||
pipeline=True, | ||
load_in_low_bit=args.load_in_low_bit, | ||
max_context_len=args.max_context_len, | ||
max_prompt_len=args.max_prompt_len, | ||
quantization_group_size=args.quantization_group_size, | ||
torch_dtype=torch.float16, | ||
attn_implementation="eager", | ||
transpose_value_cache=not args.disable_transpose_value_cache, | ||
mixed_precision=True, | ||
trust_remote_code=True, | ||
compile_full_model=True, | ||
save_directory=save_dir) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | ||
tokenizer.save_pretrained(save_dir) | ||
|
||
print("-" * 80) | ||
print(f"finish save model to {save_dir}") | ||
print("success shut down") |
134 changes: 134 additions & 0 deletions
134
python/llm/example/NPU/HF-Transformers-AutoModels/LLM/CPP_Examples/llm-npu-cli.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
// | ||
// Copyright 2016 The BigDL Authors. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
// | ||
|
||
#include <iostream> | ||
#include <fstream> | ||
#include <string> | ||
#include <chrono> | ||
|
||
#include "common.h" | ||
#include "npu_llm.h" | ||
|
||
|
||
static void print_usage(int, char ** argv) { | ||
printf("\nexample usage:\n"); | ||
printf("\n %s -m npu_model_dir [-n n_predict] [prompt]\n", argv[0]); | ||
printf("\n"); | ||
} | ||
|
||
|
||
int main(int argc, char ** argv) { | ||
common_params params; | ||
|
||
// path to the npu model directory | ||
std::string model_dir; | ||
// prompt to generate text from | ||
std::string prompt = "AI是什么?"; | ||
// number of tokens to predict | ||
int n_predict = 32; | ||
|
||
// parse command line arguments | ||
|
||
{ | ||
int i = 1; | ||
for (; i < argc; i++) { | ||
if (strcmp(argv[i], "-m") == 0) { | ||
if (i + 1 < argc) { | ||
model_dir = argv[++i]; | ||
} else { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
} else if (strcmp(argv[i], "-n") == 0) { | ||
if (i + 1 < argc) { | ||
try { | ||
n_predict = std::stoi(argv[++i]); | ||
} catch (...) { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
} else { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
} else { | ||
// prompt starts here | ||
break; | ||
} | ||
} | ||
if (model_dir.empty()) { | ||
print_usage(argc, argv); | ||
return 1; | ||
} | ||
if (i < argc) { | ||
prompt = argv[i++]; | ||
for (; i < argc; i++) { | ||
prompt += " "; | ||
prompt += argv[i]; | ||
} | ||
} | ||
} | ||
|
||
params.n_predict = n_predict; | ||
params.model = model_dir; | ||
params.prompt = prompt; | ||
|
||
npu_model_params model_params; | ||
NPUModel* model = load_model_from_file(model_params, params.model); | ||
|
||
tokenizer_params tok_params; | ||
load_tokenizer(tok_params, params.model); | ||
|
||
std::string full_prompt = add_chat_template(model_params, params.prompt); | ||
std::cout << "Input: " << std::endl; | ||
std::cout << full_prompt << std::endl; | ||
|
||
// tokenize input | ||
std::vector<int32_t> embd_inp = llm_tokenize(full_prompt, false); | ||
|
||
std::vector<int32_t> embd; // output ids | ||
auto start = std::chrono::high_resolution_clock::now(); | ||
float* logits = run_prefill(model, embd_inp); | ||
int32_t token = llm_sample_token(logits, true, model_params); | ||
auto end = std::chrono::high_resolution_clock::now(); | ||
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start); | ||
printf("\nPrefill %d tokens cost %d ms.\n", embd_inp.size(), duration.count()); | ||
embd.push_back(token); | ||
|
||
int token_nums = 0; | ||
start = std::chrono::high_resolution_clock::now(); | ||
for (int i = 1; i < params.n_predict; i++){ | ||
auto logits = run_decode(model, embd[i-1]); | ||
int32_t token = llm_sample_token(logits, true, model_params); | ||
if (token != tok_params.eos_token_id) { | ||
embd.push_back(token); | ||
token_nums ++; | ||
} else { | ||
break; | ||
} | ||
} | ||
end = std::chrono::high_resolution_clock::now(); | ||
duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start); | ||
|
||
std::string output = llm_decode(embd); | ||
|
||
std::cout << "Output: " << std::endl; | ||
std::cout << output << std::endl; | ||
|
||
printf("\nDecode %d tokens cost %d ms (avg %f ms each token).\n", token_nums, duration.count(), (float)duration.count() / token_nums); | ||
|
||
return 0; | ||
} |