Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support qwen2-vl #32318

Merged
merged 103 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
a8c38a8
support-qwen2-vl
simonJJJ Jul 30, 2024
3caad96
Merge branch 'main' of https://github.com/simonJJJ/transformers into …
simonJJJ Jul 31, 2024
779d9da
tidy
simonJJJ Jul 31, 2024
382b0bc
tidy
simonJJJ Jul 31, 2024
d6bd095
tidy
simonJJJ Jul 31, 2024
a74ee73
tidy
simonJJJ Jul 31, 2024
d585b01
tidy
simonJJJ Jul 31, 2024
9b1d485
tidy
simonJJJ Aug 1, 2024
774a5bc
tidy
simonJJJ Aug 1, 2024
b7a6567
hyphen->underscore
simonJJJ Aug 5, 2024
8b8f37e
make style
simonJJJ Aug 5, 2024
ec66f42
add-flash2-tipd
simonJJJ Aug 5, 2024
b8fae77
delete-tokenize=False
simonJJJ Aug 5, 2024
9262416
remove-image_processor-in-init-file
simonJJJ Aug 5, 2024
283b03c
add-qwen2_vl-in-MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
simonJJJ Aug 5, 2024
7160f7f
format-doct
simonJJJ Aug 5, 2024
5c8d171
support-Qwen2VLVisionConfig
simonJJJ Aug 5, 2024
7e57d56
remove-standardize_cache_format
simonJJJ Aug 5, 2024
009f637
fix-letter-varaibles
simonJJJ Aug 5, 2024
a7a1f6f
remove-torch-in-image-processor
simonJJJ Aug 5, 2024
6fa58f5
remove-useless-docstring
simonJJJ Aug 5, 2024
6d3d580
Merge branch 'main' of https://github.com/simonJJJ/transformers into …
simonJJJ Aug 5, 2024
849621d
fix-one-letter-varaible-name
simonJJJ Aug 5, 2024
bd806ff
change-block-name
simonJJJ Aug 5, 2024
d7778e2
default-quick-gelu-in-vision
simonJJJ Aug 5, 2024
fca6fee
remove-useless-doc
simonJJJ Aug 5, 2024
c3dd8df
use-preimplemented-flash-forward
simonJJJ Aug 5, 2024
8ba3b08
fix-doc
simonJJJ Aug 5, 2024
20983ff
fix-image-processing-doc
simonJJJ Aug 5, 2024
254c657
fix-apply-rotary-embed
simonJJJ Aug 6, 2024
32072c5
fix-flash-attn-sliding-window
simonJJJ Aug 6, 2024
49ba7b8
refactor
simonJJJ Aug 8, 2024
3d292cd
remove-default_template
simonJJJ Aug 8, 2024
1b75471
remove-reorder_cache
simonJJJ Aug 8, 2024
37a2672
simple-get-rope_deltas
simonJJJ Aug 8, 2024
2cfa97c
update-prepare_inputs_for_generation
simonJJJ Aug 8, 2024
27bc926
update-attention-mask
simonJJJ Aug 8, 2024
fdb4afd
update-rotary_seq_len
simonJJJ Aug 8, 2024
ea8b03e
remove-state
simonJJJ Aug 8, 2024
e5a63a3
Merge branch 'main' of https://github.com/simonJJJ/transformers into …
simonJJJ Aug 8, 2024
48cccfd
kv_seq_length
simonJJJ Aug 8, 2024
8754d6f
remove-warning
simonJJJ Aug 8, 2024
55be524
_supports_static_cache
simonJJJ Aug 8, 2024
f564db7
remove-legacy-cache
simonJJJ Aug 8, 2024
5aef1d8
refactor
simonJJJ Aug 8, 2024
7f402eb
fix-replace
simonJJJ Aug 9, 2024
77273da
mrope-section-doc
simonJJJ Aug 9, 2024
d8dbb1f
code-quality
simonJJJ Aug 9, 2024
f145949
Merge branch 'main' into qwen2_vl
simonJJJ Aug 9, 2024
1a6ecad
code-quality
simonJJJ Aug 9, 2024
aba0a2f
polish-doc
simonJJJ Aug 9, 2024
3895071
fix-image-processing-test
simonJJJ Aug 9, 2024
3482568
update readme
Aug 9, 2024
49106ff
Merge branch 'qwen2_vl' of https://github.com/simonJJJ/transformers i…
Aug 9, 2024
abb6811
Update qwen2_vl.md
ShuaiBai623 Aug 9, 2024
8da4fcc
fix-test
simonJJJ Aug 9, 2024
4dd4c9b
Update qwen2_vl.md
ShuaiBai623 Aug 11, 2024
a91b2e9
nit
simonJJJ Aug 12, 2024
110914f
processor-kwargs
simonJJJ Aug 12, 2024
f97ee89
hard-code-norm_layer
simonJJJ Aug 12, 2024
468d698
code-quality
simonJJJ Aug 12, 2024
b4fbbea
Merge branch 'huggingface:main' into qwen2_vl
simonJJJ Aug 12, 2024
1639991
discard-pixel-values-in-gen
simonJJJ Aug 12, 2024
d98e9a9
fix-inconsistent-error-msg
simonJJJ Aug 12, 2024
e6cebdb
unify-image-video
simonJJJ Aug 12, 2024
0805a65
hidden_act
simonJJJ Aug 13, 2024
a92f0ae
add-docstring
simonJJJ Aug 13, 2024
c8d69f0
vision-encode-as-PreTrainedModel
simonJJJ Aug 15, 2024
1ce5837
pixel-to-target-dtype
simonJJJ Aug 15, 2024
27c29d6
update doc and low memoryvit
Aug 15, 2024
a4af0b5
Merge pull request #2 from simonJJJ/updatedoc
simonJJJ Aug 15, 2024
eade3b6
format
simonJJJ Aug 15, 2024
77ab7a5
format
simonJJJ Aug 15, 2024
c4b2add
channel-foramt
simonJJJ Aug 15, 2024
80969ef
fix vit_flashatt
Aug 15, 2024
7b5e785
Merge pull request #3 from simonJJJ/updatedoc
ShuaiBai623 Aug 16, 2024
ce37f64
format
Aug 16, 2024
db4ceb0
inherit-Qwen2VLPreTrainedModel
simonJJJ Aug 16, 2024
5f43897
simplify
simonJJJ Aug 21, 2024
9b14cf0
format-test
simonJJJ Aug 21, 2024
4795317
remove-one-line-func-in-image-processing
simonJJJ Aug 22, 2024
620e35d
avoid-one-line-reshape
simonJJJ Aug 22, 2024
f7adae9
simplify-rotary_seq_len
simonJJJ Aug 22, 2024
fe009cf
avoid-single-letter-variable
simonJJJ Aug 22, 2024
db7b5c3
no-for-loop-sdpa
simonJJJ Aug 22, 2024
61cf241
avoid-single-letter-variable
simonJJJ Aug 22, 2024
f2fb132
remove-one-line-reshape
simonJJJ Aug 22, 2024
c42cedf
remove-one-line-reshape
simonJJJ Aug 22, 2024
c185ffb
remove-no-rope-in-vit-logic
simonJJJ Aug 22, 2024
cafbe43
default-mrope
simonJJJ Aug 22, 2024
7553723
add-copied-from
simonJJJ Aug 22, 2024
43f9685
Merge branch 'main' of https://github.com/simonJJJ/transformers into …
simonJJJ Aug 22, 2024
704e3f2
more-docs-for-mrope
simonJJJ Aug 22, 2024
eefb67a
polish-doc
simonJJJ Aug 22, 2024
1fe8570
comment-and-link
simonJJJ Aug 22, 2024
827e5e9
polish-doc
simonJJJ Aug 22, 2024
f32ac01
single-letter-variables
simonJJJ Aug 26, 2024
e65e7f8
simplify-image-processing
simonJJJ Aug 26, 2024
5d37d76
video->images
simonJJJ Aug 26, 2024
4752328
kv_seq_len-update
simonJJJ Aug 26, 2024
36f2d43
vision-rope-on-the-fly
simonJJJ Aug 26, 2024
3ef1657
vision-eager-attention
simonJJJ Aug 26, 2024
e28cc19
change-processor-order
simonJJJ Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,8 @@
title: Qwen2Audio
- local: model_doc/qwen2_moe
title: Qwen2MoE
- local: model_doc/qwen2_vl
title: Qwen2VL
- local: model_doc/rag
title: RAG
- local: model_doc/realm
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Qwen2](model_doc/qwen2) ||||
| [Qwen2Audio](model_doc/qwen2_audio) ||||
| [Qwen2MoE](model_doc/qwen2_moe) ||||
| [Qwen2VL](model_doc/qwen2_vl) ||||
| [RAG](model_doc/rag) ||||
| [REALM](model_doc/realm) ||||
| [RecurrentGemma](model_doc/recurrent_gemma) ||||
Expand Down
329 changes: 329 additions & 0 deletions docs/source/en/model_doc/qwen2_vl.md
simonJJJ marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Qwen2_VL


