From 4fad59cfb45253a651dd536bd468fd09ff09ffa9 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Wed, 24 Aug 2022 10:52:53 +0800 Subject: [PATCH 1/3] speed up evaluation --- segment/val.py | 19 ++++++++----------- utils/segment/plots.py | 3 ++- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/segment/val.py b/segment/val.py index 12f7c9fc3476..d014131b7ddd 100644 --- a/segment/val.py +++ b/segment/val.py @@ -122,6 +122,7 @@ def process_batch_masks(predn, pred_masks, gt_masks, labels, iouv, overlap): mode="bilinear", align_corners=False, ).squeeze(0) + gt_masks = gt_masks.gt_(0.5) iou = mask_iou( gt_masks.view(gt_masks.shape[0], -1), @@ -171,7 +172,7 @@ def run( mask_downsample_ratio=1, compute_loss=None, ): - process = process_mask_upsample if plots else process_mask + process = process_mask_upsample if save_json else process_mask # Initialize/load model and set device training = model is not None if training: # called by train.py @@ -304,9 +305,6 @@ def run( proto_out = train_out[1][si] pred_masks = process(proto_out, pred[:, 6:], pred[:, :4], shape=im[si].shape[1:]).permute(2, 0, 1).contiguous().float() - if plots and batch_i < 3: - # filter top 15 to plot - plot_masks.append(torch.as_tensor(pred_masks[:15], dtype=torch.uint8).cpu()) # Predictions if single_cls: @@ -326,6 +324,12 @@ def run( stats.append( (correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls) + # convert pred_masks to uint8 + pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) + if plots and batch_i < 3: + # filter top 15 to plot + plot_masks.append(pred_masks[:15].cpu()) + # Save/log if save_txt: save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) @@ -336,13 +340,6 @@ def run( # Plot images if plots and batch_i < 3: - if masks.shape[1:] != im.shape[2:]: - masks = F.interpolate( - masks.unsqueeze(0).float(), - im.shape[2:], - mode="bilinear", - align_corners=False, - ).squeeze(0) plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) if len(plot_masks): plot_masks = torch.cat(plot_masks, dim=0) diff --git a/utils/segment/plots.py b/utils/segment/plots.py index 11b7081f4995..6303103ed084 100644 --- a/utils/segment/plots.py +++ b/utils/segment/plots.py @@ -132,7 +132,8 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg' for j, box in enumerate(boxes.T.tolist()): if labels or conf[j] > 0.25: # 0.25 conf thresh color = colors(classes[j]) - if scale < 1: + mh, mw = image_masks[j].shape[:2] + if mh != h or mw != w: mask = image_masks[j].astype(np.uint8) mask = cv2.resize(mask, (w, h)) mask = mask.astype(np.bool) From ce6d849cf536e11172e95a45102c40c48612bf16 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Wed, 24 Aug 2022 11:26:31 +0800 Subject: [PATCH 2/3] fix process_mask --- utils/segment/general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/segment/general.py b/utils/segment/general.py index fe4898b2cdd4..80286e3fd94b 100644 --- a/utils/segment/general.py +++ b/utils/segment/general.py @@ -44,7 +44,7 @@ def process_mask_upsample(proto_out, out_masks, bboxes, shape): """ c, mh, mw = proto_out.shape # CHW - masks = (out_masks.tanh() @ proto_out.view(c, -1)).sigmoid().view(-1, mh, mw) + masks = (out_masks.tanh() @ proto_out.float().view(c, -1)).sigmoid().view(-1, mh, mw) masks = F.interpolate(masks[None], shape, mode='bilinear', align_corners=False)[0] # CHW masks = crop(masks.permute(1, 2, 0).contiguous(), bboxes) # HWC return masks.gt_(0.5) @@ -63,7 +63,7 @@ def process_mask(proto_out, out_masks, bboxes, shape, upsample=False): c, mh, mw = proto_out.shape # CHW ih, iw = shape - masks = (out_masks.tanh() @ proto_out.view(c, -1)).sigmoid().view(-1, mh, mw) # CHW + masks = (out_masks.tanh() @ proto_out.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW downsampled_bboxes = bboxes.clone() downsampled_bboxes[:, 0] *= mw / iw From 61212a6a22aab6965e28dae25a5ec841965031eb Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Wed, 24 Aug 2022 11:26:40 +0800 Subject: [PATCH 3/3] fix plots --- utils/segment/plots.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/segment/plots.py b/utils/segment/plots.py index 6303103ed084..4517ff455cba 100644 --- a/utils/segment/plots.py +++ b/utils/segment/plots.py @@ -93,8 +93,8 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg' if paths: annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames if len(targets) > 0: - j = targets[:, 0] == i - ti = targets[j] # image targets + idx = targets[:, 0] == i + ti = targets[idx] # image targets boxes = xywh2xyxy(ti[:, 2:6]).T classes = ti[:, 1].astype('int') @@ -126,13 +126,13 @@ def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg' image_masks = np.repeat(image_masks, nl, axis=0) image_masks = np.where(image_masks == index, 1.0, 0.0) else: - image_masks = masks[j] + image_masks = masks[idx] im = np.asarray(annotator.im).copy() for j, box in enumerate(boxes.T.tolist()): if labels or conf[j] > 0.25: # 0.25 conf thresh color = colors(classes[j]) - mh, mw = image_masks[j].shape[:2] + mh, mw = image_masks[j].shape if mh != h or mw != w: mask = image_masks[j].astype(np.uint8) mask = cv2.resize(mask, (w, h))