Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Multi-input support for LLaVA and fix embedding inputs for multi-image models #8238

Merged
merged 4 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,19 @@ Multimodal Language Models
-
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- Image\ :sup:`E`
- Image\ :sup:`E+`
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
-
* - :code:`LlavaNextForConditionalGeneration`
- LLaVA-NeXT
- Image\ :sup:`E+`
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
Comment on lines +230 to +234
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes an issue where the list isn't in alphabetical order.

* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image\ :sup:`E`
Expand All @@ -237,14 +242,9 @@ Multimodal Language Models
- Image\ :sup:`E+`
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
-
* - :code:`MiniCPMV`
- MiniCPM-V
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`QWenLMHeadModel`
- Qwen
- Image
- Qwen-VL
- Image\ :sup:`E`
Comment on lines +246 to +247
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This updates the Qwen-VL row to follow the new format.

- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
-
* - :code:`UltravoxModel`
Expand Down
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(
def generate(
self,
prompts: List[str],
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
Expand Down Expand Up @@ -314,7 +314,7 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
Expand Down Expand Up @@ -351,7 +351,7 @@ def generate_greedy_logprobs(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
**kwargs: Any,
) -> List[List[torch.Tensor]]:
all_logprobs: List[List[torch.Tensor]] = []
Expand Down Expand Up @@ -433,8 +433,8 @@ def generate_greedy_logprobs_limit(
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
audios: Optional[List[Tuple[np.ndarray, int]]] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
Expand Down Expand Up @@ -671,7 +671,7 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[Image.Image]] = None,
images: Optional[PromptImageInput] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file are to resolve type errors when passing multi-input.

) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)
Expand Down
6 changes: 4 additions & 2 deletions tests/distributed/test_multimodal_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
if model.startswith("llava-hf/llava-1.5"):
from ..models.test_llava import models, run_test
elif model.startswith("llava-hf/llava-v1.6"):
from ..models.test_llava_next import models, run_test
from ..models.test_llava_next import run_test # type: ignore[no-redef]
from ..models.test_llava_next import models
elif model.startswith("facebook/chameleon"):
from ..models.test_chameleon import models, run_test
from ..models.test_chameleon import run_test # type: ignore[no-redef]
from ..models.test_chameleon import models
else:
raise NotImplementedError(f"Unsupported model: {model}")

Expand Down
141 changes: 129 additions & 12 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type, overload

import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
Expand All @@ -8,11 +8,14 @@
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from .utils import check_logprobs_close

pytestmark = pytest.mark.vlm

_LIMIT_IMAGE_PER_PROMPT = 4

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
Expand Down Expand Up @@ -52,6 +55,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
return hf_output_ids, hf_output_str, out_logprobs


@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
Expand All @@ -64,6 +68,78 @@ def run_test(
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...


@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...


def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]

if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[prompt for _ in sizes],
[image.resize(size) for size in sizes],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")

_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)


def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.

Expand All @@ -85,13 +161,6 @@ def run_test(
else:
mantis_processor = None

images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
Expand All @@ -100,15 +169,18 @@ def run_test(
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
for prompts, images in inputs
]

if mantis_processor is not None:
Expand All @@ -131,7 +203,7 @@ def process(hf_inputs: BatchEncoding):
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
for prompts, images in inputs
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
Expand Down Expand Up @@ -181,6 +253,51 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image

inputs = [(
[
"USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
"USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
"USER: <image><image><image><image>\nDescribe 4 images.\nASSISTANT:", # noqa: E501
"USER: <image>\nWhat is the season?\nASSISTANT:",
],
[
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
])]

_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@pytest.mark.parametrize("model", models)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying the LLaVA-NeXT test layout for now. The logic will be consolidated in a future PR.

def test_context_length_too_short(vllm_runner, image_assets, model):
images = [asset.pil_image for asset in image_assets]
Expand Down
30 changes: 18 additions & 12 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig

from vllm.attention import AttentionMetadata
Expand All @@ -16,6 +17,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
Expand All @@ -24,7 +26,7 @@
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)


Expand Down Expand Up @@ -133,7 +135,16 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config

image_feature_size = get_max_llava_image_tokens(ctx)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_max_llava_image_tokens(ctx)
elif is_list_of(image_data, Image.Image):
image_feature_size = [get_max_llava_image_tokens(ctx)
] * len(image_data)
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
Expand Down Expand Up @@ -230,29 +241,24 @@ def _parse_and_validate_image_input(
return None

if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)

if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")

# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)

return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
data=flatten_bn(image_embeds, concat=True),
)

raise AssertionError("This line should be unreachable.")
Expand Down
Loading