## Overview

The [Qwen2_VL](https://qwenlm.github.io/blog/qwen2-vl/) is a major update to our [Qwen-VL](https://arxiv.org/pdf/2308.12966) model from the Qwen team.

The abstract from the blog is the following:

*This blog introduces Qwen2-VL, an advanced version of the Qwen-VL model that has undergone significant enhancements over the past year. Key improvements include enhanced image comprehension, advanced video understanding, integrated visual agent functionality, and expanded multilingual support. The model architecture has been optimized for handling arbitrary image resolutions through Naive Dynamic Resolution support and utilizes Multimodal Rotary Position Embedding (M-ROPE) to effectively process both 1D textual and multi-dimensional visual data. This updated model demonstrates competitive performance against leading AI systems like GPT-4o and Claude 3.5 Sonnet in vision-related tasks and ranks highly among open-source models in text capabilities. These advancements make Qwen2-VL a versatile tool for various applications requiring robust multimodal processing and reasoning abilities.*


## Usage example

### Single Media inference

The model can accept both images and videos as input. Here's an example code for inference.

```python

from PIL import Image
import requests
import torch
from torchvision import io
from typing import Dict
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor

# Load the model in half-precision on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", device_map="auto")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

# Image
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
image = Image.open(requests.get(url, stream=True).raw)

conversation = [
{
"role":"user",
"content":[
{
"type":"image",
},
{
"type":"text",
"text":"Describe this image."
}
]
}
]


# Preprocess the inputs
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe this image.<|im_end|>\n<|im_start|>assistant\n'
simonJJJ marked this conversation as resolved.
Show resolved Hide resolved

inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
inputs = inputs.to('cuda')

# Inference: Generation of the output
output_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
print(output_text)



# Video
def fetch_video(ele: Dict, nframe_factor=2):
if isinstance(ele['video'], str):
def round_by_factor(number: int, factor: int) -> int:
return round(number / factor) * factor

video = ele["video"]
if video.startswith("file://"):
video = video[7:]

video, _, info = io.read_video(
video,
start_pts=ele.get("video_start", 0.0),
end_pts=ele.get("video_end", None),
pts_unit="sec",
output_format="TCHW",
)
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], nframe_factor)
else:
fps = ele.get("fps", 1.0)
nframes = round_by_factor(video.size(0) / info["video_fps"] * fps, nframe_factor)
idx = torch.linspace(0, video.size(0) - 1, nframes, dtype=torch.int64)
return video[idx]
Comment on lines +84 to +107
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zucchini-nlp I think we'd want to have this in transformers utils for people to easily use it no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we want exactly this one. AFAIR there were some discussion on what video-decoder to use and the final decision was to use av. Yet, I didn't see torchvision among options, so this is something I will work on, standardizing video-related processors and adding util functions somewhere (e.g. we also have make_batched_videos)

I think for docs it's okay to use torchvision, so I didn't insist on av


video_info = {"type": "video", "video": "/path/to/video.mp4", "fps": 1.0}
video = fetch_video(video_info)
conversation = [
{
"role": "user",
"content": [
{"type": "video"},
{"type": "text", "text": "What happened in the video?"},
],
}
]

# Preprocess the inputs
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>What happened in the video?<|im_end|>\n<|im_start|>assistant\n'

inputs = processor(text=[text_prompt], videos=[video], padding=True, return_tensors="pt")
inputs = inputs.to('cuda')

# Inference: Generation of the output
output_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
print(output_text)

```


### Batch Mixed Media Inference

The model can batch inputs composed of mixed samples of various types such as images, videos, and text. Here is an example.

```python

image1 = Image.open("/path/to/image1.jpg")
image2 = Image.open("/path/to/image2.jpg")
image3 = Image.open("/path/to/image3.jpg")
image4 = Image.open("/path/to/image4.jpg")
image5 = Image.open("/path/to/image5.jpg")
video = fetch_video({
"type": "video",
"video": "/path/to/video.mp4",
"fps": 1.0
})

# Conversation for the first image
conversation1 = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Describe this image."}
]
}
]

# Conversation with two images
conversation2 = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image"},
{"type": "text", "text": "What is written in the pictures?"}
]
}
]

# Conversation with pure text
conversation3 = [
{
"role": "user",
"content": "who are you?"
}
]


# Conversation with mixed midia
conversation4 = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "image"},
{"type": "video"},
{"type": "text", "text": "What are the common elements in these medias?"},
],
}
]

conversations = [conversation1, conversation2, conversation3, conversation4]
# Preparation for batch inference
texts = [processor.apply_chat_template(msg, add_generation_prompt=True) for msg in conversations]
inputs = processor(
text=texts,
images=[image1, image2, image3, image4, image5],
videos=[video],
padding=True,
return_tensors="pt",
)
inputs = inputs.to('cuda')

# Batch Inference
output_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
print(output_text)
```

### Usage Tips

#### Image Resolution for performance boost

simonJJJ marked this conversation as resolved.
Show resolved Hide resolved
The model supports a wide range of resolution inputs. By default, it uses the native resolution for input, but higher resolutions can enhance performance at the cost of more computation. Users can set the minimum and maximum number of pixels to achieve an optimal configuration for their needs.

```python

min_pixels = 224*224
max_pixels = 2048*2048
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

```



#### Multiple Image Inputs

simonJJJ marked this conversation as resolved.
Show resolved Hide resolved
By default, images and video content are directly included in the conversation. When handling multiple images, it's helpful to add labels to the images and videos for better reference. Users can control this behavior with the following settings:



```python

conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Hello, how are you?"}
]
},
{
"role": "assistant",
"content": "I'm doing well, thank you for asking. How can I assist you today?"
},
{
"role": "user",
"content": [
{"type": "text", "text": "Can you describe these images and video?"},
{"type": "image"},
{"type": "image"},
{"type": "video"},
{"type": "text", "text": "These are from my vacation."}
]
},
{
"role": "assistant",
"content": "I'd be happy to describe the images and video for you. Could you please provide more context about your vacation?"
},
{
"role": "user",
"content": "It was a trip to the mountains. Can you see the details in the images and video?"
}
]

# default:
prompt_without_id = processor.apply_chat_template(conversation, add_generation_prompt=True)
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?<|vision_start|><|image_pad|><|vision_end|><|vision_start|><|image_pad|><|vision_end|><|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'


# add ids
prompt_with_id = processor.apply_chat_template(conversation, add_generation_prompt=True, add_vision_id=True)
# Excepted output: '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>Hello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing well, thank you for asking. How can I assist you today?<|im_end|>\n<|im_start|>user\nCan you describe these images and video?Picture 2: <|vision_start|><|image_pad|><|vision_end|>Picture 3: <|vision_start|><|image_pad|><|vision_end|>Video 1: <|vision_start|><|video_pad|><|vision_end|>These are from my vacation.<|im_end|>\n<|im_start|>assistant\nI'd be happy to describe the images and video for you. Could you please provide more context about your vacation?<|im_end|>\n<|im_start|>user\nIt was a trip to the mountains. Can you see the details in the images and video?<|im_end|>\n<|im_start|>assistant\n'

```

#### Flash-Attention 2 to speed up generation

simonJJJ marked this conversation as resolved.
Show resolved Hide resolved
First, make sure to install the latest version of Flash Attention 2:

```bash
pip install -U flash-attn --no-build-isolation
```

Also, you should have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`.

To load and run a model using Flash Attention-2, simply add `attn_implementation="flash_attention_2"` when loading the model as follows:

```python
from transformers import Qwen2VLForConditionalGeneration

model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
```


## Qwen2VLConfig

[[autodoc]] Qwen2VLConfig

## Qwen2VLImageProcessor

[[autodoc]] Qwen2VLImageProcessor
- preprocess

## Qwen2VLProcessor

[[autodoc]] Qwen2VLProcessor

## Qwen2VLModel

[[autodoc]] Qwen2VLModel
- forward

## Qwen2VLForConditionalGeneration

[[autodoc]] Qwen2VLForConditionalGeneration
- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel)
simonJJJ marked this conversation as resolved.
Show resolved Hide resolved
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
Expand Down Expand Up @@ -230,6 +231,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
Expand Down
Loading
Loading