From 09b16c68e9429f4c2462e8edef97a5c7d93ce37a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Jun 2021 23:26:55 +0200 Subject: [PATCH] Update `dataset_stats()` @KalenMike this is a PR to add image filenames and labels to our stats dictionary and to save the dictionary to JSON. Save location is next to the train labels.cache file. The single JSON contains all stats for entire dataset. Usage example: ```python from utils.datasets import * dataset_stats('coco128.yaml', verbose=True) ``` --- utils/datasets.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index 444b3ff2f60c..f18569a7665b 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -2,6 +2,7 @@ import glob import hashlib +import json import logging import math import os @@ -1105,12 +1106,20 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False): continue x = [] dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset + if split == 'train': + cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'): x.append(np.bincount(label[:, 0].astype(int), minlength=nc)) x = np.array(x) # shape(128x80) - stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()}, - 'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()), - 'per_class': (x > 0).sum(0).tolist()}} + stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()}, + 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()), + 'per_class': (x > 0).sum(0).tolist()}, + 'labels': {str(Path(k).name): v.tolist() for k, v in zip(dataset.img_files, dataset.labels)}} + + # Save, print and return + with open(cache_path.with_suffix('.json'), 'w') as f: + json.dump(stats, f) # save stats *.json if verbose: print(yaml.dump([stats], sort_keys=False, default_flow_style=False)) + # print(json.dumps(stats, indent=2, sort_keys=False)) return stats