Update `dataset_stats()` (#3593)

@KalenMike this is a PR to add image filenames and labels to our stats dictionary and to save the dictionary to JSON. Save location is next to the train labels.cache file. The single JSON contains all stats for entire dataset.

Usage example:
```python
from utils.datasets import *

dataset_stats('coco128.yaml', verbose=True)
```
This commit is contained in:
Glenn Jocher 2021-06-12 13:26:41 +02:00 committed by GitHub
parent 4984cf54be
commit 7a565f130a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 3 deletions

View File

@ -2,6 +2,7 @@
import glob import glob
import hashlib import hashlib
import json
import logging import logging
import math import math
import os import os
@ -1105,12 +1106,20 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
continue continue
x = [] x = []
dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
if split == 'train':
cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'): for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
x.append(np.bincount(label[:, 0].astype(int), minlength=nc)) x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
x = np.array(x) # shape(128x80) x = np.array(x) # shape(128x80)
stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()}, stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()), 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()}} 'per_class': (x > 0).sum(0).tolist()},
'labels': {str(Path(k).name): v.tolist() for k, v in zip(dataset.img_files, dataset.labels)}}
# Save, print and return
with open(cache_path.with_suffix('.json'), 'w') as f:
json.dump(stats, f) # save stats *.json
if verbose: if verbose:
print(yaml.dump([stats], sort_keys=False, default_flow_style=False)) print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
# print(json.dumps(stats, indent=2, sort_keys=False))
return stats return stats