Skip to content

Commit

Permalink
[Model] support input embeddings for qwen2vl (vllm-project#8856)
Browse files Browse the repository at this point in the history
  • Loading branch information
whyiug authored Sep 30, 2024
1 parent d06be6d commit 990c15c
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 71 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ Multimodal Language Models
-
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL
- Image\ :sup:`+` / Video\ :sup:`+`
- Image\ :sup:`E+` / Video\ :sup:`+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
-
* - :code:`UltravoxModel`
Expand Down
17 changes: 17 additions & 0 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,24 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Inference with image embeddings as input with additional parameters
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding.
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
mm_data['image'] = {
"image_embeds": image_embeds,
"image_grid_thw": image_grid_thw,
}
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": mm_data,
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Batch inference
image_1 = PIL.Image.open(...)
image_2 = PIL.Image.open(...)
Expand Down
188 changes: 118 additions & 70 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
Union)
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Tuple, Type, TypedDict, Union)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -76,19 +76,31 @@
# === Vision Inputs === #


class Qwen2VLImageInputs(TypedDict):
pixel_values: torch.Tensor
class Qwen2VLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""

image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""


class Qwen2VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""


Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs,
Qwen2VLImageEmbeddingInputs]


class Qwen2VLVideoInputs(TypedDict):
pixel_values_videos: torch.Tensor
"""Shape:
Expand Down Expand Up @@ -567,6 +579,11 @@ def mm_input_mapper_for_qwen2_vl(
data_type_key: str,
) -> MultiModalInputs:
"""Input mapper for Qwen2-VL."""
if data_type_key == "image" and isinstance(data, dict):
return MultiModalInputs({
"image_embeds": data.get("image_embeds"),
"image_grid_thw": data.get("image_grid_thw"),
})
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
Expand Down Expand Up @@ -739,6 +756,48 @@ def _get_llm_num_vision_tokens(
return llm_num_vision_tokens


def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
data_type_key: str, image_processor: Any,
prompt_token_ids: List[int]) -> List[int]:
"""
Expand pad tokens for multi-modal inputs (e.g., images or videos).
Args:
inputs (list): The multi-modal inputs (e.g., images or videos).
token_id (int): The token ID used to represent the multi-modal input.
make_batched_fn (Callable): A function to batch the inputs.
data_type_key (str): The type of the multi-modal input.
image_processor (Any): The image processor used to process the inputs.
prompt_token_ids (List[int]): The list of token IDs in the prompt.
Returns:
List[int]: The list of token IDs for the multi-modal inputs.
"""
indices = [
idx for idx, token in enumerate(prompt_token_ids) if token == token_id
]
inputs = make_batched_fn(inputs)
assert len(indices) == len(inputs)

prompt_token_ids_with_data = []
for cnt, data in enumerate(inputs):
num_tokens = _get_llm_num_vision_tokens(
[data] if data_type_key == "image" else data,
data_type_key=data_type_key,
image_processor=image_processor,
)
if cnt == 0:
end_idx = indices[cnt]
non_data_tokens = prompt_token_ids[:end_idx]
else:
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
1:indices[cnt]]
prompt_token_ids_with_data.extend(non_data_tokens)
prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens))
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
return prompt_token_ids_with_data


def input_processor_for_qwen2_vl(ctx: InputContext,
llm_inputs: LLMInputs) -> LLMInputs:
multi_modal_data = llm_inputs.get("multi_modal_data", None)
Expand Down Expand Up @@ -775,62 +834,38 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
)["input_ids"]

# Expand image pad tokens.

if image_inputs is not None:
image_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
]
image_inputs = make_batched_images(image_inputs)
assert len(image_indices) == len(image_inputs)

prompt_token_ids_with_image = []
for image_cnt, image in enumerate(image_inputs):
num_image_tokens = _get_llm_num_vision_tokens(
[image],
data_type_key="image",
image_processor=image_processor,
)
if image_cnt == 0:
non_image_tokens = prompt_token_ids[:image_indices[image_cnt]]
else:
non_image_tokens = prompt_token_ids[image_indices[image_cnt -
1] +
1:image_indices[image_cnt]]
prompt_token_ids_with_image.extend(non_image_tokens)
prompt_token_ids_with_image.extend(
hf_config.image_token_id for _ in range(num_image_tokens))
prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_image

# Expand video pad tokens.
if isinstance(image_inputs, dict):
prompt_token_ids_with_image = []
image_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
]
image_cnt = len(image_indices)
embed_dim = image_inputs.get('image_embeds').size(0)
assert embed_dim % image_cnt == 0
num_pad_tokens = embed_dim // image_cnt
for idx, token in enumerate(prompt_token_ids):
if idx in image_indices:
prompt_token_ids_with_image.extend([token] *
num_pad_tokens)
else:
prompt_token_ids_with_image.append(token)
prompt_token_ids = prompt_token_ids_with_image
else:
prompt_token_ids = _expand_pad_tokens(image_inputs,
hf_config.image_token_id,
make_batched_images, "image",
image_processor,
prompt_token_ids)

if video_inputs is not None:
video_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.video_token_id
]
video_inputs = make_batched_videos(video_inputs)
assert len(video_indices) == len(video_inputs)

prompt_token_ids_with_video = []
for video_cnt, video in enumerate(video_inputs):
num_video_tokens = _get_llm_num_vision_tokens(
video,
data_type_key="video",
image_processor=image_processor,
)
if video_cnt == 0:
non_video_tokens = prompt_token_ids[:video_indices[video_cnt]]
else:
non_video_tokens = prompt_token_ids[video_indices[video_cnt -
1] +
1:video_indices[video_cnt]]
prompt_token_ids_with_video.extend(non_video_tokens)
prompt_token_ids_with_video.extend(
hf_config.video_token_id for _ in range(num_video_tokens))
prompt_token_ids_with_video.extend(prompt_token_ids[video_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_video
prompt_token_ids = _expand_pad_tokens(video_inputs,
hf_config.video_token_id,
make_batched_videos, "video",
image_processor,
prompt_token_ids)

return LLMInputs(
prompt_token_ids=prompt_token_ids,
Expand Down Expand Up @@ -910,22 +945,32 @@ def _validate_and_reshape_mm_tensor(self,
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)

if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None

pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")

if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")

return Qwen2VLImageInputs(pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
return Qwen2VLImagePixelInputs(type="pixel_values",
data=pixel_values,
image_grid_thw=image_grid_thw)

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Qwen2VLImageEmbeddingInputs(type="image_embeds",
data=image_embeds)

def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
Expand All @@ -947,7 +992,10 @@ def _parse_and_validate_video_input(

def _process_image_input(self,
image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
if image_input["type"] == "image_embeds":
return image_input["data"].type(self.visual.dtype)

pixel_values = image_input["data"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"])
return image_embeds
Expand Down

0 comments on commit 990c15c

Please sign in to comment.