Skip to content

Commit

Permalink
Merge pull request #327 from hellock/bboxes-ignore
Browse files Browse the repository at this point in the history
Allow gt_bboxes_ignore for RPN and single-stage detectors
  • Loading branch information
hellock authored Mar 5, 2019
2 parents 85c30cc + f1d06cd commit 70383d4
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 18 deletions.
11 changes: 8 additions & 3 deletions mmdet/core/anchor/anchor_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def anchor_target(anchor_list,
target_means,
target_stds,
cfg,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
sampling=True,
Expand Down Expand Up @@ -41,6 +42,8 @@ def anchor_target(anchor_list,
valid_flag_list[i] = torch.cat(valid_flag_list[i])

# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
Expand All @@ -49,6 +52,7 @@ def anchor_target(anchor_list,
anchor_list,
valid_flag_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
target_means=target_means,
Expand Down Expand Up @@ -90,6 +94,7 @@ def images_to_levels(target, num_level_anchors):
def anchor_target_single(flat_anchors,
valid_flags,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
target_means,
Expand All @@ -108,11 +113,11 @@ def anchor_target_single(flat_anchors,

if sampling:
assign_result, sampling_result = assign_and_sample(
anchors, gt_bboxes, None, None, cfg)
anchors, gt_bboxes, gt_bboxes_ignore, None, cfg)
else:
bbox_assigner = build_assigner(cfg.assigner)
assign_result = bbox_assigner.assign(anchors, gt_bboxes, None,
gt_labels)
assign_result = bbox_assigner.assign(anchors, gt_bboxes,
gt_bboxes_ignore, gt_labels)
bbox_sampler = PseudoSampler()
sampling_result = bbox_sampler.sample(assign_result, anchors,
gt_bboxes)
Expand Down
11 changes: 9 additions & 2 deletions mmdet/models/anchor_heads/anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,14 @@ def loss_single(self, cls_score, bbox_pred, labels, label_weights,
avg_factor=num_total_samples)
return loss_cls, loss_reg

def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
cfg):
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)

Expand All @@ -186,6 +192,7 @@ def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=sampling)
Expand Down
18 changes: 15 additions & 3 deletions mmdet/models/anchor_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,21 @@ def forward_single(self, x):
rpn_bbox_pred = self.rpn_reg(x)
return rpn_cls_score, rpn_bbox_pred

def loss(self, cls_scores, bbox_preds, gt_bboxes, img_metas, cfg):
losses = super(RPNHead, self).loss(cls_scores, bbox_preds, gt_bboxes,
None, img_metas, cfg)
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
img_metas,
cfg,
gt_bboxes_ignore=None):
losses = super(RPNHead, self).loss(
cls_scores,
bbox_preds,
gt_bboxes,
None,
img_metas,
cfg,
gt_bboxes_ignore=gt_bboxes_ignore)
return dict(
loss_rpn_cls=losses['loss_cls'], loss_rpn_reg=losses['loss_reg'])

Expand Down
11 changes: 9 additions & 2 deletions mmdet/models/anchor_heads/ssd_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,14 @@ def loss_single(self, cls_score, bbox_pred, labels, label_weights,
avg_factor=num_total_samples)
return loss_cls[None], loss_reg

def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
cfg):
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.anchor_generators)

Expand All @@ -145,6 +151,7 @@ def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels, img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=1,
sampling=False,
Expand Down
5 changes: 3 additions & 2 deletions mmdet/models/detectors/cascade_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def forward_train(self,
img,
img_meta,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None):
x = self.extract_feat(img)
Expand All @@ -121,7 +121,8 @@ def forward_train(self,
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
self.train_cfg.rpn)
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
rpn_losses = self.rpn_head.loss(
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
losses.update(rpn_losses)

proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
Expand Down
9 changes: 7 additions & 2 deletions mmdet/models/detectors/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,20 @@ def extract_feat(self, img):
x = self.neck(x)
return x

def forward_train(self, img, img_meta, gt_bboxes=None):
def forward_train(self,
img,
img_meta,
gt_bboxes=None,
gt_bboxes_ignore=None):
if self.train_cfg.rpn.get('debug', False):
self.rpn_head.debug_imgs = tensor2imgs(img)

x = self.extract_feat(img)
rpn_outs = self.rpn_head(x)

rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta, self.train_cfg.rpn)
losses = self.rpn_head.loss(*rpn_loss_inputs)
losses = self.rpn_head.loss(
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses

def simple_test(self, img, img_meta, rescale=False):
Expand Down
10 changes: 8 additions & 2 deletions mmdet/models/detectors/single_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ def extract_feat(self, img):
x = self.neck(x)
return x

def forward_train(self, img, img_metas, gt_bboxes, gt_labels):
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
x = self.extract_feat(img)
outs = self.bbox_head(x)
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas, self.train_cfg)
losses = self.bbox_head.loss(*loss_inputs)
losses = self.bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses

def simple_test(self, img, img_meta, rescale=False):
Expand Down
7 changes: 5 additions & 2 deletions mmdet/models/detectors/two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def forward_train(self,
img,
img_meta,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
proposals=None):
x = self.extract_feat(img)
Expand All @@ -94,7 +94,8 @@ def forward_train(self,
rpn_outs = self.rpn_head(x)
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
self.train_cfg.rpn)
rpn_losses = self.rpn_head.loss(*rpn_loss_inputs)
rpn_losses = self.rpn_head.loss(
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
losses.update(rpn_losses)

proposal_inputs = rpn_outs + (img_meta, self.test_cfg.rpn)
Expand All @@ -108,6 +109,8 @@ def forward_train(self,
bbox_sampler = build_sampler(
self.train_cfg.rcnn.sampler, context=self)
num_imgs = img.size(0)
if gt_bboxes_ignore is None:
gt_bboxes_ignore = [None for _ in range(num_imgs)]
sampling_results = []
for i in range(num_imgs):
assign_result = bbox_assigner.assign(
Expand Down

0 comments on commit 70383d4

Please sign in to comment.