-
Notifications
You must be signed in to change notification settings - Fork 62
/
app.py
65 lines (54 loc) · 2.14 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import glob
import gradio as gr
import numpy as np
from inference import get_mask
from train import AnimeSegmentation, net_names
def rmbg_fn(img, img_size):
mask = get_mask(model, img, False, int(img_size))
img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
mask = (mask * 255).astype(np.uint8)
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
mask = mask.repeat(3, axis=2)
return mask, img
def load_model(path, net_name, img_size):
global model
model = AnimeSegmentation.try_load(
net_name=net_name, img_size=int(img_size), ckpt_path=path, map_location="cpu"
)
model.eval()
return "success"
def get_model_path():
model_paths = sorted(glob.glob("**/*.ckpt", recursive=True))
return model_path_input.update(choices=model_paths)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=6006, help="gradio server port,")
opt = parser.parse_args()
model = None
app = gr.Blocks()
with app:
with gr.Accordion(label="Model option", open=False):
load_model_path_btn = gr.Button("Get Models")
model_path_input = gr.Dropdown(label="model")
model_type = gr.Dropdown(
label="model type",
value="isnet_is",
choices=net_names,
)
model_image_size = gr.Slider(
label="image size", value=1024, minimum=0, maximum=1280, step=32
)
load_model_path_btn.click(get_model_path, [], model_path_input)
load_model_btn = gr.Button("Load")
model_msg = gr.Textbox()
load_model_btn.click(
load_model, [model_path_input, model_type, model_image_size], model_msg
)
input_img = gr.Image(label="input image")
run_btn = gr.Button(variant="primary")
with gr.Row():
output_mask = gr.Image(label="mask")
output_img = gr.Image(label="result", image_mode="RGBA")
run_btn.click(rmbg_fn, [input_img, model_image_size], [output_mask, output_img])
app.launch(server_port=opt.port)