Skip to content

Commit

Permalink
Metric-Confidence plots feature addition (ultralytics#2057)
Browse files Browse the repository at this point in the history
* Metric-Confidence plots feature addition

* cleanup

* Metric-Confidence plots feature addition

* cleanup

* Update run-once lines

* cleanup

* save all 4 curves to wandb
  • Loading branch information
glenn-jocher authored Jan 28, 2021
1 parent bcf7ee1 commit 6ea7d59
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test(data,
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any():
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, [email protected], [email protected]:0.95]
ap50, ap = ap[:, 0], ap.mean(1) # [email protected], [email protected]:0.95
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
else:
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if plots:
plot_results(save_dir=save_dir) # save as results.png
if wandb:
files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
if opt.log_artifacts:
Expand Down
53 changes: 38 additions & 15 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def fitness(x):
return (x[:, :4] * w).sum(1)


def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision-recall_curve.png', names=[]):
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
Expand All @@ -35,12 +35,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision

# Find unique classes
unique_classes = np.unique(target_cls)
nc = unique_classes.shape[0] # number of classes, number of detections

# Create Precision-Recall curve and compute AP for each class
px, py = np.linspace(0, 1, 1000), [] # for plotting
pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
for ci, c in enumerate(unique_classes):
i = pred_cls == c
n_l = (target_cls == c).sum() # number of labels
Expand All @@ -55,25 +54,28 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision

# Recall
recall = tpc / (n_l + 1e-16) # recall curve
r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases

# Precision
precision = tpc / (tpc + fpc) # precision curve
p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score

# AP from recall-precision curve
for j in range(tp.shape[1]):
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
if plot and (j == 0):
if plot and j == 0:
py.append(np.interp(px, mrec, mpre)) # precision at [email protected]

# Compute F1 score (harmonic mean of precision and recall)
# Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + 1e-16)

if plot:
plot_pr_curve(px, py, ap, save_dir, names)
plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names)
plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1')
plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision')
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')

return p, r, ap, f1, unique_classes.astype('int32')
i = f1.mean(0).argmax() # max F1 index
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')


def compute_ap(recall, precision):
Expand Down Expand Up @@ -181,13 +183,14 @@ def print(self):

# Plots ----------------------------------------------------------------------------------------------------------------

def plot_pr_curve(px, py, ap, save_dir='.', names=()):
def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
# Precision-recall curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)

if 0 < len(names) < 21: # show mAP in legend if < 10 classes
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision)
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
else:
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)

Expand All @@ -197,4 +200,24 @@ def plot_pr_curve(px, py, ap, save_dir='.', names=()):
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)
fig.savefig(Path(save_dir), dpi=250)


def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'):
# Metric-confidence curve
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)

if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py):
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
else:
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)

y = py.mean(0)
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
fig.savefig(Path(save_dir), dpi=250)

0 comments on commit 6ea7d59

Please sign in to comment.