Skip to content

Commit

Permalink
use topk instead of sort
Browse files Browse the repository at this point in the history
Summary: now pytorch/pytorch#22812 is fixed

Reviewed By: zhanghang1989

Differential Revision: D32610251

fbshipit-source-id: e099a2c53f71cca95af35aafc26ab59f9613c07b
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Nov 27, 2021
1 parent 4606450 commit 1ad5759
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 14 deletions.
6 changes: 2 additions & 4 deletions detectron2/modeling/meta_arch/dense_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,8 @@ def _decode_per_level_predictions(

# 2. Keep top k top scoring boxes only
num_topk = min(topk_candidates, topk_idxs.size(0))
# torch.sort is actually faster than .topk (https://github.com/pytorch/pytorch/issues/22812)
pred_scores, idxs = pred_scores.sort(descending=True)
pred_scores = pred_scores[:num_topk]
topk_idxs = topk_idxs[idxs[:num_topk]]
pred_scores, idxs = pred_scores.topk(num_topk)
topk_idxs = topk_idxs[idxs]

anchor_idxs, classes_idxs = topk_idxs.unbind(dim=1)

Expand Down
6 changes: 1 addition & 5 deletions detectron2/modeling/proposal_generator/proposal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,7 @@ def find_top_rpn_proposals(
else:
num_proposals_i = min(Hi_Wi_A, pre_nms_topk)

# sort is faster than topk: https://github.com/pytorch/pytorch/issues/22812
# topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
logits_i, idx = logits_i.sort(descending=True, dim=1)
topk_scores_i = logits_i.narrow(1, 0, num_proposals_i)
topk_idx = idx.narrow(1, 0, num_proposals_i)
topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)

# each is N x topk
topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 4
Expand Down
6 changes: 1 addition & 5 deletions detectron2/modeling/proposal_generator/rrpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ def find_top_rrpn_proposals(
else:
num_proposals_i = min(Hi_Wi_A, pre_nms_topk)

# sort is faster than topk (https://github.com/pytorch/pytorch/issues/22812)
# topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
logits_i, idx = logits_i.sort(descending=True, dim=1)
topk_scores_i = logits_i[batch_idx, :num_proposals_i]
topk_idx = idx[batch_idx, :num_proposals_i]
topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)

# each is N x topk
topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 5
Expand Down

0 comments on commit 1ad5759

Please sign in to comment.