Skip to content

Commit

Permalink
Cache v0.3: improved corrupt image/label reporting (#3676)
Browse files Browse the repository at this point in the history
* Cache v0.3: improved corrupt image/label reporting

Fix for #3656 (comment)

* cleanup
  • Loading branch information
glenn-jocher authored Jun 18, 2021
1 parent 2296f15 commit f527704
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
if cache_path.is_file():
cache, exists = torch.load(cache_path), True # load
if cache['hash'] != get_hash(self.label_files + self.img_files): # changed
if cache['hash'] != get_hash(self.label_files + self.img_files) or cache['version'] != 0.3:

This comment has been minimized.

Copy link
@thanhminhmr

thanhminhmr Jun 18, 2021

Contributor

It shoud be written like this:

if cache['version'] < 0.3 or cache['hash'] != get_hash(self.label_files + self.img_files):

The cache version check is much faster than the cache hash, so check the version first make sense. And also change the version check to be future proof (?) (I'm assume that older version have smaller version value, revert this if it is not the case.)

This comment has been minimized.

Copy link
@glenn-jocher

glenn-jocher Jun 19, 2021

Author Member

@thanhminhmr yes good points! Implemented in PR #3691

cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
else:
cache, exists = self.cache_labels(cache_path, prefix), False # cache
Expand All @@ -400,11 +400,12 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
if exists:
d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
if cache['msgs']:
logging.info('\n'.join(cache['msgs'])) # display warnings
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'

# Read cache
cache.pop('hash') # remove hash
cache.pop('version') # remove version
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
labels, shapes, self.segments = zip(*cache.values())
self.labels = list(labels)
self.shapes = np.array(shapes, dtype=np.float64)
Expand Down Expand Up @@ -461,26 +462,31 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
# Cache dataset labels, check images and read shapes
x = {} # dict
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
with Pool(num_threads) as pool:
pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
desc=desc, total=len(self.img_files))
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f in pbar:
for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
ne += ne_f
nc += nc_f
if im_file:
x[im_file] = [l, shape, segments]
if msg:
msgs.append(msg)
pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"

pbar.close()
if msgs:
logging.info('\n'.join(msgs))
if nf == 0:
logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = nf, nm, ne, nc, len(self.img_files)
x['version'] = 0.2 # cache version
x['msgs'] = msgs # warnings
x['version'] = 0.3 # cache version
try:
torch.save(x, path) # save cache for next time
logging.info(f'{prefix}New cache created: {path}')
Expand Down Expand Up @@ -1084,11 +1090,11 @@ def verify_image_label(args):
else:
nm = 1 # label missing
l = np.zeros((0, 5), dtype=np.float32)
return im_file, l, shape, segments, nm, nf, ne, nc
return im_file, l, shape, segments, nm, nf, ne, nc, ''
except Exception as e:
nc = 1
logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
return [None, None, None, None, nm, nf, ne, nc]
msg = f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}'
return [None, None, None, None, nm, nf, ne, nc, msg]


def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
Expand Down

0 comments on commit f527704

Please sign in to comment.