diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ed847a7e3696b..32eed1a771718 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -338,7 +338,10 @@ steps: - tests/models/decoder_only/vision_language commands: - pytest -v -s models/decoder_only/audio_language - - pytest -v -s models/decoder_only/vision_language + # HACK - run phi3v tests separately to sidestep this transformers bug + # https://github.com/huggingface/transformers/issues/34307 + - pytest -v -s models/decoder_only/vision_language/test_phi3v.py + - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language - label: Other Models Test # 6min #mirror_hardwares: [amd] @@ -413,7 +416,7 @@ steps: # Avoid importing model tests that cause CUDA reinitialization error - pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus - - pytest models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus + - pytest models/decoder_only/vision_language/test_models.py -v -s -m distributed_2_gpus - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py diff --git a/tests/conftest.py b/tests/conftest.py index 2fce2d772c6ed..bdc6ffb148602 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -259,8 +259,7 @@ def __init__( is_sentence_transformer: bool = False, skip_tokenizer_init: bool = False, auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM, - postprocess_inputs: Callable[[BatchEncoding], - BatchEncoding] = identity, + postprocess_inputs: Callable[..., BatchEncoding] = identity, ) -> None: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -303,6 +302,7 @@ def __init__( if skip_tokenizer_init: self.tokenizer = self.processor.tokenizer + self.dtype = dtype self.postprocess_inputs = postprocess_inputs def get_inputs( @@ -337,7 +337,7 @@ def get_inputs( processor_kwargs["sampling_rate"] = sr inputs = self.processor(**processor_kwargs) - inputs = self.postprocess_inputs(inputs) + inputs = self.postprocess_inputs(inputs, dtype=self.dtype) all_inputs.append(inputs) diff --git a/tests/engine/test_short_mm_context.py b/tests/engine/test_short_mm_context.py new file mode 100644 index 0000000000000..a6ba7a131c506 --- /dev/null +++ b/tests/engine/test_short_mm_context.py @@ -0,0 +1,29 @@ +import pytest + +from ..conftest import IMAGE_ASSETS + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "USER: \nWhat's the content of the image?\nASSISTANT:", + "cherry_blossom": + "USER: \nWhat is the season?\nASSISTANT:", +}) + +models = ["llava-hf/llava-1.5-7b-hf"] + + +@pytest.mark.parametrize("model", models) +def test_context_length_too_short(vllm_runner, image_assets, model): + images = [asset.pil_image for asset in image_assets] + + with pytest.raises(ValueError, match="too long to fit into the model"): + vllm_model = vllm_runner( + model, + max_model_len=128, # LLaVA has a feature size of 576 + enforce_eager=True, + ) + + with vllm_model: + vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]], + max_tokens=1, + images=[images[0]]) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index bfffd34d1142c..ad6c2d854d1f0 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -92,7 +92,7 @@ def run_test( for vllm_prompt, _, audio in prompts_and_audios ] - def process(hf_inputs: BatchEncoding): + def process(hf_inputs: BatchEncoding, **kwargs): hf_inputs["audio_values"] = hf_inputs["audio_values"] \ .to(torch_dtype) # type: ignore return hf_inputs diff --git a/tests/models/decoder_only/language/test_qwen.py b/tests/models/decoder_only/language/test_qwen.py new file mode 100644 index 0000000000000..128fe65afbb84 --- /dev/null +++ b/tests/models/decoder_only/language/test_qwen.py @@ -0,0 +1,34 @@ +"""Ensure that a text-only Qwen model can be run without throwing an error. +We explicitly test this because Qwen is implemented as a multimodal and +supports a visual encoder for models like Qwen-VL. +""" +from typing import List, Type + +import pytest + +from ....conftest import VllmRunner + +models = [ + "Qwen/Qwen-7B-Chat" # Has no visual encoder +] + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_text_only_qwen_model_can_be_loaded_and_run( + vllm_runner: Type[VllmRunner], + example_prompts: List[str], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, +): + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_model.generate_greedy_logprobs( + example_prompts, + max_tokens, + num_logprobs=num_logprobs, + ) diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/__init__.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py new file mode 100644 index 0000000000000..c2d3fda6994f6 --- /dev/null +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py @@ -0,0 +1,68 @@ +import pytest + +from vllm.inputs import InputContext + +from ....utils import build_model_context + + +@pytest.fixture() +def get_max_llava_next_image_tokens(): + from vllm.model_executor.models.llava_next import ( + get_max_llava_next_image_tokens) + return get_max_llava_next_image_tokens + + +@pytest.fixture() +def dummy_data_for_llava_next(): + from vllm.model_executor.models.llava_next import dummy_data_for_llava_next + return dummy_data_for_llava_next + + +@pytest.mark.parametrize("gridpoints,expected_max_tokens", [ + ([[336, 336]], 1176), + ([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], 2928), +]) +def test_get_max_llava_next_image_tokens(gridpoints, expected_max_tokens, + get_max_llava_next_image_tokens): + ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf") + + # Update the config image_grid_pinpoints + # and calculate the resulting max tokens + ctx.model_config.hf_config.image_grid_pinpoints = gridpoints + + actual_max_tokens = get_max_llava_next_image_tokens( + InputContext(ctx.model_config)) + + assert expected_max_tokens == actual_max_tokens + + +@pytest.mark.parametrize( + "gridpoints,expected_size", + [ + # One point; it has to be the largest + ([[336, 336]], (336, 336)), + # Default for most llava next models; the 2x2 tile is the largest + ([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], + (672, 672)), + # If two rectangular gridpoints are the same, the more vertical + # one has the higher feature count due to newline features + ([[336, 672], [672, 336]], (672, 336)) + ]) +def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next, + gridpoints, expected_size): + ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf") + + # Update the config image_grid_pinpoints + ctx.model_config.hf_config.image_grid_pinpoints = gridpoints + seq_len = 5000 # bigger than the max feature size for any image + + seq_data, mm_data = dummy_data_for_llava_next( + ctx, + seq_len=seq_len, + mm_counts={"image": 1}, + ) + + # The dummy data dims should match the gridpoint with the biggest feat size + assert mm_data["image"].height == expected_size[0] + assert mm_data["image"].width == expected_size[1] + assert len(seq_data.get_token_ids()) >= seq_len diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py new file mode 100644 index 0000000000000..d6a7b34fdde9f --- /dev/null +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py @@ -0,0 +1,181 @@ +"""Tests for phi3v's multimodal preprocessing kwargs.""" +from typing import Optional + +import pytest +import torch +from transformers import AutoImageProcessor, AutoTokenizer + +from vllm.inputs import InputContext, token_inputs +from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID +from vllm.multimodal import MultiModalRegistry + +from .....conftest import _ImageAssets +from ....utils import build_model_context + +models = ["microsoft/Phi-3.5-vision-instruct"] + + +# Wrap lazy imports to avoid initializing CUDA during test collection +@pytest.fixture() +def input_processor_for_phi3v(): + from vllm.model_executor.models.phi3v import input_processor_for_phi3v + return input_processor_for_phi3v + + +@pytest.fixture() +def dummy_data_for_phi3v(): + from vllm.model_executor.models.phi3v import dummy_data_for_phi3v + return dummy_data_for_phi3v + + +@pytest.fixture() +def get_max_phi3v_image_tokens(): + from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens + return get_max_phi3v_image_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops", [4, 16, None]) +def test_input_mapper_override(model: str, image_assets: _ImageAssets, + num_crops: Optional[int]): + """Ensure that the [default] input mapper handles num_crops properly.""" + # We pass the processor kwargs here since for this model, we fall back to + # the default mapper; this will fall back to the HF mapper and forward + # mm_processor_kwargs to it. + mm_processor_kwargs = { + "num_crops": num_crops + } if num_crops is not None else {} + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + ) + + hf_processor = AutoImageProcessor.from_pretrained(model, + trust_remote_code=True, + **mm_processor_kwargs) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + image = image_assets[0].pil_image + hf_result = hf_processor.preprocess( + image, + return_tensors="pt", + ) + + vllm_result = mm_registry.map_input( + ctx.model_config, + {"image": image}, + ) + + assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"]) + assert torch.all( + hf_result["num_img_tokens"] == vllm_result["num_img_tokens"]) + + # For pixel values, the second axis should be the num_crops + 1 + # for the rescaled original image. The default value in VLLM falls + # back to the HF config, which is why we compare to the processor num_crops + assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"]) + assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1 + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_max_tokens", [ + (4, 781), + (16, 2653), +]) +def test_max_tokens_override(get_max_phi3v_image_tokens, model: str, + num_crops: int, expected_max_tokens: int): + """Ensure get_max_phi3v_image_tokens handles num_crops properly.""" + # NOTE: mm_processor_kwargs on the context in this test is unused, since + # this is testing the mapper directly. In practice, the processor kwargs + # are wrapped in a closure when calling the max tokens func. We explicitly + # do NOT use the mm_processor_kwargs in the model context here to ensure + # that the max image tokens implementation is referencing a mix of the + # kwargs to the function and the original mm_processor_kwargs in case + # values are somehow updated and end up in a bad state. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + actual_max_tokens = get_max_phi3v_image_tokens( + InputContext(ctx.model_config), + num_crops=num_crops, + ) + + assert expected_max_tokens == actual_max_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [ + (4, 781, 1), + (4, 781, 2), + (16, 2653, 1), + (16, 2653, 2), +]) +def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int, + toks_per_img: int, num_imgs: int): + """Ensure dummy_data_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the dummy data func. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + sequence_data, _, = dummy_data_for_phi3v( + ctx=ctx, + seq_len=8192, # Should be bigger than num_imgs * toks_per_img + mm_counts={"image": num_imgs}, + num_crops=num_crops, + ) + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) + assert img_tok_count == toks_per_img * num_imgs + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [ + (4, 757, 1), + (4, 757, 2), + (16, 1921, 1), + (16, 1921, 2), +]) +def test_input_processor_override(input_processor_for_phi3v, + image_assets: _ImageAssets, model: str, + num_crops: int, expected_toks_per_img: int, + num_imgs: int): + """Ensure input_processor_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the custom input processor. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model) + # Build the image str / prompt based on the number of images we pass + img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) + prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" + images = [image_assets[0].pil_image] * num_imgs + + inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) + + processed_inputs = input_processor_for_phi3v(ctx, + inputs, + num_crops=num_crops) + + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py new file mode 100644 index 0000000000000..a01651b171d60 --- /dev/null +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py @@ -0,0 +1,144 @@ +"""Tests for Qwen's multimodal preprocessing kwargs.""" +from typing import Dict, List, Union + +import pytest +import torch +from PIL.Image import Image + +from vllm.inputs import InputContext, token_inputs +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer + +from .....conftest import IMAGE_ASSETS +from ....utils import build_model_context + +### Multimodal preprocessing tests +SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image +# These values are specific to Qwen-VL/Chat; we can get these from the model +# config also, but they are hardcoded here to keep the parameterize/fixtures +# easy to read. +IMG_START_ID = 151857 +IMG_END_ID = 151858 +IMG_PAD_ID = 151859 +TOKS_PER_IMG = 256 +VIS_ENC_DIM = 4096 +IMG_SIZE = 448 + + +@pytest.fixture() +def input_mapper_for_qwen(): + # Lazy import to avoid initializing CUDA during test collection + from vllm.model_executor.models.qwen import input_mapper_for_qwen + return input_mapper_for_qwen + + +@pytest.fixture() +def input_processor_for_qwen(): + # Lazy import to avoid initializing CUDA during test collection + from vllm.model_executor.models.qwen import input_processor_for_qwen + return input_processor_for_qwen + + +@pytest.fixture() +def qwen_vl_context() -> InputContext: + """Get an InputContext for Qwen-VL.""" + return build_model_context(model_name="Qwen/Qwen-VL", + trust_remote_code=True) + + +# Happy path tests for single/multi-image scenarios for the multimodal +# input processor and mapper, respectively +@pytest.mark.parametrize("num_images", [1, 2]) +def test_input_processor_valid_mm_data(input_processor_for_qwen, + qwen_vl_context: InputContext, + num_images: int): + """Happy cases for image inputs to Qwen's multimodal input processor.""" + prompt = "".join( + [f"Picture {num}: \n" for num in range(1, num_images + 1)]) + inputs = token_inputs( + prompt=prompt, + # When processing multimodal data for a multimodal model, the qwen + # input processor will overwrite the provided prompt_token_ids with + # the image prompts + prompt_token_ids=[], + multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)}, + ) + proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs) + assert isinstance(proc_inputs, dict) + + # Each image should have one start / stop and a fixed context of 256 + proc_tokens = proc_inputs["prompt_token_ids"] + assert proc_tokens.count(IMG_START_ID) == num_images + assert proc_tokens.count(IMG_END_ID) == num_images + assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG + + +@pytest.mark.parametrize( + "img_data,expected_shape", + [ + # single / multi-image + (SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)), + (2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)), + # single / multi-image embeddings + (torch.rand( + (TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)), + (torch.rand( + (1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)), + (torch.rand( + (2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)), + ]) +def test_input_mapper_valid_mm_data(input_mapper_for_qwen, + qwen_vl_context: InputContext, + img_data: Union[torch.Tensor, List[Image], + Image], + expected_shape: List[int]): + """Happy cases for image inputs to Qwen's multimodal input mapper.""" + mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data) + # Ensure that we get the appropriately shaped pixel_values + # for images and image embeddings, respectively. + assert isinstance(mapped_img_data, MultiModalInputs) + assert "pixel_values" in mapped_img_data + assert mapped_img_data["pixel_values"].shape == expected_shape + + +# Sad path tests for the multimodal input processor and mapper, respectively +@pytest.mark.parametrize("mm_data", [ + { + "image": torch.rand((5)) + }, + { + "image": torch.rand((5, 5, 5, 5, 5)) + }, +]) +def test_input_processor_invalid_mm_data(input_processor_for_qwen, + qwen_vl_context: InputContext, + mm_data: Dict[str, torch.Tensor]): + """Test sad cases validated in Qwen's multimodal input processor.""" + tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer, + trust_remote_code=True) + prompt = "Picture 1: \n" + prompt_token_ids = tokenizer.encode(prompt) + inputs = token_inputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=mm_data) + # Should fail since we have too many or too few dimensions for embeddings + with pytest.raises(ValueError): + input_processor_for_qwen(qwen_vl_context, inputs) + + +@pytest.mark.parametrize( + "img_data", + [ + # Wrong context length + torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)), + # Wrong visual encoder output size + torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)), + ]) +def test_input_mapper_invalid_mm_data( + input_mapper_for_qwen, + qwen_vl_context: InputContext, + img_data: Union[torch.Tensor, List[Image], Image], +): + """Sad cases validated in Qwen VL's multimodal input mapper.""" + with pytest.raises(ValueError): + input_mapper_for_qwen(qwen_vl_context, img_data) diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py similarity index 98% rename from tests/models/decoder_only/vision_language/test_qwen2_vl.py rename to tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py index d3de5fb26d4b8..5c90e7f7a267c 100644 --- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py @@ -8,8 +8,8 @@ from vllm.inputs import InputContext, token_inputs from vllm.multimodal import MultiModalRegistry -from ....conftest import _ImageAssets -from ...utils import build_model_context +from .....conftest import _ImageAssets +from ....utils import build_model_context MODEL = "Qwen/Qwen2-VL-2B-Instruct" MIN_PIXELS = "min_pixels" diff --git a/tests/models/decoder_only/vision_language/test_blip2.py b/tests/models/decoder_only/vision_language/test_blip2.py deleted file mode 100644 index e1e32b96d89ac..0000000000000 --- a/tests/models/decoder_only/vision_language/test_blip2.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import List, Optional, Tuple - -import pytest -from transformers import AutoModelForVision2Seq, AutoTokenizer - -from vllm.multimodal.utils import rescale_image_size -from vllm.sequence import SampleLogprobs - -from ....conftest import IMAGE_ASSETS -from ...utils import check_logprobs_close - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "Question: What's the content of the image? Answer:", - "cherry_blossom": - "Question: What is the season? Answer:", -}) - - -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - _, output_str, out_logprobs = vllm_output - - hf_output_str = output_str + "\n" - - tokenizer = AutoTokenizer.from_pretrained(model) - hf_output_ids = tokenizer.encode(hf_output_str) - assert hf_output_ids[0] == tokenizer.bos_token_id - hf_output_ids = hf_output_ids[1:] - - return hf_output_ids, hf_output_str, out_logprobs - - -@pytest.mark.parametrize("model", ["Salesforce/blip2-opt-2.7b"]) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_tokens: int, num_logprobs: int) -> None: - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalData objects and corresponding - MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - - # max_model_len should be greater than image_feature_size - with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: - vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs_per_image - ] - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForVision2Seq) as hf_model: - hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs_per_image - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) diff --git a/tests/models/decoder_only/vision_language/test_broadcast.py b/tests/models/decoder_only/vision_language/test_broadcast.py deleted file mode 100644 index 38c4a95de16f4..0000000000000 --- a/tests/models/decoder_only/vision_language/test_broadcast.py +++ /dev/null @@ -1,46 +0,0 @@ -import pytest -import transformers - -from ....utils import multi_gpu_test - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@pytest.mark.parametrize("model", [ - "llava-hf/llava-1.5-7b-hf", - "llava-hf/llava-v1.6-mistral-7b-hf", - "facebook/chameleon-7b", -]) -def test_models(hf_runner, vllm_runner, image_assets, - distributed_executor_backend, model) -> None: - - dtype = "half" - max_tokens = 5 - num_logprobs = 5 - tensor_parallel_size = 2 - - if model.startswith("llava-hf/llava-1.5"): - from .test_llava import models, run_test - elif model.startswith("llava-hf/llava-v1.6"): - from .test_llava_next import models, run_test # type: ignore[no-redef] - elif model.startswith("facebook/chameleon"): - if transformers.__version__.startswith("4.46"): - pytest.skip("Model broken in HF, " - "see huggingface/transformers#34379") - from .test_chameleon import models, run_test # type: ignore[no-redef] - else: - raise NotImplementedError(f"Unsupported model: {model}") - - run_test( - hf_runner, - vllm_runner, - image_assets, - model=models[0], - # So that LLaVA-NeXT processor may return nested list - size_factors=[0.25, 0.5, 1.0], - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - ) diff --git a/tests/models/decoder_only/vision_language/test_chameleon.py b/tests/models/decoder_only/vision_language/test_chameleon.py deleted file mode 100644 index 4bd678b9f21c4..0000000000000 --- a/tests/models/decoder_only/vision_language/test_chameleon.py +++ /dev/null @@ -1,130 +0,0 @@ -from typing import List, Optional, Type - -import pytest -import transformers -from transformers import AutoModelForVision2Seq, BatchEncoding - -from vllm.multimodal.utils import rescale_image_size -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE - -from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from ...utils import check_outputs_equal - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "USER: \nWhat's the content of the image?\nASSISTANT:", - "cherry_blossom": - "USER: \nWhat is the season?\nASSISTANT:", -}) - -models = ["facebook/chameleon-7b"] - - -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, - model: str, - *, - size_factors: List[float], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - - with vllm_runner(model, - max_model_len=4096, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - - vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs_per_image - ] - - def process(hf_inputs: BatchEncoding): - hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \ - .to(torch_dtype) # type: ignore - return hf_inputs - - with hf_runner(model, - dtype=dtype, - postprocess_inputs=process, - auto_cls=AutoModelForVision2Seq) as hf_model: - hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs_per_image - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): - # HF Logprobs include image tokens, unlike vLLM, so we don't directly - # compare them - check_outputs_equal( - outputs_0_lst=[outputs[:2] for outputs in hf_outputs], - outputs_1_lst=[outputs[:2] for outputs in vllm_outputs], - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.skipif( - transformers.__version__.startswith("4.46.0"), - reason="Model broken in HF, see huggingface/transformers#34379", -) -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [8]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype, max_tokens, num_logprobs) -> None: - run_test( - hf_runner, - vllm_runner, - image_assets, - model, - size_factors=size_factors, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) diff --git a/tests/models/decoder_only/vision_language/test_fuyu.py b/tests/models/decoder_only/vision_language/test_fuyu.py deleted file mode 100644 index 1affcd10ee72d..0000000000000 --- a/tests/models/decoder_only/vision_language/test_fuyu.py +++ /dev/null @@ -1,139 +0,0 @@ -from typing import List, Optional, Tuple, Type - -import pytest - -from vllm.multimodal.utils import rescale_image_size -from vllm.platforms import current_platform -from vllm.sequence import SampleLogprobs - -from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from ...utils import check_logprobs_close - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "What's the content of the image?\n", - "cherry_blossom": - "What is the season?\n", -}) - -models = ["adept/fuyu-8b"] - - -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, - Optional[SampleLogprobs]]): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - hf_output_str = output_str.lstrip() + "|ENDOFTEXT|" - - return output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, - model: str, - *, - size_factors: List[float], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - max_model_len=2048, - max_num_seqs=2, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs_per_image - ] - - with hf_runner(model, dtype=dtype) as hf_model: - eos_token_id = hf_model.processor.tokenizer.eos_token_id - hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - eos_token_id=eos_token_id) - for prompts, images in inputs_per_image - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - -target_dtype = "half" -if current_platform.is_cpu(): - target_dtype = "bfloat16" - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [0.25], - # Single-scale, batched - [0.25, 0.25, 0.25], - # Multi-scale - [0.25, 0.2, 0.15], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_tokens: int, num_logprobs: int) -> None: - run_test( - hf_runner, - vllm_runner, - image_assets, - model, - size_factors=size_factors, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) diff --git a/tests/models/decoder_only/vision_language/test_glm4.py b/tests/models/decoder_only/vision_language/test_glm4.py deleted file mode 100644 index 47922a57f680b..0000000000000 --- a/tests/models/decoder_only/vision_language/test_glm4.py +++ /dev/null @@ -1,133 +0,0 @@ -from typing import List, Optional, Tuple, Type - -import pytest - -from vllm.multimodal.utils import rescale_image_size -from vllm.transformers_utils.tokenizer import patch_padding_side - -from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from ....utils import large_gpu_test -from ...utils import check_logprobs_close - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "What's the content of the image?", - "cherry_blossom": - "What is the season?", -}) - -models = ["THUDM/glm-4v-9b"] -target_dtype = "bfloat16" - - -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], PromptImageInput]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - mm_limit: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - max_model_len=2048, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - stop_token_ids = [151329, 151336, 151338] - vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - stop_token_ids=stop_token_ids) - for prompts, images in inputs - ] - - with hf_runner(model, dtype=dtype) as hf_model: - hf_processor = hf_model.processor - patch_padding_side(hf_processor) - - def processor(*args, text="", images=None, **kwargs): - if images is None: - return hf_processor(*args, **kwargs) - - return hf_processor.apply_chat_template( - [{ - "role": "user", - "image": images, - "content": text - }], - add_generation_prompt=True, - tokenize=True, - return_dict=True, - **kwargs, - ) - - hf_model.processor = processor - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.transformer.output_layer - hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - ) for prompts, images in inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@large_gpu_test(min_gb=48) -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_tokens: int, num_logprobs: int) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - run_test( - hf_runner, - vllm_runner, - inputs_per_image, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=1, - tensor_parallel_size=1, - ) diff --git a/tests/models/decoder_only/vision_language/test_internvl.py b/tests/models/decoder_only/vision_language/test_internvl.py index fc842ec4a6171..2fd1ac4bb08f7 100644 --- a/tests/models/decoder_only/vision_language/test_internvl.py +++ b/tests/models/decoder_only/vision_language/test_internvl.py @@ -1,15 +1,11 @@ -import types -from typing import List, Optional, Tuple, Type, Union +from typing import List, Optional, Tuple, Type import pytest import torch -from PIL.Image import Image -from transformers import AutoConfig from vllm.multimodal.utils import rescale_image_size -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) +from ....conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ @@ -18,171 +14,6 @@ "cherry_blossom": "<|im_start|>User\n\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 }) -HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: \nImage-2: \nDescribe the two images in short.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501 - -models = [ - "OpenGVLab/InternVL2-1B", - "OpenGVLab/InternVL2-2B", - # NOTE: Mono-InternVL-2B doesn't work with fp16, - # it will result NaN during inference. - # See: https://huggingface.co/OpenGVLab/Mono-InternVL-2B/discussions/9 - "OpenGVLab/Mono-InternVL-2B", - # Broken due to outdated implementation of Phi-3 - # See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3 - # "OpenGVLab/InternVL2-4B", -] -target_dtype = "bfloat16" - - -# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py -def generate( - self, - pixel_values: torch.FloatTensor, - input_ids: torch.FloatTensor, - attention_mask: Optional[torch.LongTensor] = None, - **generate_kwargs, -) -> torch.LongTensor: - """Generate method for InternVL2 model without fixed use_cache.""" - assert self.img_context_token_id is not None - vit_embeds = self.extract_feature(pixel_values) - input_embeds = self.language_model.get_input_embeddings()(input_ids) - B, N, C = input_embeds.shape - input_embeds = input_embeds.reshape(B * N, C) - - input_ids = input_ids.reshape(B * N) - selected = (input_ids == self.img_context_token_id) - assert selected.sum() != 0 - input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) - - input_embeds = input_embeds.reshape(B, N, C) - - forward_kwargs = dict( - inputs_embeds=input_embeds, - attention_mask=attention_mask, - ) - if getattr(self, "use_visual_token_mask", False): - visual_token_mask = selected.reshape(B, N, 1).to(input_embeds.dtype) - forward_kwargs["visual_token_mask"] = visual_token_mask - outputs = self.language_model.generate( - **forward_kwargs, - **generate_kwargs, - ) - - return outputs - - -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], PromptImageInput]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - mm_limit: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - class InternVLProcessor: - """A simple processor for InternVL2 which misses a processor.""" - - def __init__(self, hf_runner: HfRunner): - self.num_image_token = hf_runner.model.num_image_token - self.tokenizer = hf_runner.tokenizer - self.dtype = hf_runner.model.dtype - - self.config = AutoConfig.from_pretrained(hf_runner.model_name, - trust_remote_code=True) - self.vision_config = self.config.vision_config - self.use_thumbnail = self.config.use_thumbnail - self.min_num = self.config.min_dynamic_patch - self.max_num = self.config.max_dynamic_patch - self.image_size = self.vision_config.image_size - - def __call__(self, text: str, images: Union[Image, List[Image]], - **kwargs): - from vllm.model_executor.models.internvl import ( - IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values) - images = [images] if isinstance(images, Image) else images - pixel_values = [ - image_to_pixel_values(image, self.image_size, self.min_num, - self.max_num, - self.use_thumbnail).to(self.dtype) - for image in images - ] - num_patches_list = [ - pixel_value.shape[0] for pixel_value in pixel_values - ] - pixel_values = torch.cat(pixel_values, dim=0) - for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token \ - * num_patches - image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('', image_tokens, 1) - prompt = self.tokenizer(text, return_tensors="pt") - prompt.update({"pixel_values": pixel_values}) - return prompt - - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - max_model_len=4096, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - with hf_runner(model, dtype=dtype) as hf_model: - img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( - "") - hf_model.model.img_context_token_id = img_context_token_id - hf_model.processor = InternVLProcessor(hf_model) - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() - hf_model.model.generate = types.MethodType(generate, hf_model.model) - eos_token_id = hf_model.tokenizer.eos_token_id - hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=hf_images, - eos_token_id=eos_token_id) - for prompts, hf_images in inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): - # TODO: Check whether using original CLIPVisionModel can improve - # consistency against HF - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) def run_awq_test( @@ -253,123 +84,6 @@ def run_awq_test( ) -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@torch.inference_mode() -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_tokens: int, num_logprobs: int) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - - run_test( - hf_runner, - vllm_runner, - inputs_per_image, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=1, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.5, 0.75, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@torch.inference_mode() -def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, - size_factors, dtype: str, max_tokens: int, - num_logprobs: int) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_case = [ - ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors]) - ] - - run_test( - hf_runner, - vllm_runner, - inputs_per_case, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=2, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", ["OpenGVLab/InternVL2-2B"]) -@pytest.mark.parametrize("size_factors", [[0.5, 1.0]]) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@torch.inference_mode() -def test_different_num_patches(hf_runner, vllm_runner, image_assets, model, - size_factors, dtype: str, max_tokens: int, - num_logprobs: int) -> None: - images = [asset.pil_image.resize((896, 896)) for asset in image_assets] - - inputs_batching = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - - inputs_multi_images = [ - ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors]) - ] - for inputs in [inputs_batching, inputs_multi_images]: - run_test( - hf_runner, - vllm_runner, - inputs, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=2, - tensor_parallel_size=1, - ) - - @pytest.mark.parametrize( "models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")]) @pytest.mark.parametrize( diff --git a/tests/models/decoder_only/vision_language/test_llava.py b/tests/models/decoder_only/vision_language/test_llava.py deleted file mode 100644 index fd28a9367b4b2..0000000000000 --- a/tests/models/decoder_only/vision_language/test_llava.py +++ /dev/null @@ -1,313 +0,0 @@ -from typing import List, Optional, Tuple, Type, overload - -import pytest -from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, - BatchEncoding) - -from vllm.multimodal.utils import rescale_image_size -from vllm.sequence import SampleLogprobs -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE - -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) -from ...utils import check_logprobs_close - -_LIMIT_IMAGE_PER_PROMPT = 4 - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "USER: \nWhat's the content of the image?\nASSISTANT:", - "cherry_blossom": - "USER: \nWhat is the season?\nASSISTANT:", -}) - -models = [ - "llava-hf/llava-1.5-7b-hf", - # TODO: Get this model to produce meaningful output in vLLM - # "TIGER-Lab/Mantis-8B-siglip-llama3", -] - - -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - config = AutoConfig.from_pretrained(model) - image_token_id = config.image_token_index - - tokenizer = AutoTokenizer.from_pretrained(model) - eos_token_id = tokenizer.eos_token_id - - hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) - if token_id != image_token_id or output_ids[idx - 1] != image_token_id - ] - - assert output_str[0] == " " - hf_output_str = output_str[1:] - if hf_output_ids[-1] == eos_token_id: - hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) - - return hf_output_ids, hf_output_str, out_logprobs - - -@overload -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, - model: str, - *, - size_factors: List[float], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - ... - - -@overload -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, - model: str, - *, - sizes: List[Tuple[int, int]], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - ... - - -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, - model: str, - *, - size_factors: Optional[List[float]] = None, - sizes: Optional[List[Tuple[int, int]]] = None, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - images = [asset.pil_image for asset in image_assets] - - if size_factors is not None: - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - elif sizes is not None: - inputs_per_image = [( - [prompt for _ in sizes], - [image.resize(size) for size in sizes], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - else: - raise ValueError("You must provide either `size_factors` or `sizes`") - - _run_test(hf_runner, - vllm_runner, - inputs_per_image, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend) - - -def _run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], PromptImageInput]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - # NOTE: For local use; this isn't tested in CI yet (see TODO above) - if model.startswith("TIGER-Lab/Mantis"): - from mantis.models.mllava import MLlavaProcessor - - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] - mantis_processor = MLlavaProcessor.from_pretrained( - model, torch_dtype=torch_dtype) - assert isinstance(mantis_processor, MLlavaProcessor) - else: - mantis_processor = None - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - dtype=dtype, - max_model_len=4096, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True, - limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT - }) as vllm_model: - vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - if mantis_processor is not None: - - def process(hf_inputs: BatchEncoding): - hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \ - .to(torch_dtype) # type: ignore - return hf_inputs - else: - - def process(hf_inputs: BatchEncoding): - return hf_inputs - - with hf_runner(model, - dtype=dtype, - postprocess_inputs=process, - auto_cls=AutoModelForVision2Seq) as hf_model: - hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): - # TODO: Check whether using original CLIPVisionModel can improve - # consistency against HF - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype, max_tokens, num_logprobs) -> None: - run_test( - hf_runner, - vllm_runner, - image_assets, - model, - size_factors=size_factors, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets, - model, dtype, max_tokens, - num_logprobs) -> None: - stop_sign = image_assets[0].pil_image - cherry_blossom = image_assets[1].pil_image - - inputs = [( - [ - "USER: \nDescribe 2 images.\nASSISTANT:", - "USER: \nDescribe 2 images.\nASSISTANT:", - "USER: \nDescribe 4 images.\nASSISTANT:", # noqa: E501 - "USER: \nWhat is the season?\nASSISTANT:", - ], - [ - [stop_sign, cherry_blossom], - # Images with different sizes and aspect-ratios - [ - rescale_image_size(stop_sign, 0.1), - stop_sign, - ], - [ - stop_sign, - rescale_image_size(stop_sign, 0.25), - cherry_blossom.resize((183, 488)), - cherry_blossom.resize((488, 183)) - ], - cherry_blossom, - ])] - - _run_test( - hf_runner, - vllm_runner, - inputs, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", models) -def test_context_length_too_short(vllm_runner, image_assets, model): - images = [asset.pil_image for asset in image_assets] - - with pytest.raises(ValueError, match="too long to fit into the model"): - vllm_model = vllm_runner( - model, - max_model_len=128, # LLaVA has a feature size of 576 - enforce_eager=True, - ) - - with vllm_model: - vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]], - max_tokens=1, - images=[images[0]]) diff --git a/tests/models/decoder_only/vision_language/test_llava_image_embeds.py b/tests/models/decoder_only/vision_language/test_llava_image_embeds.py deleted file mode 100644 index 66414032509ed..0000000000000 --- a/tests/models/decoder_only/vision_language/test_llava_image_embeds.py +++ /dev/null @@ -1,158 +0,0 @@ -from typing import List, Optional, Tuple, Type - -import pytest -from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer - -from vllm.sequence import SampleLogprobs - -from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets -from ...utils import check_logprobs_close - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "USER: \nWhat's the content of the image?\nASSISTANT:", - "cherry_blossom": - "USER: \nWhat is the season?\nASSISTANT:", -}) - -models = [ - "llava-hf/llava-1.5-7b-hf", -] - - -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - config = AutoConfig.from_pretrained(model) - image_token_id = config.image_token_index - - tokenizer = AutoTokenizer.from_pretrained(model) - eos_token_id = tokenizer.eos_token_id - - hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) - if token_id != image_token_id or output_ids[idx - 1] != image_token_id - ] - - assert output_str[0] == " " - hf_output_str = output_str[1:] - if hf_output_ids[-1] == eos_token_id: - hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) - - return hf_output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, - model: str, - *, - size_factors: List[float], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding vision language config as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - - # vLLM to load from image embeddings - vllm_images = [asset.image_embeds for asset in image_assets] - - # transformers to load from PIL images - hf_images = [asset.pil_image for asset in image_assets] - - vllm_inputs_per_image = [( - [prompt for _ in size_factors], - [image for _ in size_factors], - ) for image, prompt in zip(vllm_images, HF_IMAGE_PROMPTS)] - - hf_inputs_per_image = [( - [prompt for _ in size_factors], - [image for _ in size_factors], - ) for image, prompt in zip(hf_images, HF_IMAGE_PROMPTS)] - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in vllm_inputs_per_image - ] - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForVision2Seq) as hf_model: - hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in hf_inputs_per_image - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): - # TODO: Check whether using original CLIPVisionModel can improve - # consistency against HF - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - ], -) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_tokens: int, num_logprobs: int) -> None: - run_test( - hf_runner, - vllm_runner, - image_assets, - model, - size_factors=size_factors, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) diff --git a/tests/models/decoder_only/vision_language/test_llava_next.py b/tests/models/decoder_only/vision_language/test_llava_next.py deleted file mode 100644 index aa9b297c5dd4e..0000000000000 --- a/tests/models/decoder_only/vision_language/test_llava_next.py +++ /dev/null @@ -1,347 +0,0 @@ -from typing import List, Optional, Tuple, Type, overload - -import pytest -from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer - -from vllm.inputs import InputContext -from vllm.multimodal.utils import rescale_image_size -from vllm.sequence import SampleLogprobs - -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, - _ImageAssets) -from ...utils import build_model_context, check_logprobs_close - -_LIMIT_IMAGE_PER_PROMPT = 4 - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "[INST] \nWhat's the content of the image? [/INST]", - "cherry_blossom": - "[INST] \nWhat is the season? [/INST]", -}) - -models = ["llava-hf/llava-v1.6-mistral-7b-hf"] - - -@pytest.fixture() -def get_max_llava_next_image_tokens(): - from vllm.model_executor.models.llava_next import ( - get_max_llava_next_image_tokens) - return get_max_llava_next_image_tokens - - -@pytest.fixture() -def dummy_data_for_llava_next(): - from vllm.model_executor.models.llava_next import dummy_data_for_llava_next - return dummy_data_for_llava_next - - -def vllm_to_hf_output(vllm_output: Tuple[List[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - config = AutoConfig.from_pretrained(model) - image_token_id = config.image_token_index - - tokenizer = AutoTokenizer.from_pretrained(model) - eos_token_id = tokenizer.eos_token_id - - hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) - if token_id != image_token_id or output_ids[idx - 1] != image_token_id - ] - - assert output_str[0] == " " - hf_output_str = output_str[1:] - if hf_output_ids[-1] == eos_token_id: - hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) - - return hf_output_ids, hf_output_str, out_logprobs - - -@overload -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, - model: str, - *, - size_factors: List[float], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - ... - - -@overload -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, - model: str, - *, - sizes: List[Tuple[int, int]], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - ... - - -def run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, - model: str, - *, - size_factors: Optional[List[float]] = None, - sizes: Optional[List[Tuple[int, int]]] = None, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - images = [asset.pil_image for asset in image_assets] - - if size_factors is not None: - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - elif sizes is not None: - inputs_per_image = [( - [prompt for _ in sizes], - [image.resize(size) for size in sizes], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - else: - raise ValueError("You must provide either `size_factors` or `sizes`") - - _run_test(hf_runner, - vllm_runner, - inputs_per_image, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend) - - -def _run_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], PromptImageInput]], - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - dtype=dtype, - max_model_len=10240, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True, - limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT - }) as vllm_model: - vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForVision2Seq) as hf_model: - hf_outputs_per_image = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): - # TODO: Check whether using original CLIPVisionModel can improve - # consistency against HF - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype, max_tokens, num_logprobs) -> None: - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - run_test( - hf_runner, - vllm_runner, - image_assets, - model, - size_factors=size_factors, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "sizes", - [[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]], -) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes, - dtype, max_tokens, num_logprobs) -> None: - run_test( - hf_runner, - vllm_runner, - image_assets, - model, - sizes=sizes, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets, - model, dtype, max_tokens, - num_logprobs) -> None: - stop_sign = image_assets[0].pil_image - cherry_blossom = image_assets[1].pil_image - - inputs = [( - [ - "[INST] \nDescribe 2 images. [/INST]", - "[INST] \nDescribe 2 images. [/INST]", - "[INST] \nDescribe 4 images. [/INST]", - "[INST] \nWhat is the season? [/INST]" - ], - [ - [stop_sign, cherry_blossom], - # Images with different sizes and aspect-ratios - [ - rescale_image_size(stop_sign, 0.1), - stop_sign, - ], - [ - stop_sign, - rescale_image_size(stop_sign, 0.25), - cherry_blossom.resize((183, 488)), - cherry_blossom.resize((488, 183)) - ], - cherry_blossom, - ])] - - _run_test( - hf_runner, - vllm_runner, - inputs, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("gridpoints,expected_max_tokens", [ - ([[336, 336]], 1176), - ([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], 2928), -]) -def test_get_max_llava_next_image_tokens(gridpoints, expected_max_tokens, - get_max_llava_next_image_tokens): - ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf") - - # Update the config image_grid_pinpoints - # and calculate the resulting max tokens - ctx.model_config.hf_config.image_grid_pinpoints = gridpoints - - actual_max_tokens = get_max_llava_next_image_tokens( - InputContext(ctx.model_config)) - - assert expected_max_tokens == actual_max_tokens - - -@pytest.mark.parametrize( - "gridpoints,expected_size", - [ - # One point; it has to be the largest - ([[336, 336]], (336, 336)), - # Default for most llava next models; the 2x2 tile is the largest - ([[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]], - (672, 672)), - # If two rectangular gridpoints are the same, the more vertical - # one has the higher feature count due to newline features - ([[336, 672], [672, 336]], (672, 336)) - ]) -def test_dummy_data_for_llava_next_feature_size(dummy_data_for_llava_next, - gridpoints, expected_size): - ctx = build_model_context(model_name="llava-hf/llava-v1.6-mistral-7b-hf") - - # Update the config image_grid_pinpoints - ctx.model_config.hf_config.image_grid_pinpoints = gridpoints - seq_len = 5000 # bigger than the max feature size for any image - - seq_data, mm_data = dummy_data_for_llava_next( - ctx, - seq_len=seq_len, - mm_counts={"image": 1}, - ) - - # The dummy data dims should match the gridpoint with the biggest feat size - assert mm_data["image"].height == expected_size[0] - assert mm_data["image"].width == expected_size[1] - assert len(seq_data.get_token_ids()) >= seq_len diff --git a/tests/models/decoder_only/vision_language/test_llava_next_video.py b/tests/models/decoder_only/vision_language/test_llava_next_video.py deleted file mode 100644 index 7b7b23c783e2a..0000000000000 --- a/tests/models/decoder_only/vision_language/test_llava_next_video.py +++ /dev/null @@ -1,226 +0,0 @@ -from typing import List, Optional, Tuple, Type, overload - -import pytest -from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer - -from vllm.multimodal.utils import (rescale_video_size, resize_video, - sample_frames_from_video) -from vllm.sequence import SampleLogprobs - -from ....conftest import VIDEO_ASSETS, HfRunner, VllmRunner, _VideoAssets -from ...utils import check_logprobs_close - -_PREFACE = ( - "A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's " - "questions.") - -HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ - "sample_demo_1": - f"{_PREFACE}USER: