Skip to content

Commit

Permalink
Update --cache disk deprecate *_npy/ dirs (#6876)
Browse files Browse the repository at this point in the history
* Updates

* Updates

* Updates

* Updates

* Updates

* Updates

* Updates

* Updates

* Updates

* Updates

* Cleanup

* Cleanup
  • Loading branch information
glenn-jocher authored Mar 6, 2022
1 parent 8a66eba commit 4728840
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 40 deletions.
76 changes: 38 additions & 38 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,19 +407,19 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
# f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
else:
raise Exception(f'{prefix}{p} does not exist')
self.img_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert self.img_files, f'{prefix}No images found'
assert self.im_files, f'{prefix}No images found'
except Exception as e:
raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')

# Check cache
self.label_files = img2label_paths(self.img_files) # labels
self.label_files = img2label_paths(self.im_files) # labels
cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
try:
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == self.cache_version # same version
assert cache['hash'] == get_hash(self.label_files + self.img_files) # same hash
assert cache['hash'] == get_hash(self.label_files + self.im_files) # same hash
except Exception:
cache, exists = self.cache_labels(cache_path, prefix), False # cache

Expand All @@ -437,7 +437,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
labels, shapes, self.segments = zip(*cache.values())
self.labels = list(labels)
self.shapes = np.array(shapes, dtype=np.float64)
self.img_files = list(cache.keys()) # update
self.im_files = list(cache.keys()) # update
self.label_files = img2label_paths(cache.keys()) # update
n = len(shapes) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
Expand Down Expand Up @@ -466,7 +466,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
s = self.shapes # wh
ar = s[:, 1] / s[:, 0] # aspect ratio
irect = ar.argsort()
self.img_files = [self.img_files[i] for i in irect]
self.im_files = [self.im_files[i] for i in irect]
self.label_files = [self.label_files[i] for i in irect]
self.labels = [self.labels[i] for i in irect]
self.shapes = s[irect] # wh
Expand All @@ -485,24 +485,20 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride

# Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources)
self.imgs, self.img_npy = [None] * n, [None] * n
self.ims = [None] * n
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
if cache_images:
if cache_images == 'disk':
self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy')
self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files]
self.im_cache_dir.mkdir(parents=True, exist_ok=True)
gb = 0 # Gigabytes of cached images
self.img_hw0, self.img_hw = [None] * n, [None] * n
results = ThreadPool(NUM_THREADS).imap(self.load_image, range(n))
self.im_hw0, self.im_hw = [None] * n, [None] * n
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
pbar = tqdm(enumerate(results), total=n)
for i, x in pbar:
if cache_images == 'disk':
if not self.img_npy[i].exists():
np.save(self.img_npy[i].as_posix(), x[0])
gb += self.img_npy[i].stat().st_size
gb += self.npy_files[i].stat().st_size
else: # 'ram'
self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.imgs[i].nbytes
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
gb += self.ims[i].nbytes
pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
pbar.close()

Expand All @@ -512,8 +508,8 @@ def cache_labels(self, path=Path('./labels.cache'), prefix=''):
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
with Pool(NUM_THREADS) as pool:
pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
desc=desc, total=len(self.img_files))
pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
desc=desc, total=len(self.im_files))
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
nm += nm_f
nf += nf_f
Expand All @@ -530,8 +526,8 @@ def cache_labels(self, path=Path('./labels.cache'), prefix=''):
LOGGER.info('\n'.join(msgs))
if nf == 0:
LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = nf, nm, ne, nc, len(self.img_files)
x['hash'] = get_hash(self.label_files + self.im_files)
x['results'] = nf, nm, ne, nc, len(self.im_files)
x['msgs'] = msgs # warnings
x['version'] = self.cache_version # cache version
try:
Expand All @@ -543,7 +539,7 @@ def cache_labels(self, path=Path('./labels.cache'), prefix=''):
return x

def __len__(self):
return len(self.img_files)
return len(self.im_files)

# def __iter__(self):
# self.count = -1
Expand Down Expand Up @@ -622,17 +618,15 @@ def __getitem__(self, index):
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)

return torch.from_numpy(img), labels_out, self.img_files[index], shapes
return torch.from_numpy(img), labels_out, self.im_files[index], shapes

def load_image(self, i):
# loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
im = self.imgs[i]
# Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
if im is None: # not cached in RAM
npy = self.img_npy[i]
if npy and npy.exists(): # load npy
im = np.load(npy)
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
f = self.img_files[i]
im = cv2.imread(f) # BGR
assert im is not None, f'Image Not Found {f}'
h0, w0 = im.shape[:2] # orig hw
Expand All @@ -643,7 +637,13 @@ def load_image(self, i):
interpolation=cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
else:
return self.imgs[i], self.img_hw0[i], self.img_hw[i] # im, hw_original, hw_resized
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized

def cache_images_to_disk(self, i):
# Saves an image as an *.npy file for faster loading
f = self.npy_files[i]
if not f.exists():
np.save(f.as_posix(), cv2.imread(self.im_files[i]))

def load_mosaic(self, index):
# YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
Expand Down Expand Up @@ -777,16 +777,16 @@ def load_mosaic9(self, index):

@staticmethod
def collate_fn(batch):
img, label, path, shapes = zip(*batch) # transposed
im, label, path, shapes = zip(*batch) # transposed
for i, lb in enumerate(label):
lb[:, 0] = i # add target image index for build_targets()
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
return torch.stack(im, 0), torch.cat(label, 0), path, shapes

@staticmethod
def collate_fn4(batch):
img, label, path, shapes = zip(*batch) # transposed
n = len(shapes) // 4
img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]

ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
Expand All @@ -800,13 +800,13 @@ def collate_fn4(batch):
else:
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
img4.append(im)
im4.append(im)
label4.append(lb)

for i, lb in enumerate(label4):
lb[:, 0] = i # add target image index for build_targets()

return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4


# Ancillary functions --------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -999,12 +999,12 @@ def hub_ops(f, max_dim=1920):
'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()},
'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
zip(dataset.img_files, dataset.labels)]}
zip(dataset.im_files, dataset.labels)]}

if hub:
im_dir = hub_dir / 'images'
im_dir.mkdir(parents=True, exist_ok=True)
for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.img_files), total=dataset.n, desc='HUB Ops'):
for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.im_files), total=dataset.n, desc='HUB Ops'):
pass

# Profile
Expand Down
2 changes: 1 addition & 1 deletion utils/loggers/wandb/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def create_dataset_table(self, dataset: LoadImagesAndLabels, class_to_id: Dict[i
# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
artifact = wandb.Artifact(name=name, type="dataset")
img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None
img_files = tqdm(dataset.img_files) if not img_files else img_files
img_files = tqdm(dataset.im_files) if not img_files else img_files
for img_file in img_files:
if Path(img_file).is_dir():
artifact.add_dir(img_file, name='data/images')
Expand Down
2 changes: 1 addition & 1 deletion val.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def run(data,
pred = anno.loadRes(pred_json) # init predictions api
eval = COCOeval(anno, pred, 'bbox')
if is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files] # image IDs to evaluate
eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.im_files] # image IDs to evaluate
eval.evaluate()
eval.accumulate()
eval.summarize()
Expand Down

0 comments on commit 4728840

Please sign in to comment.