forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Metric-Confidence plots feature addition (ultralytics#2057)
* Metric-Confidence plots feature addition * cleanup * Metric-Confidence plots feature addition * cleanup * Update run-once lines * cleanup * save all 4 curves to wandb
- Loading branch information
1 parent
9aac44d
commit 486a2ac
Showing
3 changed files
with
40 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -215,7 +215,7 @@ def test(data, | |
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy | ||
if len(stats) and stats[0].any(): | ||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) | ||
p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, [email protected], [email protected]:0.95] | ||
ap50, ap = ap[:, 0], ap.mean(1) # [email protected], [email protected]:0.95 | ||
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() | ||
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class | ||
else: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ def fitness(x): | |
return (x[:, :4] * w).sum(1) | ||
|
||
|
||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision-recall_curve.png', names=[]): | ||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()): | ||
""" Compute the average precision, given the recall and precision curves. | ||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. | ||
# Arguments | ||
|
@@ -35,12 +35,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision | |
|
||
# Find unique classes | ||
unique_classes = np.unique(target_cls) | ||
nc = unique_classes.shape[0] # number of classes, number of detections | ||
|
||
# Create Precision-Recall curve and compute AP for each class | ||
px, py = np.linspace(0, 1, 1000), [] # for plotting | ||
pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898 | ||
s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95) | ||
ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s) | ||
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) | ||
for ci, c in enumerate(unique_classes): | ||
i = pred_cls == c | ||
n_l = (target_cls == c).sum() # number of labels | ||
|
@@ -55,25 +54,28 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision | |
|
||
# Recall | ||
recall = tpc / (n_l + 1e-16) # recall curve | ||
r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases | ||
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases | ||
|
||
# Precision | ||
precision = tpc / (tpc + fpc) # precision curve | ||
p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score | ||
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score | ||
|
||
# AP from recall-precision curve | ||
for j in range(tp.shape[1]): | ||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) | ||
if plot and (j == 0): | ||
if plot and j == 0: | ||
py.append(np.interp(px, mrec, mpre)) # precision at [email protected] | ||
|
||
# Compute F1 score (harmonic mean of precision and recall) | ||
# Compute F1 (harmonic mean of precision and recall) | ||
f1 = 2 * p * r / (p + r + 1e-16) | ||
|
||
if plot: | ||
plot_pr_curve(px, py, ap, save_dir, names) | ||
plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names) | ||
plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1') | ||
plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision') | ||
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') | ||
|
||
return p, r, ap, f1, unique_classes.astype('int32') | ||
i = f1.mean(0).argmax() # max F1 index | ||
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32') | ||
|
||
|
||
def compute_ap(recall, precision): | ||
|
@@ -181,13 +183,14 @@ def print(self): | |
|
||
# Plots ---------------------------------------------------------------------------------------------------------------- | ||
|
||
def plot_pr_curve(px, py, ap, save_dir='.', names=()): | ||
def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()): | ||
# Precision-recall curve | ||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) | ||
py = np.stack(py, axis=1) | ||
|
||
if 0 < len(names) < 21: # show mAP in legend if < 10 classes | ||
if 0 < len(names) < 21: # display per-class legend if < 21 classes | ||
for i, y in enumerate(py.T): | ||
ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision) | ||
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision) | ||
else: | ||
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) | ||
|
||
|
@@ -197,4 +200,24 @@ def plot_pr_curve(px, py, ap, save_dir='.', names=()): | |
ax.set_xlim(0, 1) | ||
ax.set_ylim(0, 1) | ||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") | ||
fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250) | ||
fig.savefig(Path(save_dir), dpi=250) | ||
|
||
|
||
def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'): | ||
# Metric-confidence curve | ||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) | ||
|
||
if 0 < len(names) < 21: # display per-class legend if < 21 classes | ||
for i, y in enumerate(py): | ||
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric) | ||
else: | ||
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric) | ||
|
||
y = py.mean(0) | ||
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}') | ||
ax.set_xlabel(xlabel) | ||
ax.set_ylabel(ylabel) | ||
ax.set_xlim(0, 1) | ||
ax.set_ylim(0, 1) | ||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") | ||
fig.savefig(Path(save_dir), dpi=250) |