Skip to content

Commit

Permalink
torch.split() 1.7.0 compatibility fix (ultralytics#7102)
Browse files Browse the repository at this point in the history
* Update loss.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update loss.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] authored Mar 22, 2022
1 parent f5a84dc commit e0e4b05
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,15 @@ def __init__(self, model, autobalance=False):
if g > 0:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

det = de_parallel(model).model[-1] # Detect() module
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
m = de_parallel(model).model[-1] # Detect() module
self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
self.na = m.na # number of anchors
self.nc = m.nc # number of classes
self.nl = m.nl # number of layers
self.anchors = m.anchors
self.device = device
for k in 'na', 'nc', 'nl', 'anchors':
setattr(self, k, getattr(det, k))

def __call__(self, p, targets): # predictions, targets
lcls = torch.zeros(1, device=self.device) # class loss
Expand All @@ -129,7 +131,8 @@ def __call__(self, p, targets): # predictions, targets

n = b.shape[0] # number of targets
if n:
pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # target-subset of predictions
# pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0
pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions

# Regression
pxy = pxy.sigmoid() * 2 - 0.5
Expand Down

0 comments on commit e0e4b05

Please sign in to comment.