From b7e716271fde66d2af5a37f40daa8fed7555f4ee Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Tue, 27 Aug 2024 18:53:56 -0700 Subject: [PATCH] [Core][VLM] Stack multimodal tensors to represent multiple images within each prompt (#7902) --- .../dev/multimodal/multimodal_index.rst | 2 - tests/multimodal/test_base.py | 83 +++++++++++++++++++ vllm/model_executor/models/blip2.py | 7 ++ vllm/model_executor/models/chameleon.py | 3 + vllm/model_executor/models/fuyu.py | 3 + vllm/model_executor/models/internvl.py | 9 ++ vllm/model_executor/models/llava.py | 8 ++ vllm/model_executor/models/llava_next.py | 11 +++ vllm/model_executor/models/minicpmv.py | 11 ++- vllm/model_executor/models/paligemma.py | 8 ++ vllm/model_executor/models/phi3v.py | 8 ++ vllm/model_executor/models/ultravox.py | 9 ++ vllm/model_executor/models/utils.py | 60 +++++++++----- vllm/multimodal/__init__.py | 3 +- vllm/multimodal/base.py | 49 +++++------ 15 files changed, 214 insertions(+), 60 deletions(-) create mode 100644 tests/multimodal/test_base.py diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index a45bc885dc122..241b2ccd0991e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -45,8 +45,6 @@ Base Classes .. autodata:: vllm.multimodal.NestedTensors -.. autodata:: vllm.multimodal.BatchedTensors - .. autodata:: vllm.multimodal.BatchedTensorInputs .. autoclass:: vllm.multimodal.MultiModalDataBuiltins diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py new file mode 100644 index 0000000000000..f19a0f33fe067 --- /dev/null +++ b/tests/multimodal/test_base.py @@ -0,0 +1,83 @@ +import torch + +from vllm.multimodal.base import MultiModalInputs, NestedTensors + + +def assert_nested_tensors_equal(expected: NestedTensors, + actual: NestedTensors): + assert type(expected) == type(actual) + if isinstance(expected, torch.Tensor): + assert torch.equal(expected, actual) + else: + for expected_item, actual_item in zip(expected, actual): + assert_nested_tensors_equal(expected_item, actual_item) + + +def assert_multimodal_inputs_equal(expected: MultiModalInputs, + actual: MultiModalInputs): + assert set(expected.keys()) == set(actual.keys()) + for key in expected: + assert_nested_tensors_equal(expected[key], actual[key]) + + +def test_multimodal_input_batch_single_tensor(): + t = torch.rand([1, 2]) + result = MultiModalInputs.batch([{"image": t}]) + assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)}) + + +def test_multimodal_input_batch_multiple_tensors(): + a = torch.rand([1, 1, 2]) + b = torch.rand([1, 1, 2]) + c = torch.rand([1, 1, 2]) + result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) + assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])}) + + +def test_multimodal_input_batch_multiple_heterogeneous_tensors(): + a = torch.rand([1, 2, 2]) + b = torch.rand([1, 3, 2]) + c = torch.rand([1, 4, 2]) + result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) + assert_multimodal_inputs_equal(result, {"image": [a, b, c]}) + + +def test_multimodal_input_batch_nested_tensors(): + a = torch.rand([2, 3]) + b = torch.rand([2, 3]) + c = torch.rand([2, 3]) + result = MultiModalInputs.batch([{ + "image": [a] + }, { + "image": [b] + }, { + "image": [c] + }]) + assert_multimodal_inputs_equal(result, { + "image": + torch.stack([a.unsqueeze(0), + b.unsqueeze(0), + c.unsqueeze(0)]) + }) + + +def test_multimodal_input_batch_heterogeneous_lists(): + a = torch.rand([1, 2, 3]) + b = torch.rand([1, 2, 3]) + c = torch.rand([1, 2, 3]) + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) + assert_multimodal_inputs_equal( + result, + {"image": [torch.stack([a, b]), c.unsqueeze(0)]}) + + +def test_multimodal_input_batch_multiple_batchable_lists(): + a = torch.rand([1, 2, 3]) + b = torch.rand([1, 2, 3]) + c = torch.rand([1, 2, 3]) + d = torch.rand([1, 2, 3]) + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}]) + assert_multimodal_inputs_equal( + result, + {"image": torch.stack([torch.stack([a, b]), + torch.stack([c, d])])}) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 20dda2a67820d..7c9123079c44f 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -555,6 +555,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return Blip2ImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -564,6 +567,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return Blip2ImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index a335e1766b2a9..2d4f172ce0be6 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -946,6 +946,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return ChameleonImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index cfc2a5288a37b..6cdf331fed8b7 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -249,6 +249,9 @@ def _parse_and_validate_image_input( image_patches = kwargs.pop("image_patches", None) if isinstance(image_patches, torch.Tensor): + # Remove the N dimension until multiple images are supported. + image_patches = image_patches.squeeze(1) + expected_feature_size = self.image_feature_size if image_patches.size(-1) != expected_feature_size: raise ValueError( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index c996f0b73f293..7f213287f33b4 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -244,6 +244,8 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): min_num, max_num, use_thumbnail=use_thumbnail) + # Add an N dimension for number of images per prompt (currently 1). + data = data.unsqueeze(0) model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) @@ -410,6 +412,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Flatten the B and N dimensions + image_embeds = image_embeds.flatten(0, 2) + return InternVLImageEmbeddingInputs( type="image_embeds", data=image_embeds, @@ -422,6 +428,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Flatten the B and N dimensions + pixel_values = pixel_values.flatten(0, 2) + return InternVLImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 6433ea380cbfe..03a0abf1db481 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -232,6 +232,10 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, torch.Tensor): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return LlavaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -241,6 +245,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return LlavaImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 7c096a3794638..3a87242954114 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -361,6 +361,14 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") + # Remove the N dimension until multiple images are supported. + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.squeeze(1) + else: + pixel_values = [t.squeeze(0) for t in pixel_values] + + image_sizes = image_sizes.squeeze(1) + return LlavaNextImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -372,6 +380,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image embeds. " f"Got type: {type(image_embeds)}") + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return LlavaNextImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 29f3640e2458b..6a3d5422e0ce4 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -594,9 +594,14 @@ def _parse_and_validate_inputs( pixel_values_flat: List[torch.Tensor] = [] tgt_sizes_flat: List[torch.Tensor] = [] - for b in range(len(pixel_values)): - pixel_values_flat += pixel_values[b] - tgt_sizes_flat += tgt_sizes[b] + for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): + if len(pixel_b) != len(tgt_b): + raise ValueError("Inconsistent N lengths, found: " + f"{len(pixel_b)} vs {len(tgt_b)}") + + for pixel_n, tgt_n in zip(pixel_b, tgt_b): + pixel_values_flat += pixel_n + tgt_sizes_flat += tgt_n # NOTE: Input IDs does not contain image tokens during memory profiling, # so we allow it to be empty diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8cb5065ed79ec..0700f0c29d708 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -185,6 +185,10 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, torch.Tensor): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return PaliGemmaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -194,6 +198,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return PaliGemmaImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index e55a0ce137ed6..61f1d73976379 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -560,6 +560,14 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") + # Merge the B and N dimensions. + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.flatten(0, 1) + else: + pixel_values = torch.cat(pixel_values) + + image_sizes = image_sizes.flatten(0, 1) + return Phi3VImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 842264f765866..c81c2fd114eb8 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -333,6 +333,12 @@ def _parse_and_validate_audio_input( raise ValueError("Incorrect type of audio features. " f"Got type: {type(audio_features)}") + # Remove the N dimension until multiple audios are supported. + if isinstance(audio_features, torch.Tensor): + audio_features = audio_features.squeeze(1) + else: + audio_features = [t.squeeze(0) for t in audio_features] + return UltravoxAudioFeatureInputs(type="audio_features", data=audio_features) @@ -341,6 +347,9 @@ def _parse_and_validate_audio_input( raise ValueError("Incorrect type of audio embeds. " f"Got type: {type(audio_embeds)}") + # Remove the N dimension until multiple audios are supported. + audio_embeds = audio_embeds.squeeze(1) + return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 91b414b1fd91a..00026b7ebe2e1 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,5 +1,6 @@ from typing import Dict, Iterable, List, Optional, Protocol, Tuple +import numpy as np import torch import torch.nn as nn from torch.func import functional_call @@ -10,7 +11,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.models import ModelRegistry -from vllm.multimodal import BatchedTensors +from vllm.multimodal.base import NestedTensors from vllm.utils import is_pin_memory_available @@ -54,9 +55,34 @@ def init_vllm_registered_model( ) +def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: + """ + Recursively concatenates NestedTensors along any heterogeneously sized + dimensions. + """ + + if isinstance(embeddings, torch.Tensor): + return embeddings + + return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) + + +def _embedding_count_expression(embeddings: NestedTensors) -> str: + """ + Constructs a debugging representation of the number of embeddings in the + NestedTensors. + """ + + if isinstance(embeddings, torch.Tensor): + return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) + + return " + ".join( + _embedding_count_expression(inner) for inner in embeddings) + + def merge_multimodal_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, - multimodal_embeddings: BatchedTensors, + multimodal_embeddings: NestedTensors, placeholder_token_id: int) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the @@ -69,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, mask = (input_ids == placeholder_token_id) num_expected_tokens = mask.sum() - if isinstance(multimodal_embeddings, torch.Tensor): - batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape - total_tokens = batch_size * batch_tokens - if num_expected_tokens != total_tokens: - expr = f"{batch_size} x {batch_tokens}" - raise ValueError( - f"Attempted to assign {expr} = {total_tokens} " - f"multimodal tokens to {num_expected_tokens} placeholders") - - inputs_embeds[mask] = multimodal_embeddings.view( - total_tokens, embed_dim) - else: - size_per_batch = [t.shape[0] for t in multimodal_embeddings] - total_tokens = sum(size_per_batch) - if num_expected_tokens != total_tokens: - expr = ' + '.join(map(str, size_per_batch)) - raise ValueError( - f"Attempted to assign {expr} = {total_tokens} " - f"multimodal tokens to {num_expected_tokens} placeholders") - - inputs_embeds[mask] = torch.cat(multimodal_embeddings) + flattened = _flatten_embeddings(multimodal_embeddings) + *dims, embed_dim = flattened.shape + num_multimodal_embeddings = np.prod(dims) + if num_multimodal_embeddings != num_expected_tokens: + expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( + f"Attempted to assign {expr} = {num_multimodal_embeddings} " + f"multimodal tokens to {num_expected_tokens} placeholders") + inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim) return inputs_embeds diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 456e41ebfad03..489e1e51f05cb 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,4 +1,4 @@ -from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins, +from .base import (BatchedTensorInputs, MultiModalDataBuiltins, MultiModalDataDict, MultiModalInputs, MultiModalPlugin, NestedTensors) from .registry import MultiModalRegistry @@ -14,7 +14,6 @@ __all__ = [ "BatchedTensorInputs", - "BatchedTensors", "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalInputs", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 8ada60c8fd6ae..5b00117c64e53 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,9 +1,8 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import Callable, Dict, List, Mapping, Optional -from typing import Sequence as GenericSequence -from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final +from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type, + TypedDict, TypeVar, Union, cast, final) import numpy as np import torch @@ -15,23 +14,16 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import JSONTree, json_map_leaves +from vllm.utils import json_map_leaves logger = init_logger(__name__) -NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor] +NestedTensors = Union[List["NestedTensors"], torch.Tensor] """ -Use a list instead of a tensor if the dimensions of each element do not match. -Currently only supports up to singly nested list of tensors. +Uses a list instead of a tensor if the dimensions of each element do not match. """ -BatchedTensors: TypeAlias = JSONTree[torch.Tensor] -""" -A nested JSON structure of tensors which have been batched via -:meth:`MultiModalInputs.batch`. -""" - -BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]] +BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via :meth:`MultiModalInputs.batch`. @@ -54,26 +46,23 @@ class MultiModalInputs(_MultiModalInputsBase): """ @staticmethod - def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors: + def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: """ - If each input tensor in the batch has the same shape, return a single - batched tensor; otherwise, return a list of :class:`NestedTensors` with - one element per item in the batch. + Recursively stacks lists of tensors when they all have the same shape. """ - # may be list rather than tensors - if isinstance(tensors[0], list): - return [[t for t in tensor[0]] - for tensor in cast(List[List[torch.Tensor]], tensors)] - - tensors_ = cast(List[torch.Tensor], tensors) + if isinstance(nested_tensors, torch.Tensor): + return nested_tensors - unbatched_shape = tensors_[0].shape[1:] + stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] + if any(isinstance(t, list) for t in stacked): + return stacked - for tensor in tensors_: - if tensor.shape[1:] != unbatched_shape: - return [tensor.squeeze(0) for tensor in tensors_] + tensors_ = cast(List[torch.Tensor], stacked) + if any(t.shape != tensors_[0].shape for t in tensors_): + # The tensors have incompatible shapes and can't be stacked. + return tensors_ - return torch.cat(tensors_, dim=0) + return torch.stack(tensors_) @staticmethod def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: @@ -102,7 +91,7 @@ def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: item_lists[k].append(v) return { - k: MultiModalInputs._try_concat(item_list) + k: MultiModalInputs._try_stack(item_list) for k, item_list in item_lists.items() }