diff --git a/utils/loss.py b/utils/loss.py index 12ec230..a414ba2 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -501,7 +501,7 @@ def build_targets(self, p, targets): # Build targets for compute_loss(), input targets(image,class,x,y,w,h) na, nt = self.na, targets.shape[0] # number of anchors, targets tcls, tbox, indices, anch = [], [], [], [] - gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices @@ -775,7 +775,7 @@ def build_targets(self, p, targets, imgs): matching_anchs[i].append(all_anch[layer_idx]) for i in range(nl): - if matching_gis[i] != []: + if matching_targets[i] != []: matching_bs[i] = torch.cat(matching_bs[i], dim=0) matching_as[i] = torch.cat(matching_as[i], dim=0) matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) @@ -796,7 +796,7 @@ def find_3_positive(self, p, targets): # Build targets for compute_loss(), input targets(image,class,x,y,w,h) na, nt = self.na, targets.shape[0] # number of anchors, targets indices, anch = [], [] - gain = torch.ones(7, device=targets.device) # normalized to gridspace gain + gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt) targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices