Skip to content

Commit

Permalink
[Enhance] Speed up SimOTA matching. (open-mmlab#7098)
Browse files Browse the repository at this point in the history
  • Loading branch information
RangiLyu authored and ZwwWayne committed Jul 19, 2022
1 parent 787307c commit d67a39e
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions mmdet/core/bbox/assigners/sim_ota_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,27 +225,27 @@ def get_in_gt_and_in_center_info(self, priors, gt_bboxes):
return is_in_gts_or_centers, is_in_boxes_and_centers

def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):
matching_matrix = torch.zeros_like(cost)
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
# select candidate topk ious for dynamic-k calculation
candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
# calculate dynamic k for each gt
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(
cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
matching_matrix[:, gt_idx][pos_idx] = 1.0
cost[:, gt_idx], k=dynamic_ks[gt_idx], largest=False)
matching_matrix[:, gt_idx][pos_idx] = 1

del topk_ious, dynamic_ks, pos_idx

prior_match_gt_mask = matching_matrix.sum(1) > 1
if prior_match_gt_mask.sum() > 0:
cost_min, cost_argmin = torch.min(
cost[prior_match_gt_mask, :], dim=1)
matching_matrix[prior_match_gt_mask, :] *= 0.0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0
matching_matrix[prior_match_gt_mask, :] *= 0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1
# get foreground mask inside box and center prior
fg_mask_inboxes = matching_matrix.sum(1) > 0.0
fg_mask_inboxes = matching_matrix.sum(1) > 0
valid_mask[valid_mask.clone()] = fg_mask_inboxes

matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
Expand Down

0 comments on commit d67a39e

Please sign in to comment.