From 7a565f130a257aed46a0cac77cca945b489696bf Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 12 Jun 2021 13:26:41 +0200 Subject: [PATCH] Update `dataset_stats()` (#3593) @KalenMike this is a PR to add image filenames and labels to our stats dictionary and to save the dictionary to JSON. Save location is next to the train labels.cache file. The single JSON contains all stats for entire dataset. Usage example: ```python from utils.datasets import * dataset_stats('coco128.yaml', verbose=True) ``` --- utils/datasets.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/utils/datasets.py b/utils/datasets.py index 444b3ff..f18569a 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -2,6 +2,7 @@ import glob import hashlib +import json import logging import math import os @@ -1105,12 +1106,20 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False): continue x = [] dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset + if split == 'train': + cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'): x.append(np.bincount(label[:, 0].astype(int), minlength=nc)) x = np.array(x) # shape(128x80) - stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()}, - 'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()), - 'per_class': (x > 0).sum(0).tolist()}} + stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()}, + 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()), + 'per_class': (x > 0).sum(0).tolist()}, + 'labels': {str(Path(k).name): v.tolist() for k, v in zip(dataset.img_files, dataset.labels)}} + + # Save, print and return + with open(cache_path.with_suffix('.json'), 'w') as f: + json.dump(stats, f) # save stats *.json if verbose: print(yaml.dump([stats], sort_keys=False, default_flow_style=False)) + # print(json.dumps(stats, indent=2, sort_keys=False)) return stats