Skip to content

Commit

Permalink
Support for image with no annotations in RetinaNet (#3032)
Browse files Browse the repository at this point in the history
* Enable support for images without annotations

* Ensuring gradient propagates to RegressionHead.

* Rewriting losses to remove branching.

* Fix the seed on DeformConv autocast test.
  • Loading branch information
datumbox authored Nov 27, 2020
1 parent 9e71fda commit 4ab46e5
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 30 deletions.
9 changes: 9 additions & 0 deletions test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def test_forward_negative_sample_krcnn(self):
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.))

def test_forward_negative_sample_retinanet(self):
model = torchvision.models.detection.retinanet_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)

images, targets = self._make_empty_sample()
loss_dict = model(images, targets)

self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.))


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from common_utils import set_rng_seed
import math
import unittest

Expand Down Expand Up @@ -655,6 +656,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_autocast(self):
set_rng_seed(0)
for dtype in (torch.float, torch.half):
with torch.cuda.amp.autocast():
self._test_forward(torch.device("cuda"), False, dtype=dtype)
Expand Down
44 changes: 14 additions & 30 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,16 @@ def compute_loss(self, targets, head_outputs, matched_idxs):
# determine only the foreground
foreground_idxs_per_image = matched_idxs_per_image >= 0
num_foreground = foreground_idxs_per_image.sum()
# no matched_idxs means there were no annotations in this image
# TODO: enable support for images without annotations that works on distributed
if False: # matched_idxs_per_image.numel() == 0:
gt_classes_target = torch.zeros_like(cls_logits_per_image)
valid_idxs_per_image = torch.arange(cls_logits_per_image.shape[0])
else:
# create the target classification
gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target[
foreground_idxs_per_image,
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
] = 1.0

# find indices for which anchors should be ignored
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

# create the target classification
gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target[
foreground_idxs_per_image,
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
] = 1.0

# find indices for which anchors should be ignored
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

# compute the classification loss
losses.append(sigmoid_focal_loss(
Expand Down Expand Up @@ -191,23 +186,12 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs):

for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \
zip(targets, bbox_regression, anchors, matched_idxs):
# no matched_idxs means there were no annotations in this image
# TODO enable support for images without annotations with distributed support
# if matched_idxs_per_image.numel() == 0:
# continue

# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image.clamp(min=0)]

# determine only the foreground indices, ignore the rest
foreground_idxs_per_image = matched_idxs_per_image >= 0
num_foreground = foreground_idxs_per_image.sum()
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
num_foreground = foreground_idxs_per_image.numel()

# select only the foreground boxes
matched_gt_boxes_per_image = matched_gt_boxes_per_image[foreground_idxs_per_image, :]
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]]
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]

Expand Down Expand Up @@ -403,7 +387,7 @@ def compute_loss(self, targets, head_outputs, anchors):
matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image['boxes'].numel() == 0:
matched_idxs.append(torch.empty((0,), dtype=torch.int32))
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64))
continue

match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)
Expand Down

0 comments on commit 4ab46e5

Please sign in to comment.