diff --git a/tests/conftest.py b/tests/conftest.py index 999ca60d07a4f..c7a349f1e9e2a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import os import sys from collections import UserList -from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar +from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union import pytest import torch @@ -508,7 +508,8 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, - images: Optional[List[Image.Image]] = None, + images: Optional[Union[List[Image.Image], + List[List[Image.Image]]]] = None, stop_token_ids: Optional[List[int]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, diff --git a/tests/models/test_minicpmv.py b/tests/models/test_minicpmv.py index c57f0f8c08548..c3b2a7bcbaafd 100644 --- a/tests/models/test_minicpmv.py +++ b/tests/models/test_minicpmv.py @@ -14,6 +14,18 @@ pytestmark = pytest.mark.vlm + +class NestedInputs(UserDict): + + def __init__(self, model_inputs: BatchFeature): + super().__init__({"model_inputs": model_inputs}) + + self.model_inputs = model_inputs + + def to(self, device: torch.types.Device): + return NestedInputs(self.model_inputs.to(device)) + + # The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -23,7 +35,7 @@ "cherry_blossom": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \ "(./)\nWhat is the season?<|eot_id|>" \ - "<|start_header_id|>assistant<|end_header_id|>\n\n" + "<|start_header_id|>assistant<|end_header_id|>\n\n", }) models = ["openbmb/MiniCPM-Llama3-V-2_5"] @@ -94,22 +106,10 @@ def run_test( ] with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): - - class NestedInputs(UserDict): - - def __init__(self, model_inputs: BatchFeature): - super().__init__({"model_inputs": model_inputs}) - - self.model_inputs = model_inputs - - def to(self, device: torch.types.Device): - return NestedInputs(self.model_inputs.to(device)) - hf_processor = hf_model.processor hf_model.processor = lambda **kw: NestedInputs( hf_processor(**kw) # type: ignore ) - hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, @@ -161,3 +161,123 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +HF_MULTIIMAGE_IMAGE_PROMPT = \ + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \ + "(./)\n(./)\n" \ + "Describe these images.<|eot_id|>" \ + "<|start_header_id|>assistant<|end_header_id|>\n\n" + + +def run_multi_image_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + images = [asset.pil_image for asset in image_assets] + + inputs_per_case = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + max_model_len=4096, + max_num_seqs=1, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + tokenizer = vllm_model.model.get_tokenizer() + stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id] + vllm_outputs_per_case = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + stop_token_ids=stop_token_ids) + for prompts, images in inputs_per_case + ] + + with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): + hf_processor = hf_model.processor + hf_model.processor = lambda **kw: NestedInputs( + hf_processor(**kw) # type: ignore + ) + hf_outputs_per_case = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + tokenizer=tokenizer) + for prompts, images in inputs_per_case + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): + check_logprobs_close( + outputs_0_lst=[ + trunc_hf_output(hf_output) for hf_output in hf_outputs + ], + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + run_multi_image_test( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 095bb49f6ba76..0388259595628 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -392,6 +392,20 @@ def forward(self, x: torch.Tensor, return x +def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: + version_float = getattr(config, "version", None) + + # The old configs do not include version number + # TODO: Remove this after the HF repos are updated + if version_float is None: + if config.hidden_size == 2304 and config.query_num == 64: + return (2, 0) + return (2, 5) + + version_str = str(version_float) + return tuple(int(x) for x in version_str.split(".")) + + def get_max_minicpmv_image_tokens(ctx: InputContext): hf_config = ctx.get_hf_config(PretrainedConfig) return getattr(hf_config, "query_num", 64) @@ -421,36 +435,43 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs - model_config = ctx.model_config - + version = get_version_by_config(model_config.hf_config) tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) + image_processor = cached_get_image_processor(model_config.tokenizer) + + def get_placeholder(image_size: Tuple[int, int], num_image: int): + if version == (2, 0) or version == (2, 5): + return image_processor. \ + get_slice_image_placeholder(image_size) + return image_processor. \ + get_slice_image_placeholder(image_size, num_image) prompt = llm_inputs.get("prompt") if prompt is None: token_ids = llm_inputs.get("prompt_token_ids") prompt = tokenizer.decode(token_ids) - image_processor = cached_get_image_processor(model_config.tokenizer) pattern = "(./)" - image = multi_modal_data["image"] + images = multi_modal_data["image"] + if isinstance(images, Image.Image): + images = [images] image_tags = re.findall(pattern, prompt) if len(image_tags) == 0: new_token_ids = token_ids new_prompt = prompt else: - if len(image_tags) > 1: - logger.warning("Multiple image input is not supported yet, " - "so any extra image tokens will be treated " - "as plain text.") - text_chunks = prompt.split(pattern) - new_prompt = (text_chunks[0] + - image_processor.get_slice_image_placeholder(image.size) + - "".join(text_chunks[1:])) - + new_prompt_chunks: List[str] = [] + for i in range(len(images)): + new_prompt_chunks += [ + text_chunks[i], + get_placeholder(images[i].size, i) + ] + new_prompt_chunks.append(text_chunks[-1]) + new_prompt = "".join(new_prompt_chunks) new_token_ids = tokenizer.encode(new_prompt) llm_inputs = LLMInputs( @@ -478,14 +499,7 @@ def __init__( self.config = config self.multimodal_config = multimodal_config - if not hasattr(self.config, "version"): - if self.config.hidden_size == 2304 and self.config.query_num == 64: - self.version = (2, 0) - else: - self.version = (2, 5) - else: - self.version = str(self.config.version).split(".") - self.version = tuple([int(x) for x in self.version]) + self.version = get_version_by_config(self.config) self.llm = self.init_llm(config, cache_config, quant_config) self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 3b37ce9149fb8..b6a3909e95632 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -113,7 +113,7 @@ def _get_hf_image_processor(self, model_config: ModelConfig): def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: model_config = ctx.model_config - if isinstance(data, Image.Image): + if isinstance(data, (Image.Image, list)): image_processor = self._get_hf_image_processor(model_config) if image_processor is None: raise RuntimeError("No HuggingFace processor is available "