Skip to content

Commit

Permalink
Use multi-threading in cache_labels (#3505)
Browse files Browse the repository at this point in the history
* Use multi threading in cache_labels

* PEP8 reformat

* Add num_threads

* changed ThreadPool.imap_unordered to Pool.imap_unordered

* Remove inplace additions

* Update datasets.py

refactor initial desc

Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
deanmark and glenn-jocher authored Jun 8, 2021
1 parent c058a61 commit 28bff22
Showing 1 changed file with 56 additions and 43 deletions.
99 changes: 56 additions & 43 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shutil
import time
from itertools import repeat
from multiprocessing.pool import ThreadPool
from multiprocessing.pool import ThreadPool, Pool
from pathlib import Path
from threading import Thread

Expand All @@ -29,6 +29,7 @@
help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
num_threads = min(8, os.cpu_count()) # number of multiprocessing threads
logger = logging.getLogger(__name__)

# Get orientation exif tag
Expand Down Expand Up @@ -447,7 +448,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
if cache_images:
gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads
results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
pbar = tqdm(enumerate(results), total=n)
for i, x in pbar:
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
Expand All @@ -458,53 +459,24 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
# Cache dataset labels, check images and read shapes
x = {} # dict
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
for i, (im_file, lb_file) in enumerate(pbar):
try:
# verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
segments = [] # instance segments
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in img_formats, f'invalid image format {im.format}'

# verify labels
if os.path.isfile(lb_file):
nf += 1 # label found
with open(lb_file, 'r') as f:
l = [x.split() for x in f.read().strip().splitlines() if len(x)]
if any([len(x) > 8 for x in l]): # is segment
classes = np.array([x[0] for x in l], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
l = np.array(l, dtype=np.float32)
if len(l):
assert l.shape[1] == 5, 'labels require 5 columns each'
assert (l >= 0).all(), 'negative labels'
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
else:
ne += 1 # label empty
l = np.zeros((0, 5), dtype=np.float32)
else:
nm += 1 # label missing
l = np.zeros((0, 5), dtype=np.float32)
x[im_file] = [l, shape, segments]
except Exception as e:
nc += 1
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')

pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
with Pool(num_threads) as pool:
pbar = tqdm(pool.imap_unordered(verify_image_label,
zip(self.img_files, self.label_files, repeat(prefix))),
desc=desc, total=len(self.img_files))
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f in pbar:
if im_file:
x[im_file] = [l, shape, segments]
nm, nf, ne, nc = nm + nm_f, nf + nf_f, ne + ne_f, nc + nc_f
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
pbar.close()

if nf == 0:
logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')

x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = nf, nm, ne, nc, i + 1
x['results'] = nf, nm, ne, nc, len(self.img_files)
x['version'] = 0.2 # cache version
try:
torch.save(x, path) # save cache for next time
Expand Down Expand Up @@ -1069,3 +1041,44 @@ def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False):
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
with open(path / txt[i], 'a') as f:
f.write(str(img) + '\n') # add image to txt file


def verify_image_label(params):
# Verify one image-label pair
im_file, lb_file, prefix = params
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
try:
# verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
segments = [] # instance segments
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in img_formats, f'invalid image format {im.format}'

# verify labels
if os.path.isfile(lb_file):
nf = 1 # label found
with open(lb_file, 'r') as f:
l = [x.split() for x in f.read().strip().splitlines() if len(x)]
if any([len(x) > 8 for x in l]): # is segment
classes = np.array([x[0] for x in l], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
l = np.array(l, dtype=np.float32)
if len(l):
assert l.shape[1] == 5, 'labels require 5 columns each'
assert (l >= 0).all(), 'negative labels'
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
else:
ne = 1 # label empty
l = np.zeros((0, 5), dtype=np.float32)
else:
nm = 1 # label missing
l = np.zeros((0, 5), dtype=np.float32)
return im_file, l, shape, segments, nm, nf, ne, nc
except Exception as e:
nc = 1
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
return [None] * 4 + [nm, nf, ne, nc]

0 comments on commit 28bff22

Please sign in to comment.