Skip to content

Commit

Permalink
Fix Detections class tolist() method (#5945)
Browse files Browse the repository at this point in the history
* Fix tolist() to add the file for each Detection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix PEP8 requirement for 2 spaces before an inline comment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
3 people authored Dec 11, 2021
1 parent 8f875d9 commit 8f35436
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,15 @@ def forward(self, imgs, size=640, augment=False, profile=False):

class Detections:
# YOLOv5 detections class for inference results
def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
super().__init__()
d = pred[0].device # device
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
self.imgs = imgs # list of images as numpy arrays
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names
self.files = files # image filenames
self.times = times # profiling times
self.xyxy = pred # xyxy pixels
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
Expand Down Expand Up @@ -612,10 +613,11 @@ def pandas(self):

def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():'
x = [Detections([self.imgs[i]], [self.pred[i]], names=self.names, shape=self.s) for i in range(self.n)]
for d in x:
for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
setattr(d, k, getattr(d, k)[0]) # pop out of list
r = range(self.n) # iterable
x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
# for d in x:
# for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
# setattr(d, k, getattr(d, k)[0]) # pop out of list
return x

def __len__(self):
Expand Down

0 comments on commit 8f35436

Please sign in to comment.