-
Notifications
You must be signed in to change notification settings - Fork 0
/
pr_curve.py
104 lines (95 loc) · 3.66 KB
/
pr_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import os
import PIL.Image as Image
import pdb
import matplotlib.pyplot as plt
def main():
#algs = ['sb_crf']
datasets = ['ECSSD']
for dataset in datasets:
print(dataset)
#dir = '/home/rabbit/Datasets/%s'%dataset
input_dir = '/home/rabbit/Desktop/DUT_train/PRE/ECSSD/test2/mask2/'
output_dir = '/home/rabbit/Desktop/DUT_train/PRE/ECSSD/test2/'
#gt_dir = '%s/masks'%dir
gt_dir = '/home/rabbit/Desktop/DUT_train/PRE/ECSSD/test2/gt/'
#input_dirs = ['%s/%s'%(dir, alg) for alg in algs]
fig = plt.figure(figsize=(9, 3))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
name = 'result'
#for input_dir, alg in zip(input_dirs, algs):
evaluate(input_dir,gt_dir, output_dir,name="result")
sb = np.load('%s/%s.npz'%(output_dir,name))
ax1.plot(sb['m_recs'], sb['m_pres'], linewidth=1, )
ax2.plot(np.linspace(0, 1, 21), sb['m_fms'], linewidth=1)
print(' fm: %.4f, mea: %.4f'%( sb['m_thfm'], sb['m_mea']))
ax1.grid(True)
ax1.set_xlabel('Recall', fontsize=14)
ax1.set_ylabel('Precision', fontsize=14)
ax2.grid(True)
ax2.set_xlabel('Threshold', fontsize=14)
ax2.set_ylabel('F-measure', fontsize=14)
handles, labels = ax1.get_legend_handles_labels()
lgd = ax1.legend(handles, labels, loc='center left', bbox_to_anchor=(0.5, -0.5), ncol=8, fontsize=14)
fig.savefig('%s.pdf'%dataset, bbox_extra_artists=(lgd,), bbox_inches='tight')
def evaluate(input_dir, gt_dir, output_dir=None, name=None):
if not os.path.exists(output_dir):
os.mkdir(output_dir)
filelist = os.listdir(input_dir)
eps = np.finfo(float).eps
m_pres = np.zeros(21)
m_recs = np.zeros(21)
m_fms = np.zeros(21)
m_thfm = 0
m_mea = 0
it = 1
for filename in filelist:
if not filename.endswith('.png'):
continue
# print('evaluating image %d'%it)
mask = Image.open('%s/%s' % (input_dir, filename))
mask = np.array(mask, dtype=np.float)
if len(mask.shape) != 2:
mask = mask[:, :, 0]
mask = (mask - mask.min()) / (mask.max()-mask.min()+eps)
gt = Image.open('%s/%s' % (gt_dir, filename))
gt = np.array(gt, dtype=np.uint8)
gt[gt != 0] = 1
pres = []
recs = []
fms = []
mea = np.abs(gt-mask).mean()
# threshold fm
binary = np.zeros(mask.shape)
th = 2*mask.mean()
if th > 1:
th = 1
binary[mask >= th] = 1
sb = (binary * gt).sum()
pre = sb / (binary.sum()+eps)
rec = sb / (gt.sum()+eps)
thfm = 1.3 * pre * rec / (0.3 * pre + rec + eps)
for th in np.linspace(0, 1, 21):
binary = np.zeros(mask.shape)
binary[ mask >= th] = 1
pre = (binary * gt).sum() / (binary.sum()+eps)
rec = (binary * gt).sum() / (gt.sum()+ eps)
fm = 1.3 * pre * rec / (0.3*pre + rec + eps)
pres.append(pre)
recs.append(rec)
fms.append(fm)
fms = np.array(fms)
pres = np.array(pres)
recs = np.array(recs)
m_mea = m_mea * (it-1) / it + mea / it
m_fms = m_fms * (it - 1) / it + fms / it
m_recs = m_recs * (it - 1) / it + recs / it
m_pres = m_pres * (it - 1) / it + pres / it
m_thfm = m_thfm * (it - 1) / it + thfm / it
it += 1
if not (output_dir is None or name is None):
np.savez('%s/%s.npz'%(output_dir, name), m_mea=m_mea, m_thfm=m_thfm, m_recs=m_recs, m_pres=m_pres, m_fms=m_fms)
return m_thfm, m_mea
if __name__ == '__main__':
main()