forked from ClementPinard/SfmLearner-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_inference.py
86 lines (65 loc) · 3.39 KB
/
run_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from imageio import imread, imsave
from PIL import Image
import numpy as np
from path import Path
import argparse
from tqdm import tqdm
from models import DispNetS
from utils import tensor2array
parser = argparse.ArgumentParser(description='Inference script for DispNet learned with \
Structure from Motion Learner inference on KITTI and CityScapes Dataset',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--output-disp", action='store_true', help="save disparity img")
parser.add_argument("--output-depth", action='store_true', help="save depth img")
parser.add_argument("--pretrained", required=True, type=str, help="pretrained DispNet path")
parser.add_argument("--img-height", default=128, type=int, help="Image height")
parser.add_argument("--img-width", default=416, type=int, help="Image width")
parser.add_argument("--no-resize", action='store_true', help="no resizing is done")
parser.add_argument("--dataset-list", default=None, type=str, help="Dataset list file")
parser.add_argument("--dataset-dir", default='.', type=str, help="Dataset directory")
parser.add_argument("--output-dir", default='output', type=str, help="Output directory")
parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@torch.no_grad()
def main():
args = parser.parse_args()
if not(args.output_disp or args.output_depth):
print('You must at least output one value !')
return
disp_net = DispNetS().to(device)
weights = torch.load(args.pretrained)
disp_net.load_state_dict(weights['state_dict'])
disp_net.eval()
dataset_dir = Path(args.dataset_dir)
output_dir = Path(args.output_dir)
output_dir.makedirs_p()
if args.dataset_list is not None:
with open(args.dataset_list, 'r') as f:
test_files = [dataset_dir/file for file in f.read().splitlines()]
else:
test_files = sum([list(dataset_dir.walkfiles('*.{}'.format(ext))) for ext in args.img_exts], [])
print('{} files to test'.format(len(test_files)))
for file in tqdm(test_files):
img = imread(file)
h,w,_ = img.shape
if (not args.no_resize) and (h != args.img_height or w != args.img_width):
img = np.array(Image.fromarray(img).imresize((args.img_height, args.img_width)))
img = np.transpose(img, (2, 0, 1))
tensor_img = torch.from_numpy(img.astype(np.float32)).unsqueeze(0)
tensor_img = ((tensor_img/255 - 0.5)/0.5).to(device)
output = disp_net(tensor_img)[0]
file_path, file_ext = file.relpath(args.dataset_dir).splitext()
print(file_path)
print(file_path.splitall())
file_name = '-'.join(file_path.splitall()[1:])
print(file_name)
if args.output_disp:
disp = (255*tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8)
imsave(output_dir/'{}_disp{}'.format(file_name, file_ext), np.transpose(disp, (1,2,0)))
if args.output_depth:
depth = 1/output
depth = (255*tensor2array(depth, max_value=10, colormap='rainbow')).astype(np.uint8)
imsave(output_dir/'{}_depth{}'.format(file_name, file_ext), np.transpose(depth, (1,2,0)))
if __name__ == '__main__':
main()