Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support video frames. #151

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,6 @@ cython_debug/
**/*.pyc
*.json
__pycache__
controller.log.*
tinychat/serve/gradio_web_server.log.*
tinychat/serve/model_worker_*
64 changes: 59 additions & 5 deletions tinychat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,69 @@ def clear_after_click_example_3_image_icl(imagebox, imagebox_2, imagebox_3, text


def add_images(
state, imagebox, imagebox_2, imagebox_3, image_process_mode, request: gr.Request
state, imagebox, imagebox_2, imagebox_3, videobox, image_process_mode, request: gr.Request
):
if state.image_loaded:
# return (state,) + (None,) * IMAGE_BOX_NUM
return state

def extract_frames(video_path):
import cv2
from PIL import Image
vidcap = cv2.VideoCapture(video_path)
fps = vidcap.get(cv2.CAP_PROP_FPS)
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = frame_count / fps

frame_interval = frame_count // 10
print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
# frame_interval = 10

def get_frame(max_frames):
# frame_id = int(fps * stamp)
# vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
# ret, frame = vidcap.read()
images = []
count = 0
success = True
while success:
success, frame = vidcap.read()
if count % frame_interval:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
im_pil = Image.fromarray(img)
images.append(im_pil)
if len(images) == max_frames:
return images

count += 1
# assert ret, "videocap.read fails!"
# img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# im_pil = Image.fromarray(img)
# print(f"loading {stamp} success")
return images

# return [get_frame(0), get_frame(stamp1), get_frame(stamp2)]
# img = get_frame(0)
# img1 = get_frame(frame_interval * 1)
# return [img, img1, img, img1, img, img1,]
return get_frame(8)

frames = [None, ]
if videobox is not None:
frames = extract_frames(videobox)
# add frames as regular images
logger.info(f"Got videobox: {videobox}.")

logger.info(f"add_image. ip: {request.client.host}.")
image_list = [imagebox, imagebox_2, imagebox_3, *frames]
logger.info(f"image_list: {image_list}")

im_count = 0
for image in [imagebox, imagebox_2, imagebox_3]:
for image in image_list:
if image is not None:
im_count += 1
for image in [imagebox, imagebox_2, imagebox_3]:

for image in image_list:
if image is not None:
if args.auto_pad_image_token or im_count == 1:
text = (AUTO_FILL_IM_TOKEN_HOLDER, image, image_process_mode)
Expand All @@ -222,6 +274,7 @@ def add_images(
# state.append_message(state.roles[0], text)
# state.append_message(state.roles[1], None)
# state.skip_next = False
logger.info(f"im_count {im_count}. ip: {request.client.host}.")
state.image_loaded = True
# return (state,) + (None,) * IMAGE_BOX_NUM
return state
Expand Down Expand Up @@ -564,6 +617,7 @@ def build_demo(embed_mode):
imagebox = gr.Image(type="pil")
imagebox_2 = gr.Image(type="pil")
imagebox_3 = gr.Image(type="pil")
videobox = gr.Video(label="1 video = 8 frames")
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad", "Default"],
value="Default",
Expand Down Expand Up @@ -841,7 +895,7 @@ def build_demo(embed_mode):
clear_text_history, [state, prompt_style_btn], [state, chatbot], queue=False
).then(
add_images,
[state, imagebox, imagebox_2, imagebox_3, image_process_mode],
[state, imagebox, imagebox_2, imagebox_3, videobox, image_process_mode],
[state],
queue=False,
).then(
Expand All @@ -863,7 +917,7 @@ def build_demo(embed_mode):
clear_text_history, [state, prompt_style_btn], [state, chatbot], queue=False
).then(
add_images,
[state, imagebox, imagebox_2, imagebox_3, image_process_mode],
[state, imagebox, imagebox_2, imagebox_3, videobox, image_process_mode],
[state],
queue=False,
).then(
Expand Down
2 changes: 1 addition & 1 deletion tinychat/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def init():
global max_seq_len, max_batch_size, llama_multiple_of, mem_efficient_load
max_seq_len = 2048
max_seq_len = 5120
max_batch_size = 1
llama_multiple_of = 256
mem_efficient_load = False # Whether to load the checkpoint in a layer-wise manner. Activate this if you are facing OOM issues on edge devices (e.g., Jetson Orin).
Expand Down