Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradio demo for groma #36

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions groma/serve/gradio_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import os
import copy
import torch
import random
import requests
from io import BytesIO
from PIL import Image, ImageDraw
import gradio as gr
from transformers.image_transforms import center_to_corners_format
from transformers import AutoTokenizer, AutoImageProcessor, BitsAndBytesConfig
from groma.utils import disable_torch_init
from groma.model.groma import GromaModel
from groma.constants import DEFAULT_TOKENS
from groma.data.conversation import conv_templates
import argparse


def load_model(model_name, quant_type):
# Model
disable_torch_init()
model_name = os.path.expanduser(model_name)
vis_processor = AutoImageProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

kwargs = {}
if quant_type == 'fp16':
kwargs['torch_dtype'] = torch.float16
elif quant_type == '8bit':
kwargs['load_in_8bit'] = True
elif quant_type == '4bit':
int4_quant_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_storage=torch.uint8,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type='nf4'
)
kwargs = {'quantization_config': int4_quant_cfg}

if quant_type == '8bit' or quant_type == '4bit':
model = GromaModel.from_pretrained(model_name, **kwargs)
else:
model = GromaModel.from_pretrained(model_name, **kwargs).cuda()
print('----------Model loaded to GPU')

model.init_special_token_id(tokenizer)
return model, vis_processor, tokenizer


def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str,
default="FoundationVision/groma-7b-finetune")
parser.add_argument("--image_dir", type=str, default=None)
parser.add_argument("--image_file", type=str, default=None)
parser.add_argument("--output_dir", type=str, default='output')
parser.add_argument("--query", type=str, default=None)
parser.add_argument("--quant_type", type=str, default='none')

args = parser.parse_args()
return args


def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image


def generate_random_color():
"""Generate a random RGB color."""
return tuple(random.randint(0, 255) for _ in range(3))


def draw_box(box, image, color):
"""Draw a bounding box with a given color."""
w, h = image.size
box = [box[0] * w, box[1] * h, box[2] * w, box[3] * h]
draw = ImageDraw.Draw(image)
draw.rectangle(box, outline=color, width=3)
return image


def eval_model(model, vis_processor, tokenizer, quant_type, image_file, query):

conversations = []
instruct = "Here is an image with region crops from it. "
instruct += "Image: {}. ".format(DEFAULT_TOKENS['image'])
instruct += "Regions: {}.".format(DEFAULT_TOKENS['region'])
answer = 'Thank you for the image! How can I assist you with it?'
conversations.append((conv_templates['llava'].roles[0], instruct))
conversations.append((conv_templates['llava'].roles[1], answer))
conversations.append((conv_templates['llava'].roles[0], query))
conversations.append((conv_templates['llava'].roles[1], ''))
prompt = conv_templates['llava'].get_prompt(conversations)

inputs = tokenizer([prompt])
input_ids = torch.as_tensor(inputs.input_ids).cuda()

raw_image = load_image(image_file)
raw_image = raw_image.resize((448, 448))
image = vis_processor.preprocess(raw_image, return_tensors='pt')[
'pixel_values'].to('cuda')

with torch.inference_mode():
with torch.autocast(device_type="cuda"):
outputs = model.generate(
input_ids,
images=image,
use_cache=True,
do_sample=False,
max_new_tokens=1024,
return_dict_in_generate=True,
output_hidden_states=True,
generation_config=model.generation_config,
)

output_ids = outputs.sequences
input_token_len = input_ids.shape[1]
pred_boxes = outputs.hidden_states[0][-1]['pred_boxes'][0].cpu()
pred_boxes = center_to_corners_format(pred_boxes)

img_copy = copy.deepcopy(raw_image)
box_idx_token_ids = model.box_idx_token_ids
selected_box_inds = [box_idx_token_ids.index(
id) for id in output_ids[0] if id in box_idx_token_ids]
selected_box_inds = [x for x in selected_box_inds if x < len(pred_boxes)]

for i, box in enumerate(pred_boxes[selected_box_inds, :]):
color = generate_random_color()
img_copy = draw_box(box, img_copy, color)

n_diff_input_output = (
input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(
f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')

outputs = tokenizer.batch_decode(
output_ids[:, input_token_len:], skip_special_tokens=False)[0]
outputs = outputs.strip()

return img_copy, outputs


args = parse_args()
model, vis_processor, tokenizer = load_model(args.model_name, args.quant_type)


def chatbot_interface(image, query, model_name="checkpoints/groma-finetune/", quant_type='none'):
img_result, response = eval_model(
model, vis_processor, tokenizer, quant_type, image, query)
return img_result, response


title = """<h1 align="center"> Groma demo </h1>"""
description = '''
Welcome to Groma Demo!
Upload an image, ask a question, and receive the models response with bounding boxes on the relevant image areas.
'''


iface = gr.Interface(
fn=chatbot_interface,

inputs=[
gr.inputs.Image(type="filepath", label="Upload Image"),
gr.inputs.Textbox(lines=2, placeholder="Query here.", label="Query"),
gr.inputs.Dropdown(choices=[
'none', 'fp16', '8bit', '4bit'], default="none", label="Quantization Type")
],
outputs=[
gr.outputs.Image(type="pil", label="Processed Image"),
gr.outputs.Textbox(label="Model Response")
],
title=title,
description=description,
examples=[
["examples_v2/1.jpg", "[vqa]please locate female wearing a green T-shirt?", "none"],
["examples_v2/2.jpg",
"[grounding]Please detect all the males and describe their activities.", "none"],
["examples_v2/4.jpg", "[grounding]locate females walking in the image / [grounding]locate the person sitting in the image", "none"],
["examples_v2/5.jpg", "locate the female sitting in wheelchair.", "fp16"]
]
)

with gr.Blocks() as demo:
gr.Markdown("""![](file/examples_v2/logo.jpg)""")
iface.render()

demo.launch(share=True, enable_queue=True)