Skip to content

Commit

Permalink
Initial support for llava-next-video
Browse files Browse the repository at this point in the history
* Inference with llava-hf/llava-next-video* (with bugs)
* Add VideoPlugin.
* Add example for llava-next-video
  • Loading branch information
TKONIY committed Aug 15, 2024
1 parent fc93e56 commit d95701b
Show file tree
Hide file tree
Showing 6 changed files with 590 additions and 11 deletions.
67 changes: 57 additions & 10 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,15 @@

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.utils import FlexibleArgumentParser

# Input image and question
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
question = "What is the content of this image?"


# LLaVA-1.5
def run_llava(question):

prompt = f"USER: <image>\n{question}\nASSISTANT:"

llm = LLM(model="llava-hf/llava-1.5-7b-hf")
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf")
stop_token_ids = None
return llm, prompt, stop_token_ids

Expand All @@ -34,6 +30,13 @@ def run_llava_next(question):
stop_token_ids = None
return llm, prompt, stop_token_ids

# LlaVA-NeXT-Video
# Currently only support for video input
def run_llava_next_video(question):
prompt = f"[INST] <video>\n{question} [/INST]"
llm = LLM(model="llava-hf/llava-next-video-7b-hf")
stop_token_ids = None
return llm, prompt, stop_token_ids

# Fuyu
def run_fuyu(question):
Expand Down Expand Up @@ -162,6 +165,7 @@ def run_blip2(question):
model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
"llava-next-video": run_llava_next_video,
"fuyu": run_fuyu,
"phi3_v": run_phi3v,
"paligemma": run_paligemma,
Expand All @@ -171,12 +175,48 @@ def run_blip2(question):
"internvl_chat": run_internvl,
}

def get_multi_modal_input(args):
"""
return {
"data": image or video,
"question": question,
}
"""
if args.modality == "image":
# Input image and question
image = ImageAsset("cherry_blossom") \
.pil_image.convert("RGB")
img_question = "What is the content of this image?"

return {
"data": image,
"question": img_question,
}

if args.modality == "video":
# Input video and question
video = VideoAsset(name="sample_demo_1.mp4",
num_frames=args.num_frames).pil_images
vid_question = "Why is this video funny?"

return {
"data": video,
"question": vid_question,
}

msg = f"Modality {args.modality} is not supported."
raise ValueError(msg)

def main(args):
model = args.model_type
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")

modality = args.modality
mm_input = get_multi_modal_input(args)
data = mm_input["data"]
question = mm_input["question"]

llm, prompt, stop_token_ids = model_example_map[model](question)

# We set temperature to 0.2 so that outputs can be different
Expand All @@ -191,7 +231,7 @@ def main(args):
inputs = {
"prompt": prompt,
"multi_modal_data": {
"image": image
modality: data
},
}

Expand All @@ -200,7 +240,7 @@ def main(args):
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"image": image
modality: data
},
} for _ in range(args.num_prompts)]

Expand All @@ -218,13 +258,20 @@ def main(args):
parser.add_argument('--model-type',
'-m',
type=str,
default="llava",
default="llava-next-video",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument('--num-prompts',
type=int,
default=1,
help='Number of prompts to run.')

parser.add_argument('--modality',
type=str,
default="video",
help='Modality of the input.')
parser.add_argument('--num-frames',
type=int,
default=16,
help='Number of frames to extract from the video.')
args = parser.parse_args()
main(args)
81 changes: 81 additions & 0 deletions vllm/assets/video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Literal, List
from huggingface_hub import hf_hub_download
import cv2
import numpy as np

from PIL import Image

from .base import get_cache_dir


@lru_cache
def download_video_asset(filename: str) -> str:
"""
Download and open an image from huggingface
repo: raushan-testing-hf/videos-test
"""
video_directory = get_cache_dir() / "video-eample-data"
video_directory.mkdir(parents=True, exist_ok=True)

video_path = video_directory / filename
if not video_path.exists():
video_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test",
filename=filename,
repo_type="dataset",
cache_dir=video_directory,
)
return video_path


def video_to_ndarrays_list(
path: str, num_frames: int) -> List[np.ndarray]:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")

total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

frames = []
for i in range(total_frames):
ret, frame = cap.read()
if ret:
frames.append(frame)

cap.release()

frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = [frames[i] for i in frame_indices if i < len(frames)]

if len(frames) < num_frames:
raise ValueError(
f"Could not read enough frames from video file {path}"
f" (expected {num_frames} frames, got {len(frames)})")

return frames

def video_to_pil_images_list(
path: str, num_frames: int) -> List[Image.Image]:
frames = video_to_ndarrays_list(path, num_frames)
return [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
for frame in frames]


@dataclass(frozen=True)
class VideoAsset:
name: Literal["sample_demo_1.mp4"]
num_frames: int

@property
def pil_images(self) -> List[Image.Image]:
video_path = download_video_asset(self.name)
return video_to_pil_images_list(video_path, self.num_frames)

@property
def np_ndarrays(self) -> List[np.ndarray]:
video_path = download_video_asset(self.name)
return video_to_ndarrays_list(video_path, self.num_frames)


2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration":
("llava_next", "LlavaNextForConditionalGeneration"),
"LlavaNextVideoForConditionalGeneration":
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
Expand Down
Loading

0 comments on commit d95701b

Please sign in to comment.