-
Notifications
You must be signed in to change notification settings - Fork 31
/
eval.py
98 lines (76 loc) · 3.77 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch as torch
from tqdm import tqdm
from loss import compute_loss, get_metrics, print_metrics
from image import save_masks
import numpy as np
import matplotlib.pyplot as plt
def evaluation(model, dataset, device, save_mask=True, plot_roc=True, print_metric=True):
"""
Function to perform an evaluation of a trained model. We compute different metrics show in the dictionary
to_plot_metrics and plot the ROC over different thresholds.
:param model: a trained model
:param dataset: dataset of images
:param device: GPU or CPU. Used to transfer the dataset to the right device.
:param save_mask: Boolean to call or not saveMask to plot the mask predicted by the model
:param plot_roc: Boolean to plot and save the ROC computer over the different thresholds
:param print_metric: Boolean to plot or not the different metrics computed over the thresholds
:return: the dictionary containing the metrics
"""
# Set model modules to eval
model.eval()
loss = 0
last_masks = [None] * len(dataset)
last_truths = [None] * len(dataset)
# thresholds for the probabilities defined in the feature maps to classifiy the pixels
thesholds = [0, 0.0000001, 0.000001, 0.000005, 0.00001, 0.000025, 0.00005, 0.0001, 0.00025, 0.0005, 0.001, 0.005,
0.01, 0.025, 0.05, 0.075, 0.1, 0.2, 0.4, 0.6, 0.8, 1]
n_thesholds = len(thesholds)
#All metrics and measures to be computed for each threshold
to_plot_metrics = dict([("F1",np.zeros(n_thesholds)), ("Recall",np.zeros(n_thesholds)),
("Precision",np.zeros(n_thesholds)), ("TP",np.zeros(n_thesholds)),
("TN", np.zeros(n_thesholds)), ("FP",np.zeros(n_thesholds)), ("FN",np.zeros(n_thesholds)),
("AUC", 0), ("TPR", np.zeros(n_thesholds)),
("FPR", np.zeros(n_thesholds))])
with tqdm(desc=f'Validation', unit='img') as progress_bar:
for i, (image, ground_truth) in enumerate(dataset):
image = image[0, ...]
ground_truth = ground_truth[0, ...]
last_truths[i] = ground_truth
image = image.to(device)
ground_truth = ground_truth.to(device)
with torch.no_grad():
mask_predicted = model(image)
last_masks[i] = mask_predicted
progress_bar.set_postfix(**{'loss': loss})
bce_weight = torch.Tensor([1, 8]).to(device)
loss += compute_loss(mask_predicted, ground_truth, bce_weight=bce_weight)
get_metrics(mask_predicted[0, 0], ground_truth[0], to_plot_metrics, thesholds)
progress_bar.update()
if save_mask:
save_masks(last_masks, last_truths, str(device), max_img=50, shuffle=False, color="red",
filename="mask_predicted_test.png", threshold=thesholds[np.argmax(to_plot_metrics["F1"])])
if print_metric:
print_metrics(to_plot_metrics, len(dataset), "test set")
# AVERAGING THE METRICS
nb_images = len(dataset)
for (k,v) in to_plot_metrics.items():
to_plot_metrics[k] = v / nb_images
# ROC
if plot_roc:
plt.title('Receiver Operating Characteristic')
plt.plot(to_plot_metrics["FPR"], to_plot_metrics["TPR"], 'b',
label='AUC = %0.2f' % to_plot_metrics["AUC"])
plt.legend(loc='lower right')
plt.plot([0, 1], [0, 1], 'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
#plt.show()
plt.savefig("ROC.png")
plt.show()
plt.close("ROC.png")
loss /= len(dataset)
to_plot_metrics["loss"] = loss
to_plot_metrics["best_threshold"] = thesholds[np.argmax(to_plot_metrics["F1"])]
return to_plot_metrics