-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Multi-input support for LLaVA and fix embedding inputs for multi-image models #8238
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,14 +219,19 @@ Multimodal Language Models | |
- | ||
* - :code:`LlavaForConditionalGeneration` | ||
- LLaVA-1.5 | ||
- Image\ :sup:`E` | ||
- Image\ :sup:`E+` | ||
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. | ||
- | ||
* - :code:`LlavaNextForConditionalGeneration` | ||
- LLaVA-NeXT | ||
- Image\ :sup:`E+` | ||
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | ||
- | ||
* - :code:`MiniCPMV` | ||
- MiniCPM-V | ||
- Image\ :sup:`+` | ||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. | ||
- | ||
* - :code:`PaliGemmaForConditionalGeneration` | ||
- PaliGemma | ||
- Image\ :sup:`E` | ||
|
@@ -237,14 +242,9 @@ Multimodal Language Models | |
- Image\ :sup:`E+` | ||
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. | ||
- | ||
* - :code:`MiniCPMV` | ||
- MiniCPM-V | ||
- Image\ :sup:`+` | ||
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. | ||
- | ||
* - :code:`QWenLMHeadModel` | ||
- Qwen | ||
- Image | ||
- Qwen-VL | ||
- Image\ :sup:`E` | ||
Comment on lines
+246
to
+247
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This updates the Qwen-VL row to follow the new format. |
||
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. | ||
- | ||
* - :code:`UltravoxModel` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -278,7 +278,7 @@ def __init__( | |
def generate( | ||
self, | ||
prompts: List[str], | ||
images: Optional[List[Image.Image]] = None, | ||
images: Optional[PromptImageInput] = None, | ||
**kwargs: Any, | ||
) -> List[Tuple[List[List[int]], List[str]]]: | ||
if images: | ||
|
@@ -314,7 +314,7 @@ def generate_greedy( | |
self, | ||
prompts: List[str], | ||
max_tokens: int, | ||
images: Optional[List[Image.Image]] = None, | ||
images: Optional[PromptImageInput] = None, | ||
**kwargs: Any, | ||
) -> List[Tuple[List[int], str]]: | ||
outputs = self.generate(prompts, | ||
|
@@ -351,7 +351,7 @@ def generate_greedy_logprobs( | |
self, | ||
prompts: List[str], | ||
max_tokens: int, | ||
images: Optional[List[Image.Image]] = None, | ||
images: Optional[PromptImageInput] = None, | ||
**kwargs: Any, | ||
) -> List[List[torch.Tensor]]: | ||
all_logprobs: List[List[torch.Tensor]] = [] | ||
|
@@ -433,8 +433,8 @@ def generate_greedy_logprobs_limit( | |
prompts: List[str], | ||
max_tokens: int, | ||
num_logprobs: int, | ||
images: Optional[List[Image.Image]] = None, | ||
audios: Optional[List[Tuple[np.ndarray, int]]] = None, | ||
images: Optional[PromptImageInput] = None, | ||
audios: Optional[PromptAudioInput] = None, | ||
**kwargs: Any, | ||
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: | ||
all_logprobs: List[List[Dict[int, float]]] = [] | ||
|
@@ -671,7 +671,7 @@ def generate_greedy( | |
self, | ||
prompts: List[str], | ||
max_tokens: int, | ||
images: Optional[List[Image.Image]] = None, | ||
images: Optional[PromptImageInput] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The changes in this file are to resolve type errors when passing multi-input. |
||
) -> List[Tuple[List[int], str]]: | ||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) | ||
outputs = self.generate(prompts, greedy_params, images=images) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import List, Optional, Tuple, Type | ||
from typing import List, Optional, Tuple, Type, overload | ||
|
||
import pytest | ||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, | ||
|
@@ -8,11 +8,14 @@ | |
from vllm.sequence import SampleLogprobs | ||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE | ||
|
||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets | ||
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, | ||
_ImageAssets) | ||
from .utils import check_logprobs_close | ||
|
||
pytestmark = pytest.mark.vlm | ||
|
||
_LIMIT_IMAGE_PER_PROMPT = 4 | ||
|
||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ | ||
"stop_sign": | ||
"USER: <image>\nWhat's the content of the image?\nASSISTANT:", | ||
|
@@ -52,6 +55,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, | |
return hf_output_ids, hf_output_str, out_logprobs | ||
|
||
|
||
@overload | ||
def run_test( | ||
hf_runner: Type[HfRunner], | ||
vllm_runner: Type[VllmRunner], | ||
|
@@ -64,6 +68,78 @@ def run_test( | |
num_logprobs: int, | ||
tensor_parallel_size: int, | ||
distributed_executor_backend: Optional[str] = None, | ||
): | ||
... | ||
|
||
|
||
@overload | ||
def run_test( | ||
hf_runner: Type[HfRunner], | ||
vllm_runner: Type[VllmRunner], | ||
image_assets: _ImageAssets, | ||
model: str, | ||
*, | ||
sizes: List[Tuple[int, int]], | ||
dtype: str, | ||
max_tokens: int, | ||
num_logprobs: int, | ||
tensor_parallel_size: int, | ||
distributed_executor_backend: Optional[str] = None, | ||
): | ||
... | ||
|
||
|
||
def run_test( | ||
hf_runner: Type[HfRunner], | ||
vllm_runner: Type[VllmRunner], | ||
image_assets: _ImageAssets, | ||
model: str, | ||
*, | ||
size_factors: Optional[List[float]] = None, | ||
sizes: Optional[List[Tuple[int, int]]] = None, | ||
dtype: str, | ||
max_tokens: int, | ||
num_logprobs: int, | ||
tensor_parallel_size: int, | ||
distributed_executor_backend: Optional[str] = None, | ||
): | ||
images = [asset.pil_image for asset in image_assets] | ||
|
||
if size_factors is not None: | ||
inputs_per_image = [( | ||
[prompt for _ in size_factors], | ||
[rescale_image_size(image, factor) for factor in size_factors], | ||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] | ||
elif sizes is not None: | ||
inputs_per_image = [( | ||
[prompt for _ in sizes], | ||
[image.resize(size) for size in sizes], | ||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] | ||
else: | ||
raise ValueError("You must provide either `size_factors` or `sizes`") | ||
|
||
_run_test(hf_runner, | ||
vllm_runner, | ||
inputs_per_image, | ||
model, | ||
dtype=dtype, | ||
max_tokens=max_tokens, | ||
num_logprobs=num_logprobs, | ||
tensor_parallel_size=tensor_parallel_size, | ||
distributed_executor_backend=distributed_executor_backend) | ||
|
||
|
||
def _run_test( | ||
hf_runner: Type[HfRunner], | ||
vllm_runner: Type[VllmRunner], | ||
inputs: List[Tuple[List[str], PromptImageInput]], | ||
model: str, | ||
*, | ||
dtype: str, | ||
max_tokens: int, | ||
num_logprobs: int, | ||
tensor_parallel_size: int, | ||
distributed_executor_backend: Optional[str] = None, | ||
): | ||
"""Inference result should be the same between hf and vllm. | ||
|
||
|
@@ -85,13 +161,6 @@ def run_test( | |
else: | ||
mantis_processor = None | ||
|
||
images = [asset.pil_image for asset in image_assets] | ||
|
||
inputs_per_image = [( | ||
[prompt for _ in size_factors], | ||
[rescale_image_size(image, factor) for factor in size_factors], | ||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] | ||
|
||
# NOTE: take care of the order. run vLLM first, and then run HF. | ||
# vLLM needs a fresh new process without cuda initialization. | ||
# if we run HF first, the cuda initialization will be done and it | ||
|
@@ -100,15 +169,18 @@ def run_test( | |
# max_model_len should be greater than image_feature_size | ||
with vllm_runner(model, | ||
dtype=dtype, | ||
max_model_len=4096, | ||
tensor_parallel_size=tensor_parallel_size, | ||
distributed_executor_backend=distributed_executor_backend, | ||
enforce_eager=True) as vllm_model: | ||
enforce_eager=True, | ||
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT | ||
}) as vllm_model: | ||
vllm_outputs_per_image = [ | ||
vllm_model.generate_greedy_logprobs(prompts, | ||
max_tokens, | ||
num_logprobs=num_logprobs, | ||
images=images) | ||
for prompts, images in inputs_per_image | ||
for prompts, images in inputs | ||
] | ||
|
||
if mantis_processor is not None: | ||
|
@@ -131,7 +203,7 @@ def process(hf_inputs: BatchEncoding): | |
max_tokens, | ||
num_logprobs=num_logprobs, | ||
images=images) | ||
for prompts, images in inputs_per_image | ||
for prompts, images in inputs | ||
] | ||
|
||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, | ||
|
@@ -181,6 +253,51 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, | |
) | ||
|
||
|
||
@pytest.mark.parametrize("model", models) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
@pytest.mark.parametrize("max_tokens", [128]) | ||
@pytest.mark.parametrize("num_logprobs", [5]) | ||
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets, | ||
model, dtype, max_tokens, | ||
num_logprobs) -> None: | ||
stop_sign = image_assets[0].pil_image | ||
cherry_blossom = image_assets[1].pil_image | ||
|
||
inputs = [( | ||
[ | ||
"USER: <image><image>\nDescribe 2 images.\nASSISTANT:", | ||
"USER: <image><image>\nDescribe 2 images.\nASSISTANT:", | ||
"USER: <image><image><image><image>\nDescribe 4 images.\nASSISTANT:", # noqa: E501 | ||
"USER: <image>\nWhat is the season?\nASSISTANT:", | ||
], | ||
[ | ||
[stop_sign, cherry_blossom], | ||
# Images with different sizes and aspect-ratios | ||
[ | ||
rescale_image_size(stop_sign, 0.1), | ||
stop_sign, | ||
], | ||
[ | ||
stop_sign, | ||
rescale_image_size(stop_sign, 0.25), | ||
cherry_blossom.resize((183, 488)), | ||
cherry_blossom.resize((488, 183)) | ||
], | ||
cherry_blossom, | ||
])] | ||
|
||
_run_test( | ||
hf_runner, | ||
vllm_runner, | ||
inputs, | ||
model, | ||
dtype=dtype, | ||
max_tokens=max_tokens, | ||
num_logprobs=num_logprobs, | ||
tensor_parallel_size=1, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("model", models) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copying the LLaVA-NeXT test layout for now. The logic will be consolidated in a future PR. |
||
def test_context_length_too_short(vllm_runner, image_assets, model): | ||
images = [asset.pil_image for asset in image_assets] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fixes an issue where the list isn't in alphabetical order.