Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: The accuracy of vllm-Qwen2-VL-7B-Instruct is low. #8408

Open
1 task done
xiangxinhello opened this issue Sep 12, 2024 · 20 comments
Open
1 task done

[Bug]: The accuracy of vllm-Qwen2-VL-7B-Instruct is low. #8408

xiangxinhello opened this issue Sep 12, 2024 · 20 comments
Labels
bug Something isn't working

Comments

@xiangxinhello
Copy link

xiangxinhello commented Sep 12, 2024

Your current environment

from PIL import Image
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info

MODEL_PATH = '/workspace/mnt/storage/trt-llama/Qwen2-VL-7B-Instruct'
IMAGE_PATH = '/workspace/mnt/storage/llm_storge/vllm/examples/demo.jpeg'

llm = LLM(
model=MODEL_PATH,
dtype = 'float32',
limit_mm_per_prompt={'image': 10, 'video': 10},
)

sampling_params = SamplingParams(
temperature=0.1, top_p=0.001, repetition_penalty=1.05, max_tokens=256,
stop_token_ids=[],
)

messages = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': [
{
'type': 'image',
'image': IMAGE_PATH,
'max_pixels': 12845056,
},
{
'type': 'text',
'text': '输出击掌的检测框',
},
]},
]

processor = AutoProcessor.from_pretrained(MODEL_PATH)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)

mm_data = {}
if image_inputs is not None:
mm_data['image'] = image_inputs
if video_inputs is not None:
mm_data['video'] = video_inputs

llm_inputs = {
'prompt': prompt,
'multi_modal_data': mm_data,
}
#击掌(529,516),(583,594)
outputs = llm.generate([llm_inputs], sampling_params=sampling_params)
generated_text = outputs[0].outputs[0].text

print(generated_text)

Model Input Dumps

No response

🐛 Describe the bug

Qwen2-VL-7B-Instruct:vllm-qwenvl-fp16 have a bug, The accuracy between vllm-qwenvl and transformer-qwenvl differs.
击掌(529,513),(584,605) vllm-fp16
击掌(531,516),(581,596) transformers-qwem2-vl-fp16
The coordinates of vllm are (529,513),(584,605).
The coordinates of transformers are (536,509),(588,602).
There is a significant difference in their errors.

qwenvl_v

1

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@xiangxinhello xiangxinhello added the bug Something isn't working label Sep 12, 2024
@DarkLight1337
Copy link
Member

@fyabc can you help look into this if you have time? Thanks!

@fyabc
Copy link
Contributor

fyabc commented Sep 13, 2024

@DarkLight1337 @xiangxinhello I will take a look at it.

@fyabc
Copy link
Contributor

fyabc commented Sep 13, 2024

@xiangxinhello Hi, you set dtype to 'float32' in your example code. I want to confirm that which dtype do you use in vllm and transformers?

@fyabc
Copy link
Contributor

fyabc commented Sep 13, 2024

@xiangxinhello Hi, you set dtype to 'float32' in your example code. I want to confirm that which dtype do you use in vllm and transformers?

If vllm using fp32 and transformers using fp16, the difference may be acceptable... @ShuaiBai623 can you take a look at this diff?

@xiangxinhello
Copy link
Author

xiangxinhello commented Sep 13, 2024

@xiangxinhello Hi, you set dtype to 'float32' in your example code. I want to confirm that which dtype do you use in vllm and transformers?

If vllm using fp32 and transformers using fp16, the difference may be acceptable... @ShuaiBai623 can you take a look at this diff?

vllm float32 and float16 have the same effect, both have errors.

@xiangxinhello
Copy link
Author

Hi, @fyabc, Do you support Qwen-VL-Chat?

@fyabc
Copy link
Contributor

fyabc commented Sep 13, 2024

Hi, @fyabc, Do you support Qwen-VL-Chat?

@xiangxinhello #8029 already supported Qwen-VL-Chat, you can try latest vllm-0.6.1.

@xiangxinhello
Copy link
Author

@xiangxinhello Hi, you set dtype to 'float32' in your example code. I want to confirm that which dtype do you use in vllm and transformers?

If vllm using fp32 and transformers using fp16, the difference may be acceptable... @ShuaiBai623 can you take a look at this diff?

击掌(531,516),(581,596) transformers-qwem2-vl-fp16
击掌 (529,513),(584,605) vllm-qwem2-vl-fp16
击掌 (529,516),(583,594) vllm-qwem2-vl-fp32

@xiangxinhello
Copy link
Author

demo

@xiangxinhello
Copy link
Author

xiangxinhello commented Sep 13, 2024

Hi @fyabc, The models transformers-qwen-vl-float16 and vllm-qwen-vl-float16 show discrepancies. Could you help me with this?
击掌(536,509),(588,602) transformers-qwen-vl-float16。https://github.com/QwenLM/Qwen-VL
击掌(539,513),(588,604)<|im_end|> this is vllm-qwen-vl-float16

@xiangxinhello
Copy link
Author

xiangxinhello commented Sep 14, 2024

@xiangxinhello Hi, you set dtype to 'float32' in your example code. I want to confirm that which dtype do you use in vllm and transformers?

HI, @DarkLight1337 @fyabc, https://github.com/QwenLM/Qwen-VL
This is the result for transformers-qwen-vl-float16:击掌(536,509),(588,602)
This is the result for vllm-qwen-vl fp32: 击掌(539,513),(588,604)<|im_end|>
This is the result for vllm-qwen-vl fp16: 击掌(539,513),(588,604)<|im_end|>.

@fyabc
Copy link
Contributor

fyabc commented Sep 14, 2024

@xiangxinhello Hi, I have tested Qwen2-VL-7B-Instruct fp16/fp32 on vllm and HF, and got the same outputs '击掌(529,516),(583,594)' in all three different seeds. Can you provide your GPU & environment information, and show your HuggingFace test script?

@xiangxinhello
Copy link
Author

xiangxinhello commented Sep 14, 2024

@xiangxinhello Hi, I have tested Qwen2-VL-7B-Instruct fp16/fp32 on vllm and HF, and got the same outputs '击掌(529,516),(583,594)' in all three different seeds. Can you provide your GPU & environment information, and show your HuggingFace test script?

Hi @fyabc
My vllm-qwen2-vl-Qwen2-VL-7B-Instruct-fp16 is also (529,516),(583,594).
transformers-Qwen2-VL-7B-Instruct is (531,516),(581,596).
The main issue is that there is a slight difference in the coordinate values between the two.

A100-PCIE-40GB

this transformers-Qwen2-VL-7B-Instruct environment information:
torch 2.4.0 pypi_0 pypi
torchvision 0.19.0 pypi_0 pypi
tqdm 4.66.5 pypi_0 pypi
transformers 4.45.0.dev0 pypi_0 pypi
python 3.10.14 h955ad1f_1

this is transformers test script:
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

model = Qwen2VLForConditionalGeneration.from_pretrained(
"/workspace/mnt/storage/trt-llama/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)

processor = AutoProcessor.from_pretrained("/workspace/mnt/storage/trt-llama/Qwen2-VL-7B-Instruct")

messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "/workspace/mnt/storage/llm_storge/vllm/examples/demo.jpeg",
},
{
"type": "text",
"text": "框出图中击掌的位置"
},
],
}
]

text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")

generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

@xiangxinhello
Copy link
Author

@xiangxinhello Hi, I have tested Qwen2-VL-7B-Instruct fp16/fp32 on vllm and HF, and got the same outputs '击掌(529,516),(583,594)' in all three different seeds. Can you provide your GPU & environment information, and show your HuggingFace test script?

@xiangxinhello Hi, I have tested Qwen2-VL-7B-Instruct fp16/fp32 on vllm and HF, and got the same outputs '击掌(529,516),(583,594)' in all three different seeds. Can you provide your GPU & environment information, and show your HuggingFace test script?

The main issue is that there is a slight difference in the coordinate values between the two.

@fyabc
Copy link
Contributor

fyabc commented Sep 14, 2024

@xiangxinhello Hi, can you add print(model.generation_config) into you HF example script and show the result? I want to confirm HF & vllm are using the same generation hyperparameters.

@xiangxinhello
Copy link
Author

@xiangxinhello Hi, can you add print(model.generation_config) into you HF example script and show the result? I want to confirm HF & vllm are using the same generation hyperparameters.

@fyabc
GenerationConfig {
"bos_token_id": 151643,
"do_sample": true,
"eos_token_id": [
151645,
151643
],
"pad_token_id": 151643,
"temperature": 0.01,
"top_k": 1,
"top_p": 0.001
}

@fyabc
Copy link
Contributor

fyabc commented Sep 14, 2024

@xiangxinhello Hi, can you add print(model.generation_config) into you HF example script and show the result? I want to confirm HF & vllm are using the same generation hyperparameters.

@fyabc GenerationConfig { "bos_token_id": 151643, "do_sample": true, "eos_token_id": [ 151645, 151643 ], "pad_token_id": 151643, "temperature": 0.01, "top_k": 1, "top_p": 0.001 }

@xiangxinhello Hi, can you add model.generation_config.repetition_penalty = 1.05 to align with vllm setting and try again?

