-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinference_gradio.py
120 lines (100 loc) · 4.07 KB
/
inference_gradio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from my_model import unet_2d_condition
import json
from PIL import Image
import hydra
import os
import numpy as np
import random
from gradio.components.image_editor import Brush
import gradio as gr
import datetime
from inference import inference
from utils import concat_images,setup_logger,colors,draw_traces
import cv2
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
@hydra.main(version_base=None, config_path="conf", config_name="base_config")
def main(cfg):
setup_seed(cfg.inference.rand_seed)
# build and load model
with open(cfg.general.unet_config) as f:
unet_config = json.load(f)
unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained(cfg.general.model_path,
subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained(cfg.general.model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(cfg.general.model_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(cfg.general.model_path, subfolder="vae")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet.to(device)
text_encoder.to(device)
vae.to(device)
def progress(img_dict_np, prompt, phrase):
cfg.general.save_path = './traces_output'+str(datetime.datetime.now())
if not os.path.exists(cfg.general.save_path):
os.makedirs(cfg.general.save_path)
img_np = img_dict_np['composite']
img_np = cv2.resize(img_np, (512,512), interpolation=cv2.INTER_NEAREST)
maps = []
pil = Image.new('RGB', (512, 512), color=(167, 179, 195))
for i in range(len(colors)):
non_zero_indices = np.where(np.all(img_np == colors[i], axis=-1))
if non_zero_indices[0].size == 0:
continue
bk = np.zeros((512, 512), dtype=np.uint8)
bk[non_zero_indices] = 1
maps.append(bk)
for y,x in zip(non_zero_indices[0],non_zero_indices[1]):
pil.putpixel((x, y), colors[i])
pil.save(os.path.join(cfg.general.save_path,"traces.jpg"))
# Prepare examples
examples = {"prompt": prompt,
"phrases": phrase,
'save_path': cfg.general.save_path
}
# Prepare the save path
if not os.path.exists(cfg.general.save_path):
os.makedirs(cfg.general.save_path)
logger = setup_logger(cfg.general.save_path, __name__)
logger.info(cfg)
pil_images = inference(device, unet, vae, tokenizer, text_encoder, examples['prompt'], maps, examples['phrases'], cfg, logger)
pil_images.append(draw_traces(pil_images[0].copy(),maps, examples['phrases']))
for i,img in enumerate(pil_images):
image_path = os.path.join(cfg.general.save_path, 'example_{}.png'.format(i))
img.save(image_path)
horizontal_concatenated = concat_images(pil_images, examples['prompt'])
return horizontal_concatenated
white_image_np = np.zeros((512,512, 3), dtype=np.uint8)*255
iface = gr.Interface(
fn=progress,
inputs=[
gr.Sketchpad(
value=white_image_np,
height = "80%",
width = "80%",
type='numpy',
brush=Brush(colors=[
"#90EE90",
"#FFA500",
"#FF7F50",
"#FF0000",
"#0000FF"
], color_mode="selecte",
default_size=3),
image_mode='RGB'
),
gr.Textbox(label='prompt'),
gr.Textbox(label='phrase')
],
outputs='image',
title='Traces-guidance'
)
iface.launch()
if __name__ == "__main__":
main()