Skip to content

Commit

Permalink
main code
Browse files Browse the repository at this point in the history
  • Loading branch information
williamjones4you authored Jul 10, 2022
1 parent 0f6f88d commit fba9a7f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def build_targets(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
na, nt = self.na, targets.shape[0] # number of anchors, targets
tcls, tbox, indices, anch = [], [], [], []
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices

Expand Down Expand Up @@ -775,7 +775,7 @@ def build_targets(self, p, targets, imgs):
matching_anchs[i].append(all_anch[layer_idx])

for i in range(nl):
if matching_gis[i] != []:
if matching_targets[i] != []:
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
matching_as[i] = torch.cat(matching_as[i], dim=0)
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
Expand All @@ -796,7 +796,7 @@ def find_3_positive(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
na, nt = self.na, targets.shape[0] # number of anchors, targets
indices, anch = [], []
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
gain = torch.ones(7, device=targets.device).long() # normalized to gridspace gain
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices

Expand Down

0 comments on commit fba9a7f

Please sign in to comment.