From 1396876744c59c94352e273908a9edb5d201b028 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 27 Sep 2024 16:15:58 +0800 Subject: [PATCH] [Bugfix][VLM] Fix Fuyu batching inference with `max_num_seqs>1` (#8892) --- .../decoder_only/vision_language/test_fuyu.py | 6 +-- vllm/model_executor/models/fuyu.py | 51 +++++++++++++------ 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_fuyu.py b/tests/models/decoder_only/vision_language/test_fuyu.py index 94b8431424db5..7827ecb19a744 100644 --- a/tests/models/decoder_only/vision_language/test_fuyu.py +++ b/tests/models/decoder_only/vision_language/test_fuyu.py @@ -65,8 +65,8 @@ def run_test( # max_model_len should be greater than image_feature_size with vllm_runner(model, - max_model_len=2560, - max_num_seqs=1, + max_model_len=2048, + max_num_seqs=2, dtype=dtype, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, @@ -80,8 +80,6 @@ def run_test( ] with hf_runner(model, dtype=dtype) as hf_model: - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() eos_token_id = hf_model.processor.tokenizer.eos_token_id hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index d50f4fb9e6ed4..9f4dca78d435d 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -42,7 +42,7 @@ SequenceData) from .interfaces import SupportsMultiModal -from .utils import merge_multimodal_embeddings +from .utils import flatten_bn, merge_multimodal_embeddings # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): model_config.model) model_image_input = _fuyu_image_preprocess(image_processor, image_data) - image_patches = torch.stack([ + image_patches = torch.cat([ image_patch[0] for image_patch in model_image_input["image_patches"] ]) @@ -210,7 +210,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): ]) # image has been processed with prompt in input processor - return MultiModalInputs({"image_patches": data}) + return MultiModalInputs({"pixel_values": data}) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu) @@ -242,23 +242,42 @@ def __init__(self, cache_config=cache_config, quant_config=quant_config) + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + h = w = self.config.patch_size + num_channels = self.config.num_channels + expected_dims = num_channels * h * w + + def _validate_shape(d: torch.Tensor): + actual_dims = d.size(-1) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data.to(self.vision_embed_tokens.weight.dtype) + def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[FuyuImagePixelInputs]: - image_patches = kwargs.pop("image_patches", None) + pixel_values = kwargs.pop("pixel_values", None) - if isinstance(image_patches, torch.Tensor): - # Remove the N dimension until multiple images are supported. - image_patches = image_patches.squeeze(1) + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image patches. " + f"Got type: {type(pixel_values)}") + + return FuyuImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) - expected_feature_size = self.image_feature_size - if image_patches.size(-1) != expected_feature_size: - raise ValueError( - f"Expected image patches to have the last dimension of " - f"{expected_feature_size}, got {image_patches.size(-1)}") - image_patches = image_patches.to( - self.vision_embed_tokens.weight.dtype) - return FuyuImagePixelInputs(type="pixel_values", - data=image_patches) return None def _process_image_input(