diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md index 6d58ab8996c..b773065906e 100644 --- a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md +++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/README.md @@ -18,7 +18,7 @@ conda activate llm pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu pip install transformers==4.34.0 BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6 -pip install optimum==0.14.0 +pip install optimum==1.14.0 ``` On Windows: @@ -30,7 +30,7 @@ pip install --pre --upgrade ipex-llm[all] pip install transformers==4.34.0 set BUILD_CUDA_EXT=0 pip install git+https://github.com/PanQiWei/AutoGPTQ.git@1de9ab6 -pip install optimum==0.14.0 +pip install optimum==1.14.0 ``` ### 2. Run diff --git a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py index 2929194b7cd..f6d58bb5e78 100644 --- a/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py +++ b/python/llm/example/CPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py @@ -19,7 +19,7 @@ import argparse from ipex_llm.transformers import AutoModelForCausalLM -from transformers import LlamaTokenizer, GPTQConfig +from transformers import LlamaTokenizer, AutoTokenizer # you could tune the prompt based on your own model, # here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style @@ -50,7 +50,10 @@ trust_remote_code=True,) # Load tokenizer - tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) + if "qwen" in model_path.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + else: + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) # Generate predicted tokens with torch.inference_mode(): diff --git a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py index c830d9106e0..c45963f59e7 100644 --- a/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py +++ b/python/llm/example/GPU/HF-Transformers-AutoModels/Advanced-Quantizations/GPTQ/generate.py @@ -18,7 +18,7 @@ import time import argparse from ipex_llm.transformers import AutoModelForCausalLM -from transformers import AutoTokenizer, GPTQConfig +from transformers import AutoTokenizer, AutoTokenizer # you could tune the prompt based on your own model, # here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style @@ -48,9 +48,11 @@ torch_dtype=torch.float, trust_remote_code=True,).to("xpu") - print(model) # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + if "qwen" in model_path.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + else: + tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) # Generate predicted tokens with torch.inference_mode(): diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index dc233942e79..9eafb6f1fad 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -732,10 +732,17 @@ def _optimize_pre(model): model.apply(split_mlp) # for qwen2 if model.config.model_type == "qwen2": - from ipex_llm.transformers.models.qwen2 import merge_qkv - model.apply(merge_qkv) - from ipex_llm.transformers.models.qwen2 import padding_mlp - model.apply(padding_mlp) + # Skip merge_qkv and padding_mlp if quant_method is 'gptq' + should_apply_merge_qkv = ( + not hasattr(model.config, "quantization_config") or + not hasattr(model.config.quantization_config, "quant_method") or + model.config.quantization_config.quant_method != "gptq" + ) + if should_apply_merge_qkv: + from ipex_llm.transformers.models.qwen2 import merge_qkv + model.apply(merge_qkv) + from ipex_llm.transformers.models.qwen2 import padding_mlp + model.apply(padding_mlp) if model.config.model_type == "qwen2_moe": from ipex_llm.transformers.models.qwen2_moe import merge_qkv model.apply(merge_qkv) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index afc12f4bfda..708c40337e2 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -405,12 +405,23 @@ def qwen2_attention_forward( bsz, q_len, _ = hidden_states.size() device = hidden_states.device - qkv = self.qkv_proj(hidden_states) - qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) - qkv = qkv.transpose(1, 2) - query_states, key_states, value_states = qkv.split([self.num_heads, - self.num_key_value_heads, - self.num_key_value_heads], dim=1) + if hasattr(self, 'qkv_proj') and self.qkv_proj is not None: + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=1) + else: + # when quant_method is 'gptq' + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) \ + .transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) \ + .transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: