Skip to content

Commit

Permalink
[Model][VLM] Add multi-video support for LLaVA-Onevision (vllm-projec…
Browse files Browse the repository at this point in the history
…t#8905)

Co-authored-by: litianjian <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
Signed-off-by: Linkun Chen <[email protected]>
  • Loading branch information
3 people authored and Linkun Chen committed Nov 4, 2024
1 parent 5c31233 commit b3ec999
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 162 deletions.
173 changes: 48 additions & 125 deletions tests/models/decoder_only/vision_language/test_llava_onevision.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Type, overload
from typing import List, Optional, Tuple, Type

import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
Expand All @@ -9,9 +9,8 @@
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE

from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_VideoAssets)
from ....utils import large_gpu_test
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput,
PromptVideoInput, VllmRunner)
from ...utils import check_logprobs_close

# Video test
Expand All @@ -20,7 +19,7 @@
"<|im_start|>user\n<video>\nwhy is this video funny?<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
})

models = ["llava-hf/llava-onevision-qwen2-7b-ov-hf"]
models = ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Expand All @@ -47,50 +46,16 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
return hf_output_ids, hf_output_str, out_logprobs


@overload
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...


@overload
def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
num_frames: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
# Video test
_LIMIT_VIDEO_PER_PROMPT = 4


def run_video_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
video_assets: _VideoAssets,
inputs: List[Tuple[List[str], PromptVideoInput]],
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
Expand All @@ -99,38 +64,20 @@ def run_video_test(
distributed_executor_backend: Optional[str] = None,
):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]

videos = [
sample_frames_from_video(asset.np_ndarrays, num_frames)
for asset in video_assets
]

if size_factors is not None:
inputs_per_video = [(
[prompt for _ in size_factors],
[rescale_video_size(video, factor) for factor in size_factors],
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
elif sizes is not None:
inputs_per_video = [(
[prompt for _ in sizes],
[resize_video(video, size) for size in sizes],
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
max_model_len=16384,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_video = [
enforce_eager=True,
limit_mm_per_prompt={"video": _LIMIT_VIDEO_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_input = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
videos=videos)
for prompts, videos in inputs_per_video
for prompts, videos in inputs
]

def process(hf_inputs: BatchEncoding):
Expand All @@ -142,16 +89,16 @@ def process(hf_inputs: BatchEncoding):
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_video = [
hf_outputs_per_input = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
videos=videos)
for prompts, videos in inputs_per_video
for prompts, videos in inputs
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_video,
vllm_outputs_per_video):
for hf_outputs, vllm_outputs in zip(hf_outputs_per_input,
vllm_outputs_per_input):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
Expand All @@ -165,74 +112,51 @@ def process(hf_inputs: BatchEncoding):
)


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No video
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_frames", [16])
def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
dtype, max_tokens, num_logprobs, num_frames) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/videos.
For huggingface runner, we provide the np.ndarray as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
def test_models_multiple_video_inputs(hf_runner, vllm_runner, video_assets,
model, dtype, max_tokens, num_logprobs,
num_frames) -> None:
video = sample_frames_from_video(video_assets[0].np_ndarrays, num_frames)
inputs = [(
[
"<|im_start|>user <video><video>\nDescribe 2 videos. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <video><video>\nDescribe 2 videos. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <video><video><video><video>\nDescribe 4 videos. \
<|im_end|><|im_start|>assistant\n",
"<|im_start|>user <video>\nwhy is this video funny? \
<|im_end|><|im_start|>assistant\n",
],
[
[video, video],
# Images with different sizes and aspect-ratios
[
rescale_video_size(video, 0.1),
video,
],
[
video,
rescale_video_size(video, 0.25),
resize_video(video, (183, 488)),
resize_video(video, (488, 183))
],
video,
])]
run_video_test(
hf_runner,
vllm_runner,
video_assets,
inputs,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
num_frames=num_frames,
tensor_parallel_size=1,
)


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_frames", [16])
def test_models_fixed_sizes(hf_runner, vllm_runner, video_assets, model, sizes,
dtype, max_tokens, num_logprobs,
num_frames) -> None:
run_video_test(
hf_runner,
vllm_runner,
video_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
num_frames=num_frames,
tensor_parallel_size=1,
)


Expand Down Expand Up @@ -303,7 +227,6 @@ def process(hf_inputs: BatchEncoding):
)


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def dummy_image_for_clip(
def dummy_video_for_clip(
hf_config: CLIPVisionConfig,
num_frames: int,
num_videos: int = 1,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
Expand All @@ -99,7 +100,8 @@ def dummy_video_for_clip(
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
mm_data = {"video": mm_data_per_video}
video_data = [mm_data_per_video] * num_videos
mm_data = {"video": video_data}
return mm_data


Expand Down
Loading

0 comments on commit b3ec999

Please sign in to comment.