-
-
Notifications
You must be signed in to change notification settings - Fork 16.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Explicitly compute TP, FP in val.py (#5727)
- Loading branch information
1 parent
eb51ffd
commit 36d12a5
Showing
2 changed files
with
16 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ def fitness(x): | |
return (x[:, :4] * w).sum(1) | ||
|
||
|
||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()): | ||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16): | ||
""" Compute the average precision, given the recall and precision curves. | ||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. | ||
# Arguments | ||
|
@@ -37,15 +37,15 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names | |
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] | ||
|
||
# Find unique classes | ||
unique_classes = np.unique(target_cls) | ||
unique_classes, nt = np.unique(target_cls, return_counts=True) | ||
nc = unique_classes.shape[0] # number of classes, number of detections | ||
|
||
# Create Precision-Recall curve and compute AP for each class | ||
px, py = np.linspace(0, 1, 1000), [] # for plotting | ||
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) | ||
for ci, c in enumerate(unique_classes): | ||
i = pred_cls == c | ||
n_l = (target_cls == c).sum() # number of labels | ||
n_l = nt[ci] # number of labels | ||
n_p = i.sum() # number of predictions | ||
|
||
if n_p == 0 or n_l == 0: | ||
|
@@ -56,7 +56,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names | |
tpc = tp[i].cumsum(0) | ||
|
||
# Recall | ||
recall = tpc / (n_l + 1e-16) # recall curve | ||
recall = tpc / (n_l + eps) # recall curve | ||
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases | ||
|
||
# Precision | ||
|
@@ -70,7 +70,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names | |
py.append(np.interp(px, mrec, mpre)) # precision at [email protected] | ||
|
||
# Compute F1 (harmonic mean of precision and recall) | ||
f1 = 2 * p * r / (p + r + 1e-16) | ||
f1 = 2 * p * r / (p + r + eps) | ||
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data | ||
names = {i: v for i, v in enumerate(names)} # to dict | ||
if plot: | ||
|
@@ -80,7 +80,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names | |
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') | ||
|
||
i = f1.mean(0).argmax() # max F1 index | ||
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32') | ||
p, r, f1 = p[:, i], r[:, i], f1[:, i] | ||
tp = (r * nt).round() # true positives | ||
fp = (tp / (p + eps) - tp).round() # false positives | ||
return tp, fp, p, r, f1, ap, unique_classes.astype('int32') | ||
|
||
|
||
def compute_ap(recall, precision): | ||
|
@@ -162,6 +165,12 @@ def process_batch(self, detections, labels): | |
def matrix(self): | ||
return self.matrix | ||
|
||
def tp_fp(self): | ||
tp = self.matrix.diagonal() # true positives | ||
fp = self.matrix.sum(1) - tp # false positives | ||
# fn = self.matrix.sum(0) - tp # false negatives (missed detections) | ||
return tp[:-1], fp[:-1] # remove background class | ||
|
||
def plot(self, normalize=True, save_dir='', names=()): | ||
try: | ||
import seaborn as sn | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -237,7 +237,7 @@ def run(data, | |
# Compute metrics | ||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy | ||
if len(stats) and stats[0].any(): | ||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) | ||
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) | ||
ap50, ap = ap[:, 0], ap.mean(1) # [email protected], [email protected]:0.95 | ||
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() | ||
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class | ||
|