Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Usage]: ValueError: Unexpected weight for Qwen2-VL GPTQ 4-bit custom model. #9832

Closed
1 task done
bhavyajoshi-mahindra opened this issue Oct 30, 2024 · 16 comments · Fixed by #10169
Closed
1 task done
Labels
usage How to use vllm

Comments

@bhavyajoshi-mahindra
Copy link

Your current environment

The output of `python collect_env.py`

WARNING 10-30 12:11:37 _custom_ops.py:19] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
PyTorch version: 2.4.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Home Single Language
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.10.0 (tags/v3.10.0:b494f59, Oct  4 2021, 19:00:18) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22631-SP0
Is CUDA available: False
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2060
Nvidia driver version: 565.90
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=2900
DeviceID=CPU0
Family=107
L2CacheSize=4096
L2CacheSpeed=
Manufacturer=AuthenticAMD
MaxClockSpeed=2900
Name=AMD Ryzen 7 4800H with Radeon Graphics
ProcessorType=3
Revision=24577

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-ml-py==12.560.30
[pip3] pyzmq==26.2.0
[pip3] torch==2.4.0
[pip3] torchvision==0.19.0
[pip3] transformers==4.46.1
[pip3] triton==3.0.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.3
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect

I tried to infer my custom Qwen2-VL GPTQ 4bit model using the below code:

from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info

MODEL_PATH = "Qwen2-VL"

llm = LLM(
    model=MODEL_PATH,
    limit_mm_per_prompt={"image": 10, "video": 10},
)

sampling_params = SamplingParams(
    temperature=0.1,
    top_p=0.001,
    repetition_penalty=1.05,
    max_tokens=256,
    stop_token_ids=[],
)

messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "/content/drive/MyDrive/LLM/test/Vin_2023-12-22_14-47-37.jpg",
                "min_pixels": 224 * 224,
                "max_pixels": 1280 * 28 * 28,
            },
            {"type": "text", "text":
                                    '''
                                    Please extract the Vehicle Sr No, Engine No, and Model from this image.
                                    Response only json format nothing else.
                                    Analyze the font and double check for similar letters such as "V":"U", "8":"S":"0", "R":"P".
                                    '''
             },
        ],
    },
]

processor = AutoProcessor.from_pretrained(MODEL_PATH)
prompt = processor.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)

mm_data = {}
if image_inputs is not None:
    mm_data["image"] = image_inputs
if video_inputs is not None:
    mm_data["video"] = video_inputs

llm_inputs = {
    "prompt": prompt,
    "multi_modal_data": mm_data,
}

outputs = llm.generate([llm_inputs], sampling_params=sampling_params)
generated_text = outputs[0].outputs[0].text

print(generated_text)

I got this error:

