From 6ad613955d5d7093af28a1c3dc1eae52c89532fc Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 12 Jun 2021 13:26:41 +0200 Subject: [PATCH] Update `dataset_stats()` (#3593) @KalenMike this is a PR to add image filenames and labels to our stats dictionary and to save the dictionary to JSON. Save location is next to the train labels.cache file. The single JSON contains all stats for entire dataset. Usage example: ```python from utils.datasets import * dataset_stats('coco128.yaml', verbose=True) ``` (cherry picked from commit 7a565f130a257aed46a0cac77cca945b489696bf) --- utils/datasets.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index 444b3ff2f60c..f18569a7665b 100644 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -2,6 +2,7 @@ import glob import hashlib +import json import logging import math import os @@ -1105,12 +1106,20 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False): continue x = [] dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset + if split == 'train': + cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'): x.append(np.bincount(label[:, 0].astype(int), minlength=nc)) x = np.array(x) # shape(128x80) - stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()}, - 'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()), - 'per_class': (x > 0).sum(0).tolist()}} + stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()}, + 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()), + 'per_class': (x > 0).sum(0).tolist()}, + 'labels': {str(Path(k).name): v.tolist() for k, v in zip(dataset.img_files, dataset.labels)}} + + # Save, print and return + with open(cache_path.with_suffix('.json'), 'w') as f: + json.dump(stats, f) # save stats *.json if verbose: print(yaml.dump([stats], sort_keys=False, default_flow_style=False)) + # print(json.dumps(stats, indent=2, sort_keys=False)) return stats