From 2ae25f79cf1e8d21f7bcba097e4c039463c22be4 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 30 Sep 2024 13:01:20 +0800 Subject: [PATCH] [Model] Expose InternVL2 max_dynamic_patch as a mm_processor_kwarg (#8946) --- ...e_inference_vision_language_multi_image.py | 1 + vllm/model_executor/models/internvl.py | 150 +++++++++++------- 2 files changed, 90 insertions(+), 61 deletions(-) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 1e99c02234d01..66936ab125b81 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -115,6 +115,7 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: trust_remote_code=True, max_model_len=4096, limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"max_dynamic_patch": 4}, ) placeholders = "\n".join(f"Image-{i}: \n" diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b1748700d481a..e84990a2ab109 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -5,8 +5,9 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import re -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union) +from functools import partial +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, + Tuple, TypedDict, Union) import torch import torch.nn as nn @@ -122,6 +123,20 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, return blocks, target_width, target_height +def calculate_num_blocks_wrapper(hf_config: Dict[str, Any], + max_dynamic_patch: Optional[int] = None): + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + min_num = hf_config.min_dynamic_patch + image_size = hf_config.vision_config.image_size + use_thumbnail = hf_config.use_thumbnail + return partial(calculate_num_blocks, + min_num=min_num, + max_num=max_dynamic_patch, + image_size=image_size, + use_thumbnail=use_thumbnail) + + # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int, image_size: int, @@ -168,62 +183,85 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int, return pixel_values -def get_internvl_num_patches(image_size: int, patch_size: int, - downsample_ratio: float): +def image_to_pixel_values_wrapper(hf_config: Dict[str, Any], + max_dynamic_patch: Optional[int] = None): + image_size = hf_config.vision_config.image_size + min_num = hf_config.min_dynamic_patch + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + use_thumbnail = hf_config.use_thumbnail + return partial(image_to_pixel_values, + input_size=image_size, + min_num=min_num, + max_num=max_dynamic_patch, + use_thumbnail=use_thumbnail) + + +def get_internvl_num_patches(hf_config: Dict[str, Any]): + vision_config = hf_config.vision_config + downsample_ratio = hf_config.downsample_ratio + image_size = vision_config.image_size + patch_size = vision_config.patch_size return int( get_clip_num_patches(image_size=image_size, patch_size=patch_size) * (downsample_ratio**2)) -def get_max_internvl_image_tokens(ctx: InputContext): +def get_max_internvl_image_tokens(ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None): hf_config = ctx.get_hf_config() - vision_config = hf_config.vision_config + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail - max_dynamic_patch = hf_config.max_dynamic_patch - if use_thumbnail: + if use_thumbnail and max_dynamic_patch > 1: max_dynamic_patch += 1 - downsample_ratio = hf_config.downsample_ratio - image_size = vision_config.image_size - patch_size = vision_config.patch_size - num_patches = get_internvl_num_patches(image_size, patch_size, - downsample_ratio) + num_patches = get_internvl_num_patches(hf_config) return num_patches * max_dynamic_patch -def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): +def get_max_internvl_image_size(ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None): + hf_config = ctx.get_hf_config() + image_size = hf_config.vision_config.image_size + + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + use_thumbnail = hf_config.use_thumbnail + if use_thumbnail and max_dynamic_patch > 1: + max_dynamic_patch += 1 + width = image_size * max_dynamic_patch + height = image_size + return width, height + + +def input_processor_for_internvl(ctx: InputContext, + llm_inputs: LLMInputs, + *, + max_dynamic_patch: Optional[int] = None): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs model_config = ctx.model_config hf_config = ctx.get_hf_config() - vision_config = hf_config.vision_config - - image_size = vision_config.image_size - patch_size = vision_config.patch_size - downsample_ratio = hf_config.downsample_ratio - num_patches = get_internvl_num_patches(image_size, patch_size, - downsample_ratio) image_data = multi_modal_data["image"] - min_num = hf_config.min_dynamic_patch - max_num = hf_config.max_dynamic_patch - use_thumbnail = hf_config.use_thumbnail + num_patches = get_internvl_num_patches(hf_config) + num_blocks_calculator = calculate_num_blocks_wrapper( + hf_config, max_dynamic_patch) if isinstance(image_data, Image.Image): width, height = image_data.size - num_blocks, _, _ = calculate_num_blocks(width, height, min_num, - max_num, image_size, - use_thumbnail) + num_blocks, _, _ = num_blocks_calculator(width, height) image_feature_size = [num_blocks * num_patches] elif is_list_of(image_data, Image.Image): image_feature_size = [] for image in image_data: width, height = image.size - num_blocks, _, _ = calculate_num_blocks(width, height, min_num, - max_num, image_size, - use_thumbnail) + num_blocks, _, _ = num_blocks_calculator(width, height) image_feature_size.append(num_blocks * num_patches) elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape @@ -253,31 +291,21 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data=multi_modal_data) -def input_mapper_for_internvl(ctx: InputContext, data: object): +def input_mapper_for_internvl(ctx: InputContext, + data: object, + *, + max_dynamic_patch: Optional[int] = None): hf_config = ctx.get_hf_config() - use_thumbnail = hf_config.use_thumbnail - min_num = hf_config.min_dynamic_patch - max_num = hf_config.max_dynamic_patch - image_size = hf_config.vision_config.image_size - + image_pixel_values_mapper = image_to_pixel_values_wrapper( + hf_config, max_dynamic_patch) if isinstance(data, Image.Image): - data = image_to_pixel_values(data, - image_size, - min_num, - max_num, - use_thumbnail=use_thumbnail) + data = image_pixel_values_mapper(data) # Add an N dimension for number of images per prompt (currently 1). data = data.unsqueeze(0) elif is_list_of(data, Image.Image): # we can't stack here because the images may have different num_patches - data = [ - image_to_pixel_values(img, - image_size, - min_num, - max_num, - use_thumbnail=use_thumbnail) for img in data - ] + data = [image_pixel_values_mapper(img) for img in data] model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, @@ -292,20 +320,24 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): }) -def dummy_data_for_internvl(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): +def dummy_data_for_internvl(ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + max_dynamic_patch: Optional[int] = None): num_images = mm_counts["image"] - image_feature_size = get_max_internvl_image_tokens(ctx) - model_config = ctx.model_config hf_config = ctx.get_hf_config() - vision_config = hf_config.vision_config + + image_feature_size = get_max_internvl_image_tokens( + ctx, max_dynamic_patch=max_dynamic_patch) + model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, trust_remote_code=model_config.trust_remote_code) seq_data = dummy_seq_data_for_clip( - vision_config, + hf_config.vision_config, seq_len, num_images, image_token_id=tokenizer.encode(IMG_CONTEXT, @@ -313,14 +345,11 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int, image_feature_size_override=image_feature_size, ) - image_size = vision_config.image_size - min_num = hf_config.min_dynamic_patch - max_num = hf_config.max_dynamic_patch - max_image_width = max_num * image_size - max_image_height = min_num * image_size + max_image_width, max_image_height = get_max_internvl_image_size( + ctx, max_dynamic_patch=max_dynamic_patch) mm_data = dummy_image_for_clip( - vision_config, + hf_config.vision_config, num_images, image_width_override=max_image_width, image_height_override=max_image_height, @@ -470,7 +499,6 @@ def _process_image_input( self, image_input: InternVLImageInputs, ) -> torch.Tensor: - if image_input["type"] == "image_embeds": return image_input["data"]