WARNING 10-30 12:06:32 _custom_ops.py:19] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
You are using a model of type qwen2_vl to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
ERROR 10-30 12:06:37 registry.py:264] Error in inspecting model architecture 'Qwen2VLForConditionalGeneration'
ERROR 10-30 12:06:37 registry.py:264] Traceback (most recent call last):
ERROR 10-30 12:06:37 registry.py:264]   File "F:\Mahindra\LLM\vllm\vllm\model_executor\models\registry.py", line 426, in _run_in_subprocess
ERROR 10-30 12:06:37 registry.py:264]     returned.check_returncode()
ERROR 10-30 12:06:37 registry.py:264]   File "C:\Users\bhavy\AppData\Local\Programs\Python\Python310\lib\subprocess.py", line 456, in check_returncode
ERROR 10-30 12:06:37 registry.py:264]     raise CalledProcessError(self.returncode, self.args, self.stdout,
ERROR 10-30 12:06:37 registry.py:264] subprocess.CalledProcessError: Command '['F:\\Mahindra\\LLM\\myenv\\Scripts\\python.exe', '-m', 'vllm.model_executor.models.registry']' returned non-zero exit status 1.
ERROR 10-30 12:06:37 registry.py:264] 
ERROR 10-30 12:06:37 registry.py:264] The above exception was the direct cause of the following exception:
ERROR 10-30 12:06:37 registry.py:264]
ERROR 10-30 12:06:37 registry.py:264] Traceback (most recent call last):
ERROR 10-30 12:06:37 registry.py:264]   File "F:\Mahindra\LLM\vllm\vllm\model_executor\models\registry.py", line 262, in _try_inspect_model_cls        
ERROR 10-30 12:06:37 registry.py:264]     return model.inspect_model_cls()
ERROR 10-30 12:06:37 registry.py:264]   File "F:\Mahindra\LLM\vllm\vllm\model_executor\models\registry.py", line 224, in inspect_model_cls
ERROR 10-30 12:06:37 registry.py:264]     return _run_in_subprocess(
ERROR 10-30 12:06:37 registry.py:264]   File "F:\Mahindra\LLM\vllm\vllm\model_executor\models\registry.py", line 429, in _run_in_subprocess
ERROR 10-30 12:06:37 registry.py:264]     raise RuntimeError(f"Error raised in subprocess:\n"
ERROR 10-30 12:06:37 registry.py:264] RuntimeError: Error raised in subprocess:
ERROR 10-30 12:06:37 registry.py:264] C:\Users\bhavy\AppData\Local\Programs\Python\Python310\lib\runpy.py:126: RuntimeWarning: 'vllm.model_executor.models.registry' found in sys.modules after import of package 'vllm.model_executor.models', but prior to execution of 'vllm.model_executor.models.registry'; this may result in unpredictable behaviour
ERROR 10-30 12:06:37 registry.py:264]   warn(RuntimeWarning(msg))
ERROR 10-30 12:06:37 registry.py:264] Traceback (most recent call last):
ERROR 10-30 12:06:37 registry.py:264]   File "C:\Users\bhavy\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 196, in _run_module_as_main   
ERROR 10-30 12:06:37 registry.py:264]     return _run_code(code, main_globals, None,
ERROR 10-30 12:06:37 registry.py:264]   File "C:\Users\bhavy\AppData\Local\Programs\Python\Python310\lib\runpy.py", line 86, in _run_code
ERROR 10-30 12:06:37 registry.py:264]     exec(code, run_globals)
ERROR 10-30 12:06:37 registry.py:264]   File "F:\Mahindra\LLM\vllm\vllm\model_executor\models\registry.py", line 450, in <module>
ERROR 10-30 12:06:37 registry.py:264]     _run()
ERROR 10-30 12:06:37 registry.py:264]   File "F:\Mahindra\LLM\vllm\vllm\model_executor\models\registry.py", line 445, in _run
ERROR 10-30 12:06:37 registry.py:264]     with open(output_file, "wb") as f:
ERROR 10-30 12:06:37 registry.py:264] PermissionError: [Errno 13] Permission denied: 'C:\\Users\\bhavy\\AppData\\Local\\Temp\\tmpjxi5mk75'
ERROR 10-30 12:06:37 registry.py:264]
Traceback (most recent call last):
  File "F:\Mahindra\LLM\vllm\qwen2-vl-vllm-infer.py", line 7, in <module>
    llm = LLM(
  File "F:\Mahindra\LLM\vllm\vllm\entrypoints\llm.py", line 177, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
  File "F:\Mahindra\LLM\vllm\vllm\engine\llm_engine.py", line 571, in from_engine_args
    engine_config = engine_args.create_engine_config()
  File "F:\Mahindra\LLM\vllm\vllm\engine\arg_utils.py", line 900, in create_engine_config
    model_config = self.create_model_config()
  File "F:\Mahindra\LLM\vllm\vllm\engine\arg_utils.py", line 837, in create_model_config
    return ModelConfig(
  File "F:\Mahindra\LLM\vllm\vllm\config.py", line 194, in __init__
    self.multimodal_config = self._init_multimodal_config(
  File "F:\Mahindra\LLM\vllm\vllm\config.py", line 213, in _init_multimodal_config
    if ModelRegistry.is_multimodal_model(architectures):
  File "F:\Mahindra\LLM\vllm\vllm\model_executor\models\registry.py", line 384, in is_multimodal_model
    return self.inspect_model_cls(architectures).supports_multimodal
  File "F:\Mahindra\LLM\vllm\vllm\model_executor\models\registry.py", line 353, in inspect_model_cls
    return self._raise_for_unsupported(architectures)
  File "F:\Mahindra\LLM\vllm\vllm\model_executor\models\registry.py", line 314, in _raise_for_unsupported
    raise ValueError(
ValueError: Model architectures ['Qwen2VLForConditionalGeneration'] are not supported for now. Supported architectures: ['AquilaModel', 'AquilaForCausalLM', 'ArcticForCausalLM', 'BaiChuanForCausalLM', 'BaichuanForCausalLM', 'BloomForCausalLM', 'CohereForCausalLM', 'DbrxForCausalLM', 'DeciLMForCausalLM', 'DeepseekForCausalLM', 'DeepseekV2ForCausalLM', 'ExaoneForCausalLM', 'FalconForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTJForCausalLM', 'GPTNeoXForCausalLM', 'GraniteForCausalLM', 'GraniteMoeForCausalLM', 'InternLMForCausalLM', 'InternLM2ForCausalLM', 'JAISLMHeadModel', 'JambaForCausalLM', 'LlamaForCausalLM', 'LLaMAForCausalLM', 'MambaForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'QuantMixtralForCausalLM', 'MptForCausalLM', 'MPTForCausalLM', 'MiniCPMForCausalLM', 'MiniCPM3ForCausalLM', 'NemotronForCausalLM', 'OlmoForCausalLM', 'OlmoeForCausalLM', 'OPTForCausalLM', 'OrionForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'Phi3ForCausalLM', 'Phi3SmallForCausalLM', 'PhiMoEForCausalLM', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RWForCausalLM', 'StableLMEpochForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'SolarForCausalLM', 'XverseForCausalLM', 'BartModel', 'BartForConditionalGeneration', 'MistralModel', 'Qwen2ForRewardModel', 'Gemma2Model', 'Blip2ForConditionalGeneration', 'ChameleonForConditionalGeneration', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'FuyuForCausalLM', 'InternVLChatModel', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration', 'LlavaNextVideoForConditionalGeneration', 'LlavaOnevisionForConditionalGeneration', 'MiniCPMV', 'MolmoForCausalLM', 'NVLM_D', 'PaliGemmaForConditionalGeneration', 'Phi3VForCausalLM', 'PixtralForConditionalGeneration', 'QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'UltravoxModel', 'MllamaForConditionalGeneration', 'EAGLEModel', 'MedusaModel', 'MLPSpeculatorPreTrainedModel']

Note:

  1. "Qwen2VLForConditionalGeneration" is in the list of supported models but still I got the error.
  2. collect_env.py says "Is CUDA available: False" but nvcc --version mentions :
    "nvcc: NVIDIA (R) Cuda compiler driver
    Copyright (c) 2005-2023 NVIDIA Corporation
    Built on Wed_Feb__8_05:53:42_Coordinated_Universal_Time_2023
    Cuda compilation tools, release 12.1, V12.1.66
    Build cuda_12.1.r12.1/compiler.32415258_0"

Can anyone help me with this.

How would you like to use vllm

I want to run inference of a [specific model](put link here). I don't know how to integrate it with vllm.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@bhavyajoshi-mahindra bhavyajoshi-mahindra added the usage How to use vllm label Oct 30, 2024
@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 30, 2024

This particular error should have been fixed by #9721. Note that vLLM doesn't officially support Windows installations. Please also see #9701

@bhavyajoshi-mahindra
Copy link
Author

I have switched to Linux (Colab).
I have fine-tuned Qwen2-VL (LoRA) using Llama-factory and merged it on the original weights. Then I quantized (GPTQ) the merged weights using AutoGPTQ. Now I want to infer the quantized weights using vLLM.
But I got this error :

ValueError                                Traceback (most recent call last)
[<ipython-input-5-06a68f93c72a>](https://localhost:8080/#) in <cell line: 7>()
      5 MODEL_PATH = "/content/drive/MyDrive/LLM/vinplate2-gwen2-vl-gptq-4bit"
      6 
----> 7 llm = LLM(
      8     model=MODEL_PATH,
      9     limit_mm_per_prompt={"image": 10, "video": 10},

10 frames
[/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/qwen2_vl.py](https://localhost:8080/#) in load_weights(self, weights)
   1201                     param = params_dict[name]
   1202                 except KeyError:
-> 1203                     raise ValueError(f"Unexpected weight: {name}") from None
   1204 
   1205                 weight_loader = getattr(param, "weight_loader",

ValueError: Unexpected weight: model.layers.0.mlp.down_proj.g_idx

Here is my environment.
transformers : 4.46.1
vllm : 0.6.3.post2.dev165+g33d25773
torch : 2.4.1+cu121
flash_attn : 2.6.3
CUDA : 12.2
python : 3.10.12
OS : Linux

@bhavyajoshi-mahindra bhavyajoshi-mahindra changed the title [Usage]: ValueError: Model architectures ['Qwen2VLForConditionalGeneration'] are not supported for now [Usage]: ValueError: Unexpected weight for Qwen2-VL GPTQ 4-bit custom model. Oct 30, 2024
@DarkLight1337
Copy link
Member

Can you try using the latest main branch of vLLM? #9772 might already have fixed this issue.

@DarkLight1337
Copy link
Member

cc @mgoin

@bhavyajoshi-mahindra
Copy link
Author

Still getting the same error after installing vLLM from the main branch.
transformers : 4.46.1
vllm : 0.6.3.post2.dev174+g5608e611.d20241031
torch : 2.5.0+cu124
flash_attn : 2.6.3

@mgoin
Copy link
Collaborator

mgoin commented Oct 31, 2024

Using main before #9817 landed, I am able to load Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4 just fine in vLLM. However as we just landed support for quantizing the vision transformer, this broke GPTQ checkpoints for this model (and many other VLMs using GPTQ are likely broken as well)

@DarkLight1337 this gets into the larger issue we have with enabling quantization for more modules in vLLM, but many quantization methods/configurations do not have proper "ignored" lists of modules

As an example, if you look at Qwen's official GPTQ checkpoint for Qwen2-VL you can see that all of the "model." submodules are quantized but none of the "visual." ones are https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4?show_file_info=model.safetensors.index.json
image

However within that model's gptq quantization_config, there is nothing specifying that those modules were ignored - it looks like the config should be applied everywhere https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4/blob/main/config.json#L20-L30

  "quantization_config": {
    "bits": 4,
    "damp_percent": 0.1,
    "dataset": null,
    "desc_act": false,
    "group_size": 128,
    "modules_in_block_to_quantize": null,
    "quant_method": "gptq",
    "sym": true,
    "true_sequential": true
  },

Luckily not all quant configs have this issue - obviously compressed-tensors has an ignore list, and AWQ has a "modules_to_not_convert" list

@DarkLight1337
Copy link
Member

Is it feasible to change the model initialization code to switch between the regular and the quantized version based on whether the corresponding weight is available from the model file?

@mgoin
Copy link
Collaborator

mgoin commented Oct 31, 2024

Not easily at all. We commonly rely on the assumption that we can allocate and distribute the model parameters by looking at the model config. Model loading from the weights is a separate step

@bhavyajoshi-mahindra
Copy link
Author

I mean to understand, is there anything wrong while quantizing the model or something is wrong while loading the model using vLLM?

@DarkLight1337
Copy link
Member

Not easily at all. We commonly rely on the assumption that we can allocate and distribute the model parameters by looking at the model config. Model loading from the weights is a separate step

Hmm, a more practical way might be to let the user specify additional config arguments via CLI then...

@mgoin
Copy link
Collaborator

mgoin commented Oct 31, 2024

@bhavyajoshi-mahindra the issue is that AutoGPTQ will not quantize the visual section of qwen2-vl, but it does not leave anything in the config to signify that that linear layers are skipped

@DarkLight1337 I think we should simply add a special case for GPTQ models, like was done here for AWQ

def _patch_quant_config(self, config: PretrainedConfig,
quant_config: QuantizationConfig):
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
llm_quant_config = getattr(text_config, "quantization_config",
None)
if (not quant_config.modules_to_not_convert) and \
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 31, 2024

@bhavyajoshi-mahindra the issue is that AutoGPTQ will not quantize the visual section of qwen2-vl, but it does not leave anything in the config to signify that that linear layers are skipped

@DarkLight1337 I think we should simply add a special case for GPTQ models, like was done here for AWQ

def _patch_quant_config(self, config: PretrainedConfig,
quant_config: QuantizationConfig):
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
llm_quant_config = getattr(text_config, "quantization_config",
None)
if (not quant_config.modules_to_not_convert) and \
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")

That may work for now. Does AWQ have an implicit list of modules that it quantized? What if this changes in the future?

@cedonley
Copy link

The thread here seems to indicate that AWQ should work, but I get the same issue with AWQ version.

    raise ValueError(f"Unexpected weight: {name}") from None
ValueError: Unexpected weight: visual.blocks.0.attn.proj.weight

Yet the layer is specified as unconverted in the config file:

  "quantization_config": {
    "bits": 4,
    "group_size": 128,
    "modules_to_not_convert": [
      "visual"
    ],
    "quant_method": "awq",
    "version": "gemm",
    "zero_point": true
  },

I'm trying with latest main.

@mgoin
Copy link
Collaborator

mgoin commented Oct 31, 2024

Thanks for testing @cedonley, it seems if you run vllm serve Qwen/Qwen2-VL-2B-Instruct-AWQ it will fail with your error because awq_marlin isn't obeying the ignore list. However if you force the vanilla awq backend with vllm serve Qwen/Qwen2-VL-2B-Instruct-AWQ, I am able to load the model fine. I will put a fix up for this!

@bhavyajoshi-mahindra
Copy link
Author

Any update on the issue?

@mgoin
Copy link
Collaborator

mgoin commented Nov 8, 2024

Hi @bhavyajoshi-mahindra I only have a workaround for Qwen2-VL locally, so I sat on it as I think about a more general solution. I will work on a PR just using the workaround for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
usage How to use vllm
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants