-
Notifications
You must be signed in to change notification settings - Fork 1
/
make_load.py
152 lines (133 loc) · 5.61 KB
/
make_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from pathlib import Path
import argparse
import random
import numpy as np
import matplotlib.cm as cm
import torch
import torch.nn as nn
from torch.autograd import Variable
import os
import torch.multiprocessing
from tqdm import tqdm
import cv2
from scipy.spatial.distance import cdist
from models.utils import (compute_pose_error, compute_epipolar_error,
estimate_pose, make_matching_plot,
error_colormap, AverageTimer, pose_auc, read_image,
rotate_intrinsics, rotate_pose_inplane,
scale_intrinsics, read_image_modified, frame2tensor)
from models.matching import Matching
from models.matchingsuperglue import Matching_ori
from sjlee.loss import loss_superglue
torch.set_grad_enabled(True)
torch.multiprocessing.set_sharing_strategy('file_system')
parser = argparse.ArgumentParser(
description='Image pair matching and pose evaluation with SuperGlue',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--viz', action='store_true',
help='Visualize the matches and dump the plots')
parser.add_argument(
'--eval', action='store_true',
help='Perform the evaluation'
' (requires ground truth pose and intrinsics)')
parser.add_argument(
'--superglue', choices={'indoor', 'outdoor'}, default='indoor',
help='SuperGlue weights')
parser.add_argument(
'--max_keypoints', type=int, default=1023,
help='Maximum number of keypoints detected by Superpoint'
' (\'-1\' keeps all keypoints)')
parser.add_argument(
'--keypoint_threshold', type=float, default=0.005,
help='SuperPoint keypoint detector confidence threshold')
parser.add_argument(
'--nms_radius', type=int, default=4,
help='SuperPoint Non Maximum Suppression (NMS) radius'
' (Must be positive)')
parser.add_argument(
'--sinkhorn_iterations', type=int, default=20,
help='Number of Sinkhorn iterations performed by SuperGlue')
parser.add_argument(
'--match_threshold', type=float, default=0.2,
help='SuperGlue match threshold')
parser.add_argument(
'--resize', type=int, nargs='+', default=[640, 480],
help='Resize the input image before running inference. If two numbers, '
'resize to the exact dimensions, if one number, resize the max '
'dimension, if -1, do not resize')
parser.add_argument(
'--resize_float', action='store_true',
help='Resize the image after casting uint8 to float')
parser.add_argument(
'--cache', action='store_true',
help='Skip the pair if output .npz files are already found')
parser.add_argument(
'--show_keypoints', action='store_true',
help='Plot the keypoints in addition to the matches')
parser.add_argument(
'--fast_viz', action='store_true',
help='Use faster image visualization based on OpenCV instead of Matplotlib')
parser.add_argument(
'--viz_extension', type=str, default='png', choices=['png', 'pdf'],
help='Visualization file extension. Use pdf for highest-quality.')
parser.add_argument(
'--opencv_display', action='store_true',
help='Visualize via OpenCV before saving output images')
parser.add_argument(
'--eval_pairs_list', type=str, default='assets/scannet_sample_pairs_with_gt.txt',
help='Path to the list of image pairs for evaluation')
parser.add_argument(
'--shuffle', action='store_true',
help='Shuffle ordering of pairs before processing')
parser.add_argument(
'--max_length', type=int, default=-1,
help='Maximum number of pairs to evaluate')
parser.add_argument(
'--eval_input_dir', type=str, default='assets/scannet_sample_images/',
help='Path to the directory that contains the images')
parser.add_argument(
'--eval_output_dir', type=str, default='test_matches',
help='Path to the directory in which the .npz results and optional,'
'visualizations are written')
parser.add_argument(
'--learning_rate', type=float, default=0.0001, #0.0001
help='Learning rate')
parser.add_argument(
'--batch_size', type=int, default=1,
help='batch_size')
parser.add_argument(
'--train_path', type=str, default='/home/cvlab09/projects/seungjun_an/dataset/train2014/',
help='Path to the directory of training imgs.')
parser.add_argument(
'--epoch', type=int, default=1,
help='Number of epoches')
if __name__ == '__main__':
opt = parser.parse_args()
print(opt)
# make sure the flags are properly used
assert not (opt.opencv_display and not opt.viz), 'Must use --viz with --opencv_display'
assert not (opt.opencv_display and not opt.fast_viz), 'Cannot use --opencv_display without --fast_viz'
assert not (opt.fast_viz and not opt.viz), 'Must use --viz with --fast_viz'
assert not (opt.fast_viz and opt.viz_extension == 'pdf'), 'Cannot use pdf extension with --fast_viz'
numOftrainSet = 10
# store viz results
eval_output_dir = Path(opt.eval_output_dir)
eval_output_dir.mkdir(exist_ok=True, parents=True)
print('Will write visualization images to',
'directory \"{}\"'.format(eval_output_dir))
config = {
'superpoint': {
'nms_radius': opt.nms_radius,
'keypoint_threshold': opt.keypoint_threshold,
'max_keypoints': opt.max_keypoints
},
'superglue': {
'weights': opt.superglue,
'sinkhorn_iterations': opt.sinkhorn_iterations,
'match_threshold': opt.match_threshold,
}
}
matching = Matching(config).eval().to('cuda')
matching = torch.load('/home/cvlab09/projects/seungjun_an/superglue_test/model_epoch_1.pth')
torch.save(matching.state_dict(), 'model_state_dict_epoch_1.pth')