-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_multi_sopatch.py
201 lines (177 loc) · 6.38 KB
/
main_multi_sopatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""统一所有方法,对一对图片进行配准测试
Usage:
$ python main_multi_sopatch.py -m
log-dir:
log/multirun
"""
import glob
import json
import logging
import time
import cv2 # type:ignore
import hydra
import numpy as np
from tqdm import tqdm
from components import load_component
# from lib.config import Config
from lib.rootpath import rootPath
from lib.utils import corr_MMA, pix2pix_RMSE
from utils.evaluation_utils import estimate_homo
@hydra.main(
version_base=None,
config_path="conf",
config_name="config_multi_method_ablation.yaml",
# config_name="test",
)
def main(config):
log = logging.getLogger(__name__)
# 读取图片路径
method = config.method
dataset = config.dataset.dataset_path
imgfiles1 = glob.glob(str(rootPath / dataset) + r"/test/opt/*.png")
imgfiles2 = glob.glob(str(rootPath / dataset) + r"/test/sar/*.png")
# %% load component
extractor = load_component("extractor", method.extractor.name, method.extractor)
matcher = load_component("matcher", method.matcher.name, method.matcher)
ransac = load_component("ransac", method.ransac.name, method.ransac)
# %% 评价指标
# 平均提取点数
mean_extract_nums = []
# 平均匹配点数
mean_match_nums = []
# 去除离散点后的匹配点数
mean_ransac_nums = []
# 匹配后mma
m_corr_mma = {}
m_ransac_mma = {}
# 利用角点进行单应性矩阵精度的评估,方式同MMA
m_homo_mma = {}
# dists_homo = []
# 失败情况 提取失败、匹配失败,剔点失败、最后匹配失败
failed_nums = [0, 0, 0, 0]
# 阈值3情况下的指标
mRMSE = []
mNCM = []
mCMR = []
corr_pair_num = 0
homo_pair_num = 0
success_num = 0
# %%
start_time = time.perf_counter()
for i in tqdm(range(len(imgfiles1))):
img1_path = imgfiles1[i]
img2_path = imgfiles2[i]
img1, img2 = cv2.imread(img1_path), cv2.imread(img2_path)
# (width, height)
size1, size2 = np.flip(np.asarray(img1.shape[:2])), np.flip(
np.asarray(img2.shape[:2])
)
# %% extractor提取特征点和描述子
if method.name == "LoFTR":
corr1, corr2 = extractor.run(img1_path, img2_path)
mean_extract_nums.append(np.mean([len(corr1), len(corr2)], dtype=int))
else:
kpt1, desc1 = extractor.run(img1_path)
kpt2, desc2 = extractor.run(img2_path)
if len(kpt1) == 0 or len(kpt2) == 0:
failed_nums[0] += 1
continue
mean_extract_nums.append(np.mean([len(kpt1), len(kpt2)], dtype=int))
# %% matcher
test_data = {
"x1": kpt1,
"x2": kpt2,
"desc1": desc1,
"desc2": desc2,
"size1": size1,
"size2": size2,
}
corr1, corr2 = matcher.run(test_data)
if len(corr1) <= 4 or len(corr2) <= 4:
failed_nums[1] += 1
continue
mean_match_nums.append(np.mean([len(corr1), len(corr2)], dtype=int))
corr_mma = corr_MMA(corr1, corr2)
corr_pair_num += 1
for key in corr_mma:
m_corr_mma[key] = m_corr_mma.get(key, 0) + corr_mma[key]
# %% ransac
H_pred, corr1, corr2 = ransac.run(corr1, corr2)
if len(corr1) <= 4 or len(corr2) <= 4 or H_pred is None:
failed_nums[2] += 1
continue
mean_ransac_nums.append(np.mean([len(corr1), len(corr2)], dtype=int))
# %%evaluation homography estimation:list
homo_mma = estimate_homo(img1, H_pred)
homo_pair_num += 1
for key in homo_mma:
m_homo_mma[key] = m_homo_mma.get(key, 0) + homo_mma[key]
# %%evaluation
RMSE, NCM, CMR, bool_list, ransac_mma = pix2pix_RMSE(corr1, corr2)
if NCM >= 5:
success_num += 1
mNCM.append(NCM)
mCMR.append(CMR)
mRMSE.append(RMSE)
for key in ransac_mma:
m_ransac_mma[key] = m_ransac_mma.get(key, 0) + ransac_mma[key]
else:
failed_nums[3] += 1
# %%数据集精度评估
mean_extract_nums = int(np.mean(mean_extract_nums))
# 平均匹配点数
mean_match_nums = int(np.mean(mean_match_nums))
# 去除离散点后的匹配点数
mean_ransac_nums = int(np.mean(mean_ransac_nums))
for key in m_corr_mma:
m_corr_mma[key] = round(m_corr_mma[key] / corr_pair_num, 3)
for key in m_homo_mma:
m_homo_mma[key] = round(m_homo_mma[key] / homo_pair_num, 3)
for key in m_ransac_mma:
m_ransac_mma[key] = round(m_ransac_mma[key] / success_num, 3)
end_time = time.perf_counter()
# %% 数据分析,保存为 json
# save data
data = {
"method": f"{method.name}",
"dataset": f"{config.dataset.name}",
"dataset_size": len(imgfiles1),
"success_num": success_num,
"failed_nums": failed_nums,
"success_rate": round(success_num / len(imgfiles1), 3),
"mean_extract_nums": mean_extract_nums,
"mean_match_nums": mean_match_nums,
"mean_ransac_nums": mean_ransac_nums,
"corr_mma": m_corr_mma,
"ransac_mma": m_ransac_mma,
"homo_mma": m_homo_mma,
"NCM": f"{np.mean(mNCM):.1f}",
"CMR": f"{np.mean(mCMR):.2f}",
"RMSE": f"{np.mean(mRMSE):.2f}",
"Time": f"{int(end_time - start_time):}s",
}
with open(
f"result/{method.name}_{config.dataset.name}.json",
"w",
) as f:
json.dump(data, f)
log.info(f"Method:{method.name}")
log.info(f"Dataset:{config.dataset.name}")
log.info(f"Dataset size:{len(imgfiles1)}")
log.info(f"Success num:{success_num}")
log.info(f"failed_nums: {failed_nums}")
log.info(f"failed_nums_sum: {sum(failed_nums)}")
log.info(f"Success rate:{success_num / len(imgfiles1):.3f}")
log.info(f"mean_extract_nums: {mean_extract_nums}")
log.info(f"mean_match_nums: {mean_match_nums}")
log.info(f"mean_ransac_nums: {mean_ransac_nums}")
log.info(f"corr_mma: {m_corr_mma}")
log.info(f"ransac_mma: {m_ransac_mma}")
log.info(f"homo_mma: {m_homo_mma}")
log.info(f"NCM:{np.mean(mNCM):.1f}")
log.info(f"CMR:{np.mean(mCMR):.2f}")
log.info(f"RMSE:{np.mean(mRMSE):.2f}")
log.info(f"Time:{int(end_time - start_time):}s")
log.info("\n")
if __name__ == "__main__":
main()