Skip to content

Commit

Permalink
[Model] Expose InternVL2 max_dynamic_patch as a mm_processor_kwarg (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored Sep 30, 2024
1 parent 8e60afa commit 2ae25f7
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 61 deletions.
1 change: 1 addition & 0 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"max_dynamic_patch": 4},
)

placeholders = "\n".join(f"Image-{i}: <image>\n"
Expand Down
150 changes: 89 additions & 61 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import re
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
from functools import partial
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict, Union)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -122,6 +123,20 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
return blocks, target_width, target_height


def calculate_num_blocks_wrapper(hf_config: Dict[str, Any],
max_dynamic_patch: Optional[int] = None):
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
min_num = hf_config.min_dynamic_patch
image_size = hf_config.vision_config.image_size
use_thumbnail = hf_config.use_thumbnail
return partial(calculate_num_blocks,
min_num=min_num,
max_num=max_dynamic_patch,
image_size=image_size,
use_thumbnail=use_thumbnail)


# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
image_size: int,
Expand Down Expand Up @@ -168,62 +183,85 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
return pixel_values


def get_internvl_num_patches(image_size: int, patch_size: int,
downsample_ratio: float):
def image_to_pixel_values_wrapper(hf_config: Dict[str, Any],
max_dynamic_patch: Optional[int] = None):
image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
return partial(image_to_pixel_values,
input_size=image_size,
min_num=min_num,
max_num=max_dynamic_patch,
use_thumbnail=use_thumbnail)


def get_internvl_num_patches(hf_config: Dict[str, Any]):
vision_config = hf_config.vision_config
downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size
patch_size = vision_config.patch_size
return int(
get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
(downsample_ratio**2))


def get_max_internvl_image_tokens(ctx: InputContext):
def get_max_internvl_image_tokens(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config

if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
max_dynamic_patch = hf_config.max_dynamic_patch
if use_thumbnail:
if use_thumbnail and max_dynamic_patch > 1:
max_dynamic_patch += 1
downsample_ratio = hf_config.downsample_ratio

image_size = vision_config.image_size
patch_size = vision_config.patch_size
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)
num_patches = get_internvl_num_patches(hf_config)
return num_patches * max_dynamic_patch


def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
def get_max_internvl_image_size(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
hf_config = ctx.get_hf_config()
image_size = hf_config.vision_config.image_size

if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
if use_thumbnail and max_dynamic_patch > 1:
max_dynamic_patch += 1
width = image_size * max_dynamic_patch
height = image_size
return width, height


def input_processor_for_internvl(ctx: InputContext,
llm_inputs: LLMInputs,
*,
max_dynamic_patch: Optional[int] = None):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs

model_config = ctx.model_config
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config

image_size = vision_config.image_size
patch_size = vision_config.patch_size
downsample_ratio = hf_config.downsample_ratio
num_patches = get_internvl_num_patches(image_size, patch_size,
downsample_ratio)

image_data = multi_modal_data["image"]
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
num_patches = get_internvl_num_patches(hf_config)
num_blocks_calculator = calculate_num_blocks_wrapper(
hf_config, max_dynamic_patch)
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size,
use_thumbnail)
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_size = [num_blocks * num_patches]
elif is_list_of(image_data, Image.Image):
image_feature_size = []
for image in image_data:
width, height = image.size
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
max_num, image_size,
use_thumbnail)
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_size.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
Expand Down Expand Up @@ -253,31 +291,21 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data=multi_modal_data)


def input_mapper_for_internvl(ctx: InputContext, data: object):
def input_mapper_for_internvl(ctx: InputContext,
data: object,
*,
max_dynamic_patch: Optional[int] = None):
hf_config = ctx.get_hf_config()

use_thumbnail = hf_config.use_thumbnail
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
image_size = hf_config.vision_config.image_size

image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch)
if isinstance(data, Image.Image):
data = image_to_pixel_values(data,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail)
data = image_pixel_values_mapper(data)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because the images may have different num_patches
data = [
image_to_pixel_values(img,
image_size,
min_num,
max_num,
use_thumbnail=use_thumbnail) for img in data
]
data = [image_pixel_values_mapper(img) for img in data]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
Expand All @@ -292,35 +320,36 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
})


def dummy_data_for_internvl(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
def dummy_data_for_internvl(ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
max_dynamic_patch: Optional[int] = None):
num_images = mm_counts["image"]

image_feature_size = get_max_internvl_image_tokens(ctx)
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
vision_config = hf_config.vision_config

image_feature_size = get_max_internvl_image_tokens(
ctx, max_dynamic_patch=max_dynamic_patch)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)

seq_data = dummy_seq_data_for_clip(
vision_config,
hf_config.vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(IMG_CONTEXT,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)

image_size = vision_config.image_size
min_num = hf_config.min_dynamic_patch
max_num = hf_config.max_dynamic_patch
max_image_width = max_num * image_size
max_image_height = min_num * image_size
max_image_width, max_image_height = get_max_internvl_image_size(
ctx, max_dynamic_patch=max_dynamic_patch)

mm_data = dummy_image_for_clip(
vision_config,
hf_config.vision_config,
num_images,
image_width_override=max_image_width,
image_height_override=max_image_height,
Expand Down Expand Up @@ -470,7 +499,6 @@ def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> torch.Tensor:

if image_input["type"] == "image_embeds":
return image_input["data"]

Expand Down

0 comments on commit 2ae25f7

Please sign in to comment.