Skip to content

Commit

Permalink
[VLM] Cleanup validation and update docs (vllm-project#6149)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Jul 5, 2024
1 parent a41357e commit ea4b570
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 82 deletions.
48 changes: 30 additions & 18 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,16 @@ def __init__(self,
config.vocab_size, logit_scale)
self.sampler = Sampler()

def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])

if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
"The expected image tensor shape is batch dimension plus "
"channel, height and width.")
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")

return data

Expand All @@ -173,7 +175,7 @@ def _parse_and_validate_image_input(

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(pixel_values),
data=self._validate_pixel_values(pixel_values),
)

def _select_image_features(self, image_features: torch.Tensor, *,
Expand Down Expand Up @@ -226,18 +228,25 @@ def forward(
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
`"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
`[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
29901]`.
We insert 575 tokens so that including the original image token in the
input, there are a total of 576 (24 * 24) image tokens, which
corresponds to the number of image tokens inputted to the language
model, i.e. the number of image tokens outputted by the visual encoder.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
Expand All @@ -246,6 +255,9 @@ def forward(
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
See also:
:class:`LlavaImageInputs`
"""
image_input = self._parse_and_validate_image_input(**kwargs)

Expand Down
91 changes: 41 additions & 50 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class LlavaNextImagePixelInputs(TypedDict):
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch.
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""

image_sizes: NotRequired[torch.Tensor]
Expand Down Expand Up @@ -255,40 +256,20 @@ def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:

def _validate_shape(data: torch.Tensor):

dim = data.dim()
height = width = self.config.vision_config.image_size
# All 4d image tensors have the same number of patches,
# so data is a 5d batch of these tensors
if dim == 5:
if list(data.shape)[2:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (batch size, "
f"patch number, 3, {height}, {width}), got {data.shape}"
)

# 4d image tensors have different number of patches,
# so data is each individual tensor.
elif dim == 4:
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (patch "
f"number, 3, {height}, {width}), got {data.shape}")
else:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)

def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])

if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
f"Invalid pixel value tensor of shape {data.shape}")
"The expected shape of pixel values in each batch element "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")

if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
for d in data:
_validate_shape(d)

return data

Expand Down Expand Up @@ -464,18 +445,33 @@ def forward(
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
`"A chat between a curious human and an artificial intelligence
assistant. The assistant gives helpful, detailed, and polite answers to
the human's questions.
USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
319, 1799, 9047, 13566, 29901]`.
Unlike in LLaVA-1.5, the number of image tokens inputted to the language
model depends on the original size of the input image. Including the
original image token in the input, the required number of image tokens
is given by :func:`get_llava_next_image_feature_size`.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
Expand All @@ -484,15 +480,10 @@ def forward(
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each grid patch for each input image.
Expects a batch with shape `[1, num_patches, 3, h, w]`.
image_sizes: The original `(height, width)` for each input image.
Expects a batch with shape `[1, 2]`.
See also:
Each input maps to huggingface implementation, as follows:
- `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L690
- `image_sizes`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L691
:class:`LlavaNextImageInputs`
"""
image_input = self._parse_and_validate_image_input(**kwargs)

Expand Down
30 changes: 16 additions & 14 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ class Phi3VImagePixelInputs(TypedDict):
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch.
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""

image_sizes: torch.Tensor
Expand Down Expand Up @@ -466,28 +467,29 @@ def __init__(self,
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected image sizes shape is batch dimension plus "
f"{[2]}. You supplied {data.shape}.")
f"The expected shape of image sizes is batch dimension plus "
f"{[2]}. You supplied {tuple(data.shape)}.")

return data

def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:

def _validate_shape(data: torch.Tensor):
if list(data.shape)[2:] != [
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
]:
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
expected_dims = (3, h, w)

def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])

if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected pixel value tensor shape is batch dimension "
"plus patch number, channel, height and width.")
"The expected shape of pixel values in each batch element "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")

if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
for d in data:
_validate_shape(d)

return data

Expand Down

0 comments on commit ea4b570

Please sign in to comment.