Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Qwen1.5-GPTQ-Int4 inference error #11432

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ conda activate llm
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
pip install transformers==4.34.0
BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
pip install optimum==0.14.0
pip install optimum==1.14.0
```

On Windows:
Expand All @@ -30,7 +30,7 @@ pip install --pre --upgrade ipex-llm[all]
pip install transformers==4.34.0
set BUILD_CUDA_EXT=0
pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6
pip install optimum==0.14.0
pip install optimum==1.14.0
```

### 2. Run
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import argparse

from ipex_llm.transformers import AutoModelForCausalLM
from transformers import LlamaTokenizer, GPTQConfig
from transformers import LlamaTokenizer, AutoTokenizer

# you could tune the prompt based on your own model,
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
Expand Down Expand Up @@ -50,7 +50,10 @@
trust_remote_code=True,)

# Load tokenizer
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
if "qwen" in model_path.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Generate predicted tokens
with torch.inference_mode():
Expand Down
9 changes: 8 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,14 @@ def _optimize_pre(model):
# for qwen2
if model.config.model_type == "qwen2":
from ipex_llm.transformers.models.qwen2 import merge_qkv
model.apply(merge_qkv)
# Skip merge_qkv if quant_method is 'gptq'
should_apply_merge_qkv = (
not hasattr(model.config, "quantization_config") or
not hasattr(model.config.quantization_config, "quant_method") or
model.config.quantization_config.quant_method != "gptq"
)
if should_apply_merge_qkv:
model.apply(merge_qkv)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is gptq loaded as q4_1? If so, why can't we use merge_qkv? @qiuxin2012

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a new merge_qkv for gptq here. Just convert q k v to one LowBitLinear.

Copy link
Contributor

@jason-dai jason-dai Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a new merge_qkv for gptq here. Just convert q k v to one LowBitLinear.

It's different from q4_1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gptq is loaded as q4_1. However, gptq's weights have already been quantized. If we want to use merge_qkv, we need to dequantize from the gptq format to the normal format first, perform merge_qkv, and then quantize back into the LowBitLinear.

I will try this method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gptq is loaded as q4_1. However, gptq's weights have already been quantized. If we want to use merge_qkv, we need to dequantize from the gptq format to the normal format first, perform merge_qkv, and then quantize back into the LowBitLinear.

I will try this method.

Instead of dequantization, I think we can just rearrange the quantized qkv tensors into a combined one? Anyway I think we may fix it later in a separate PR if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll merge this PR first so users can start using it, and then I'll submit another PR to address further fixes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we probably cannot merge qkv if quantization_config.desc_act==True.

from ipex_llm.transformers.models.qwen2 import padding_mlp
model.apply(padding_mlp)
if model.config.model_type == "qwen2_moe":
Expand Down
23 changes: 17 additions & 6 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,23 @@ def qwen2_attention_forward(
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device

qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=1)
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=1)
else:
# when quant_method is 'gptq'
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) \
.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) \
.transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand Down
Loading