From 46902430738a0bdc19d0a84dd2f2dfb11353868e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 27 Nov 2020 18:13:43 +0000 Subject: [PATCH] Support for image with no annotations in RetinaNet (#3032) * Enable support for images without annotations * Ensuring gradient propagates to RegressionHead. * Rewriting losses to remove branching. * Fix the seed on DeformConv autocast test. --- .../test_models_detection_negative_samples.py | 9 ++++ test/test_ops.py | 2 + torchvision/models/detection/retinanet.py | 44 ++++++------------- 3 files changed, 25 insertions(+), 30 deletions(-) diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index ed0cc515940..6d767971f72 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -128,6 +128,15 @@ def test_forward_negative_sample_krcnn(self): self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.)) + def test_forward_negative_sample_retinanet(self): + model = torchvision.models.detection.retinanet_resnet50_fpn( + num_classes=2, min_size=100, max_size=100) + + images, targets = self._make_empty_sample() + loss_dict = model(images, targets) + + self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.)) + if __name__ == '__main__': unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index 1ba40d0da5f..68e6a5d2825 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,3 +1,4 @@ +from common_utils import set_rng_seed import math import unittest @@ -655,6 +656,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_): @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_autocast(self): + set_rng_seed(0) for dtype in (torch.float, torch.half): with torch.cuda.amp.autocast(): self._test_forward(torch.device("cuda"), False, dtype=dtype) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index fc05106a807..d46d39543f8 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -107,21 +107,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs): # determine only the foreground foreground_idxs_per_image = matched_idxs_per_image >= 0 num_foreground = foreground_idxs_per_image.sum() - # no matched_idxs means there were no annotations in this image - # TODO: enable support for images without annotations that works on distributed - if False: # matched_idxs_per_image.numel() == 0: - gt_classes_target = torch.zeros_like(cls_logits_per_image) - valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0]) - else: - # create the target classification - gt_classes_target = torch.zeros_like(cls_logits_per_image) - gt_classes_target[ - foreground_idxs_per_image, - targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] - ] = 1.0 - - # find indices for which anchors should be ignored - valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS + + # create the target classification + gt_classes_target = torch.zeros_like(cls_logits_per_image) + gt_classes_target[ + foreground_idxs_per_image, + targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] + ] = 1.0 + + # find indices for which anchors should be ignored + valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS # compute the classification loss losses.append(sigmoid_focal_loss( @@ -191,23 +186,12 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ zip(targets, bbox_regression, anchors, matched_idxs): - # no matched_idxs means there were no annotations in this image - # TODO enable support for images without annotations with distributed support - # if matched_idxs_per_image.numel() == 0: - # continue - - # get the targets corresponding GT for each proposal - # NB: need to clamp the indices because we can have a single - # GT in the image, and matched_idxs can be -2, which goes - # out of bounds - matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)] - # determine only the foreground indices, ignore the rest - foreground_idxs_per_image = matched_idxs_per_image >= 0 - num_foreground = foreground_idxs_per_image.sum() + foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] + num_foreground = foreground_idxs_per_image.numel() # select only the foreground boxes - matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :] + matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] @@ -403,7 +387,7 @@ def compute_loss(self, targets, head_outputs, anchors): matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.empty((0,), dtype=torch.int32)) + matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64)) continue match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)