Skip to content

Commit

Permalink
Fix the CustomNonLinearClsHead when the batch_size is set to 1 (#2571)
Browse files Browse the repository at this point in the history
Fix bn1d issue

Co-authored-by: sungmanc <[email protected]>
  • Loading branch information
sungmanc and sungmanc authored Oct 24, 2023
1 parent a5193b1 commit 65ddbfa
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def forward(self, x):

def forward_train(self, cls_score, gt_label):
"""Forward_train fuction of CustomNonLinearHead class."""
bs = cls_score.shape[0]
if bs == 1:
cls_score = torch.cat([cls_score, cls_score], dim=0)
gt_label = torch.cat([gt_label, gt_label], dim=0)
logit = self.classifier(cls_score)
losses = self.loss(logit, gt_label, feature=cls_score)
return losses
Expand Down

0 comments on commit 65ddbfa

Please sign in to comment.