-
Notifications
You must be signed in to change notification settings - Fork 9
/
app.py
139 lines (118 loc) · 6.82 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
'''
@author: Zhigang Jiang
@time: 2022/05/23
@description:
'''
import gradio as gr
import numpy as np
import os
import torch
from PIL import Image
from utils.logger import get_logger
from config.defaults import get_config
from inference import preprocess, run_one_inference
from models.build import build_model
from argparse import Namespace
import gdown
def down_ckpt(model_cfg, ckpt_dir):
model_ids = [
['src/config/mp3d.yaml', '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh'],
['src/config/zind.yaml', '1PzBj-dfDfH_vevgSkRe5kczW0GVl_43I'],
['src/config/pano.yaml', '1JoeqcPbm_XBPOi6O9GjjWi3_rtyPZS8m'],
['src/config/s2d3d.yaml', '1PfJzcxzUsbwwMal7yTkBClIFgn8IdEzI'],
['src/config/ablation_study/full.yaml', '1U16TxUkvZlRwJNaJnq9nAUap-BhCVIha']
]
for model_id in model_ids:
if model_id[0] != model_cfg:
continue
path = os.path.join(ckpt_dir, 'best.pkl')
if not os.path.exists(path):
logger.info(f"Downloading {model_id}")
os.makedirs(ckpt_dir, exist_ok=True)
gdown.download(f"https://drive.google.com/uc?id={model_id[1]}", path, False)
def greet(img_path, pre_processing, weight_name, post_processing, visualization, mesh_format, mesh_resolution):
args.pre_processing = pre_processing
args.post_processing = post_processing
if weight_name == 'mp3d':
model = mp3d_model
elif weight_name == 'zind':
model = zind_model
else:
logger.error("unknown pre-trained weight name")
raise NotImplementedError
img_name = os.path.basename(img_path).split('.')[0]
img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]
vp_cache_path = 'src/demo/default_vp.txt'
if args.pre_processing:
vp_cache_path = os.path.join('src/output', f'{img_name}_vp.txt')
logger.info("pre-processing ...")
img, vp = preprocess(img, vp_cache_path=vp_cache_path)
img = (img / 255.0).astype(np.float32)
run_one_inference(img, model, args, img_name,
logger=logger, show=False,
show_depth='depth-normal-gradient' in visualization,
show_floorplan='2d-floorplan' in visualization,
mesh_format=mesh_format, mesh_resolution=int(mesh_resolution))
return [os.path.join(args.output_dir, f"{img_name}_pred.png"),
os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
vp_cache_path,
os.path.join(args.output_dir, f"{img_name}_pred.json")]
def get_model(args):
config = get_config(args)
down_ckpt(args.cfg, config.CKPT.DIR)
if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
logger.info(f'The {args.device} is not available, will use cpu ...')
config.defrost()
args.device = "cpu"
config.TRAIN.DEVICE = "cpu"
config.freeze()
model, _, _, _ = build_model(config, logger)
return model
if __name__ == '__main__':
logger = get_logger()
args = Namespace(device='cuda', output_dir='src/output', visualize_3d=False, output_3d=True)
os.makedirs(args.output_dir, exist_ok=True)
args.cfg = 'src/config/mp3d.yaml'
mp3d_model = get_model(args)
args.cfg = 'src/config/zind.yaml'
zind_model = get_model(args)
description = "This demo of the project " \
"<a href='https://github.com/zhigangjiang/LGT-Net' target='_blank'>LGT-Net</a>. " \
"It uses the Geometry-Aware Transformer Network to predict the 3d room layout of an rgb panorama."
demo = gr.Interface(fn=greet,
inputs=[gr.Image(type='filepath', label='input rgb panorama', value='src/demo/pano_demo1.png'),
gr.Checkbox(label='pre-processing', value=True),
gr.Radio(['mp3d', 'zind'],
label='pre-trained weight',
value='mp3d'),
gr.Radio(['manhattan', 'atalanta', 'original'],
label='post-processing method',
value='manhattan'),
gr.CheckboxGroup(['depth-normal-gradient', '2d-floorplan'],
label='2d-visualization',
value=['depth-normal-gradient', '2d-floorplan']),
gr.Radio(['.gltf', '.obj', '.glb'],
label='output format of 3d mesh',
value='.gltf'),
gr.Radio(['128', '256', '512', '1024'],
label='output resolution of 3d mesh',
value='256'),
],
outputs=[gr.Image(label='predicted result 2d-visualization', type='filepath'),
gr.Model3D(label='3d mesh reconstruction', clear_color=[1.0, 1.0, 1.0, 1.0]),
gr.File(label='3d mesh file'),
gr.File(label='vanishing point information'),
gr.File(label='layout json')],
examples=[
['src/demo/pano_demo1.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/mp3d_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/mp3d_demo2.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/mp3d_demo3.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/zind_demo1.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/zind_demo2.png', False, 'zind', 'atalanta', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/zind_demo3.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/other_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/other_demo2.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
], title='LGT-Net', allow_flagging="never", cache_examples=False, description=description)
demo.launch(debug=True, enable_queue=False)