Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM: add GGUF-IQ2 examples #10207

Merged
merged 6 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# GGUF-IQ2

This example shows how to run INT2 models using the IQ2 mechanism (first implemented by llama.cpp) in BigDL-LLM on Intel GPU.

## Verified Models

- [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf), using [llama-v2-7b.imatrix](https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/resolve/main/llama-v2-7b.imatrix)
- [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf), using [llama-v2-7b.imatrix](https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/resolve/main/llama-v2-7b.imatrix)
- [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2), using [mistral-7b-instruct-v0.2.imatrix](https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/resolve/main/mistral-7b-instruct-v0.2.imatrix)
- [Mixtral-8x7B-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1), using [mixtral-8x7b.imatrix](https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/resolve/main/mixtral-8x7b.imatrix)
- [Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1), using [mixtral-8x7b-instruct-v0.1.imatrix](https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/resolve/main/mixtral-8x7b-instruct-v0.1.imatrix)

## Requirements

To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../../../README.md#requirements) for more information.

## Example: Predict Tokens using `generate()` API

In the example [generate.py](./generate.py), we show a basic use case for a GGUF-IQ2 model to predict the next N tokens using `generate()` API, with BigDL-LLM optimizations.

### 1. Install

We suggest using conda to manage environment:

```bash
conda create -n llm python=3.9
conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip install transformers==4.35.0
```
**Note: For Mixtral model, please use transformers 4.36.0:**
```bash
pip install transformers==4.36.0
```

### 2. Configures OneAPI environment variables

```bash
source /opt/intel/oneapi/setvars.sh
```

### 3. Run

For optimal performance on Arc, it is recommended to set several environment variables.

```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
```

```
python ./generate.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
```

Arguments info:

- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`.
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.

#### 2.3 Sample Output

#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)

```log
Inference time: xxxx s
-------------------- Prompt --------------------
### HUMAN:
What is AI?

### RESPONSE:

-------------------- Output --------------------
### HUMAN:
What is AI?

### RESPONSE:

Artificial intelligence (AI) refers to the ability of machines to perform tasks that would typically require human intelligence, such as learning, problem-solving
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import torch
import time
import argparse
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import warnings

# you could tune the prompt based on your own model,
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
PROMPT_FORMAT = """### HUMAN:
{prompt}

### RESPONSE:
"""

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for LLM model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')

args = parser.parse_args()
model_path = args.repo_id_or_model_path

warnings.warn("iq2 quantization may need several minutes, please wait a moment, "
"or have a cup of coffee now : )")

# Load model in 2 bit,
# which convert the relevant layers in the model into gguf_iq2_xxs format.
# GGUF-IQ2 quantization needs imatrix file to assist in quantization
# and improve generation quality, and different model may need different
# imtraix file, you can find and download imatrix file from
# https://huggingface.co/datasets/ikawrakow/imatrix-from-wiki-train/tree/main.
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit='gguf_iq2_xxs',
trust_remote_code=True,
imatrix='llama-v2-7b.imatrix').to("xpu")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Generate predicted tokens
with torch.inference_mode():
prompt = PROMPT_FORMAT.format(prompt=args.prompt)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("xpu")
# ipex model needs a warmup, then inference time can be accurate
output = model.generate(input_ids,
max_new_tokens=args.n_predict)
st = time.time()
# if your selected model is capable of utilizing previous key/value attentions
# to enhance decoding speed, but has `"use_cache": false` in its model config,
# it is important to set `use_cache=True` explicitly in the `generate` function
# to obtain optimal performance with BigDL-LLM Low Bit optimizations
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
repetition_penalty=1.1)
end = time.time()
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s')
print('-'*20, 'Prompt', '-'*20)
print(prompt)
print('-'*20, 'Output', '-'*20)
print(output_str)
4 changes: 2 additions & 2 deletions python/llm/src/bigdl/llm/ggml/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
"fp8_e5m2": 19, # fp8 in e5m2 format
"fp8": 19, # fp8 in e5m2 format
"bf16": 20,
"iq2_xxs": 21,
"iq2_xs": 22,
"gguf_iq2_xxs": 21,
"gguf_iq2_xs": 22,
"q2_k": 23}

_llama_quantize_type = {"q4_0": 2,
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/bigdl/llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
FP8E5 = ggml_tensor_qtype["fp8_e5m2"]
IQ2_XXS = ggml_tensor_qtype["iq2_xxs"]
IQ2_XS = ggml_tensor_qtype["iq2_xs"]
IQ2_XXS = ggml_tensor_qtype["gguf_iq2_xxs"]
IQ2_XS = ggml_tensor_qtype["gguf_iq2_xs"]
Q2_K = ggml_tensor_qtype["q2_k"]


Expand Down
9 changes: 5 additions & 4 deletions python/llm/src/bigdl/llm/transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def from_pretrained(cls,
:param load_in_low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``,
``'sym_int5'``, ``'asym_int5'``, ``'sym_int8'``, ``'nf3'``,
``'nf4'``, ``'fp4'``, ``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``,
``'iq2_xxs'``, ``'iq2_xs'``, ``'fp16'`` or ``'bf16'``,
``'gguf_iq2_xxs'``, ``'gguf_iq2_xs'``, ``'fp16'`` or ``'bf16'``,
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
Relevant low bit optimizations will be applied to the model.
Expand Down Expand Up @@ -278,12 +278,13 @@ def from_pretrained(cls,
kwargs["pretraining_tp"] = 1
q_k = load_in_low_bit if load_in_low_bit else "sym_int4"
imatrix_file = kwargs.pop("imatrix", None)
if q_k in ["iq2_xxs", "iq2_xs"]:
if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs"]:
invalidInputError(imatrix_file is not None,
"For iq2_xxs and iq2_xs quantization, imatrix is needed.")
"For gguf_iq2_xxs and gguf_iq2_xs quantization,"
"imatrix is needed.")
cpu_embedding = kwargs.get("cpu_embedding", False)
# for 2bit, default use embedding_quantization
if q_k in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding and \
if q_k in ["gguf_iq2_xxs", "gguf_iq2_xs", "q2_k"] and not cpu_embedding and \
embedding_qtype is None:
embedding_qtype = "q2_k"
if imatrix_file is not None:
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/bigdl/llm/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def module_name_process(full_module_name):

def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_type=None):
cur_qtype = qtype
if qtype in [ggml_tensor_qtype["iq2_xxs"], ggml_tensor_qtype["iq2_xs"]]:
if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"]]:
# For quantization which needs importance matrix
new_module_name, layer, cur_module = module_name_process(full_module_name)
# custom mixed quantization strategy
Expand Down
Loading