@xiangxinhello
Copy link
Author

xiangxinhello commented Sep 14, 2024

@xiangxinhello Hi, can you add print(model.generation_config) into you HF example script and show the result? I want to confirm HF & vllm are using the same generation hyperparameters.

@fyabc GenerationConfig { "bos_token_id": 151643, "do_sample": true, "eos_token_id": [ 151645, 151643 ], "pad_token_id": 151643, "temperature": 0.01, "top_k": 1, "top_p": 0.001 }

@xiangxinhello Hi, can you add model.generation_config.repetition_penalty = 1.05 to align with vllm setting and try again?

@fyabc, I set transformers model.generation_config.repetition_penalty = 1.05 and is ['击掌(531,516),(587,594)'].
My vllm-qwen2-vl-Qwen2-VL-7B-Instruct-fp16 is also (529,516),(583,594).

@kq-chen
Copy link

kq-chen commented Sep 14, 2024

Diving into the code step by step to check the difference between vllm and hf, I found some differences.

The first difference lies in the inv_freq in the VisionRotaryEmbedding. In Hugging Face, it's initialized on the CPU and then moved to GPU, while in vLLM, it's initialized directly on the GPU. This results in an error in the magnitude of around 1e-8.
The second difference is in the implementation of QuickGELU. vLLM uses a custom cuda operator, whereas Hugging Face does not.

In my test case, by aligning the implementations of these two components, HF produced the same output as vLLM. However, I have not performed a more comprehensive set of tests.

Personally, I feel that these differences are quite small. Errors of this magnitude are generally acceptable.

@xiangxinhello
Copy link
Author

xiangxinhello commented Sep 18, 2024

Diving into the code step by step to check the difference between vllm and hf, I found some differences.

The first difference lies in the inv_freq in the VisionRotaryEmbedding. In Hugging Face, it's initialized on the CPU and then moved to GPU, while in vLLM, it's initialized directly on the GPU. This results in an error in the magnitude of around 1e-8. The second difference is in the implementation of QuickGELU. vLLM uses a custom cuda operator, whereas Hugging Face does not.

In my test case, by aligning the implementations of these two components, HF produced the same output as vLLM. However, I have not performed a more comprehensive set of tests.

Personally, I feel that these differences are quite small. Errors of this magnitude are generally acceptable.

Hi @kq-chen and @fyabc . Thanks you for your help.
But I think first difference is not problem, I hope you can try my example script, thanks for your support!
this is example picture demo.jpeg
demo

this is vllm-Qwen2-VL-7B-Instruct test script:

from PIL import Image
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info

MODEL_PATH = '/workspace/mnt/storage/xiangxin/trt-llama/Qwen2-VL-7B-Instruct'
IMAGE_PATH = '/workspace/mnt/storage/xiangxin/llm_storge/vllm/examples/demo.jpeg'

llm = LLM(
    model=MODEL_PATH,
    dtype = 'float16',
    limit_mm_per_prompt={'image': 10, 'video': 10},
)

sampling_params = SamplingParams(
    temperature=0.1, top_p=0.001, repetition_penalty=1.05, max_tokens=256,
    stop_token_ids=[],
)

messages = [
    {'role': 'system', 'content': 'You are a helpful assistant.'},
    {'role': 'user', 'content': [
        {
            'type': 'image',
            'image': IMAGE_PATH,
            'max_pixels': 12845056,
        },
        {
            'type': 'text',
            'text': '输出击掌的检测框',
        },
    ]},
]

processor = AutoProcessor.from_pretrained(MODEL_PATH)
prompt = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)

mm_data = {}
if image_inputs is not None:
    mm_data['image'] = image_inputs
if video_inputs is not None:
    mm_data['video'] = video_inputs

llm_inputs = {
    'prompt': prompt,
    'multi_modal_data': mm_data,
}

outputs = llm.generate([llm_inputs], sampling_params=sampling_params)
generated_text = outputs[0].outputs[0].text

print(generated_text)

this is huggingface-Qwen2-VL-7B-Instruct test script:

from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "/workspace/mnt/storage/xiangxin/trt-llama/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
)

processor = AutoProcessor.from_pretrained("/workspace/mnt/storage/xiangxin/trt-llama/Qwen2-VL-7B-Instruct")


messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "/workspace/mnt/storage/xiangxin/llm_storge/vllm/examples/demo.jpeg",
            },
            {
                "type": "text", 
                "text": "框出图中击掌的位置"
            },
        ],
    }
]

text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

model.generation_config.repetition_penalty = 1.05
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

print(model.generation_config)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants