diff --git a/utils/metrics.py b/utils/metrics.py index 43b723e7079..e17747b703f 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -90,7 +90,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names p, r, f1 = p[:, i], r[:, i], f1[:, i] tp = (r * nt).round() # true positives fp = (tp / (p + eps) - tp).round() # false positives - return tp, fp, p, r, f1, ap, unique_classes.astype('int32') + return tp, fp, p, r, f1, ap, unique_classes.astype(int) def compute_ap(recall, precision): @@ -156,7 +156,7 @@ def process_batch(self, detections, labels): matches = np.zeros((0, 3)) n = matches.shape[0] > 0 - m0, m1, _ = matches.transpose().astype(np.int16) + m0, m1, _ = matches.transpose().astype(int) for i, gc in enumerate(gt_classes): j = m0 == i if n and sum(j) == 1: diff --git a/val.py b/val.py index d886cf302df..42b0e1813e4 100644 --- a/val.py +++ b/val.py @@ -79,16 +79,17 @@ def process_batch(detections, labels, iouv): """ correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device) iou = box_iou(labels[:, 1:], detections[:, :4]) - x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) # IoU above threshold and classes match - if x[0].shape[0]: - matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detection, iou] - if x[0].shape[0] > 1: - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 1], return_index=True)[1]] - # matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 0], return_index=True)[1]] - matches = torch.from_numpy(matches).to(iouv.device) - correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv + correct_class = labels[:, 0:1] == detections[:, 5] + for i in range(len(iouv)): + x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou] + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True return correct @@ -265,7 +266,7 @@ def run( tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() - nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class + nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class else: nt = torch.zeros(1)