Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

1095 lines
47KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Dataloaders and dataset utils
  4. """
  5. import glob
  6. import hashlib
  7. import json
  8. import math
  9. import os
  10. import random
  11. import shutil
  12. import time
  13. from itertools import repeat
  14. from multiprocessing.pool import Pool, ThreadPool
  15. from pathlib import Path
  16. from threading import Thread
  17. from urllib.parse import urlparse
  18. from zipfile import ZipFile
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. import yaml
  23. from PIL import ExifTags, Image, ImageOps
  24. from torch.utils.data import DataLoader, Dataset, dataloader, distributed
  25. from tqdm import tqdm
  26. from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
  27. from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
  28. cv2, is_colab, is_kaggle, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
  29. from utils.torch_utils import torch_distributed_zero_first
  30. # Parameters
  31. HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  32. IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
  33. VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
  34. BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
  35. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  36. # Get orientation exif tag
  37. for orientation in ExifTags.TAGS.keys():
  38. if ExifTags.TAGS[orientation] == 'Orientation':
  39. break
  40. def get_hash(paths):
  41. # Returns a single hash value of a list of paths (files or dirs)
  42. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  43. h = hashlib.md5(str(size).encode()) # hash sizes
  44. h.update(''.join(paths).encode()) # hash paths
  45. return h.hexdigest() # return hash
  46. def exif_size(img):
  47. # Returns exif-corrected PIL size
  48. s = img.size # (width, height)
  49. try:
  50. rotation = dict(img._getexif().items())[orientation]
  51. if rotation in [6, 8]: # rotation 270 or 90
  52. s = (s[1], s[0])
  53. except Exception:
  54. pass
  55. return s
  56. def exif_transpose(image):
  57. """
  58. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  59. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
  60. :param image: The image to transpose.
  61. :return: An image.
  62. """
  63. exif = image.getexif()
  64. orientation = exif.get(0x0112, 1) # default 1
  65. if orientation > 1:
  66. method = {
  67. 2: Image.FLIP_LEFT_RIGHT,
  68. 3: Image.ROTATE_180,
  69. 4: Image.FLIP_TOP_BOTTOM,
  70. 5: Image.TRANSPOSE,
  71. 6: Image.ROTATE_270,
  72. 7: Image.TRANSVERSE,
  73. 8: Image.ROTATE_90,}.get(orientation)
  74. if method is not None:
  75. image = image.transpose(method)
  76. del exif[0x0112]
  77. image.info["exif"] = exif.tobytes()
  78. return image
  79. def create_dataloader(path,
  80. imgsz,
  81. batch_size,
  82. stride,
  83. single_cls=False,
  84. hyp=None,
  85. augment=False,
  86. cache=False,
  87. pad=0.0,
  88. rect=False,
  89. rank=-1,
  90. workers=8,
  91. image_weights=False,
  92. quad=False,
  93. prefix='',
  94. shuffle=False):
  95. if rect and shuffle:
  96. LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
  97. shuffle = False
  98. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  99. dataset = LoadImagesAndLabels(
  100. path,
  101. imgsz,
  102. batch_size,
  103. augment=augment, # augmentation
  104. hyp=hyp, # hyperparameters
  105. rect=rect, # rectangular batches
  106. cache_images=cache,
  107. single_cls=single_cls,
  108. stride=int(stride),
  109. pad=pad,
  110. image_weights=image_weights,
  111. prefix=prefix)
  112. batch_size = min(batch_size, len(dataset))
  113. nd = torch.cuda.device_count() # number of CUDA devices
  114. nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
  115. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  116. loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
  117. return loader(dataset,
  118. batch_size=batch_size,
  119. shuffle=shuffle and sampler is None,
  120. num_workers=nw,
  121. sampler=sampler,
  122. pin_memory=True,
  123. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
  124. class InfiniteDataLoader(dataloader.DataLoader):
  125. """ Dataloader that reuses workers
  126. Uses same syntax as vanilla DataLoader
  127. """
  128. def __init__(self, *args, **kwargs):
  129. super().__init__(*args, **kwargs)
  130. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  131. self.iterator = super().__iter__()
  132. def __len__(self):
  133. return len(self.batch_sampler.sampler)
  134. def __iter__(self):
  135. for _ in range(len(self)):
  136. yield next(self.iterator)
  137. class _RepeatSampler:
  138. """ Sampler that repeats forever
  139. Args:
  140. sampler (Sampler)
  141. """
  142. def __init__(self, sampler):
  143. self.sampler = sampler
  144. def __iter__(self):
  145. while True:
  146. yield from iter(self.sampler)
  147. class LoadImages:
  148. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  149. def __init__(self, path, img_size=640, stride=32, auto=True):
  150. p = str(Path(path).resolve()) # os-agnostic absolute path
  151. if '*' in p:
  152. files = sorted(glob.glob(p, recursive=True)) # glob
  153. elif os.path.isdir(p):
  154. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  155. elif os.path.isfile(p):
  156. files = [p] # files
  157. else:
  158. raise Exception(f'ERROR: {p} does not exist')
  159. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  160. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  161. ni, nv = len(images), len(videos)
  162. self.img_size = img_size
  163. self.stride = stride
  164. self.files = images + videos
  165. self.nf = ni + nv # number of files
  166. self.video_flag = [False] * ni + [True] * nv
  167. self.mode = 'image'
  168. self.auto = auto
  169. if any(videos):
  170. self.new_video(videos[0]) # new video
  171. else:
  172. self.cap = None
  173. assert self.nf > 0, f'No images or videos found in {p}. ' \
  174. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  175. def __iter__(self):
  176. self.count = 0
  177. return self
  178. def __next__(self):
  179. if self.count == self.nf:
  180. raise StopIteration
  181. path = self.files[self.count]
  182. if self.video_flag[self.count]:
  183. # Read video
  184. self.mode = 'video'
  185. ret_val, img0 = self.cap.read()
  186. while not ret_val:
  187. self.count += 1
  188. self.cap.release()
  189. if self.count == self.nf: # last video
  190. raise StopIteration
  191. path = self.files[self.count]
  192. self.new_video(path)
  193. ret_val, img0 = self.cap.read()
  194. self.frame += 1
  195. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  196. else:
  197. # Read image
  198. self.count += 1
  199. img0 = cv2.imread(path) # BGR
  200. assert img0 is not None, f'Image Not Found {path}'
  201. s = f'image {self.count}/{self.nf} {path}: '
  202. # Padded resize
  203. img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
  204. # Convert
  205. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  206. img = np.ascontiguousarray(img)
  207. return path, img, img0, self.cap, s
  208. def new_video(self, path):
  209. self.frame = 0
  210. self.cap = cv2.VideoCapture(path)
  211. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  212. def __len__(self):
  213. return self.nf # number of files
  214. class LoadWebcam: # for inference
  215. # YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
  216. def __init__(self, pipe='0', img_size=640, stride=32):
  217. self.img_size = img_size
  218. self.stride = stride
  219. self.pipe = eval(pipe) if pipe.isnumeric() else pipe
  220. self.cap = cv2.VideoCapture(self.pipe) # video capture object
  221. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  222. def __iter__(self):
  223. self.count = -1
  224. return self
  225. def __next__(self):
  226. self.count += 1
  227. if cv2.waitKey(1) == ord('q'): # q to quit
  228. self.cap.release()
  229. cv2.destroyAllWindows()
  230. raise StopIteration
  231. # Read frame
  232. ret_val, img0 = self.cap.read()
  233. img0 = cv2.flip(img0, 1) # flip left-right
  234. # Print
  235. assert ret_val, f'Camera Error {self.pipe}'
  236. img_path = 'webcam.jpg'
  237. s = f'webcam {self.count}: '
  238. # Padded resize
  239. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  240. # Convert
  241. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  242. img = np.ascontiguousarray(img)
  243. return img_path, img, img0, None, s
  244. def __len__(self):
  245. return 0
  246. class LoadStreams:
  247. # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
  248. def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
  249. self.mode = 'stream'
  250. self.img_size = img_size
  251. self.stride = stride
  252. if os.path.isfile(sources):
  253. with open(sources) as f:
  254. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  255. else:
  256. sources = [sources]
  257. n = len(sources)
  258. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  259. self.sources = [clean_str(x) for x in sources] # clean source names for later
  260. self.auto = auto
  261. for i, s in enumerate(sources): # index, source
  262. # Start thread to read frames from video stream
  263. st = f'{i + 1}/{n}: {s}... '
  264. if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
  265. check_requirements(('pafy', 'youtube_dl==2020.12.2'))
  266. import pafy
  267. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  268. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  269. if s == 0:
  270. assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
  271. assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
  272. cap = cv2.VideoCapture(s)
  273. assert cap.isOpened(), f'{st}Failed to open {s}'
  274. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  275. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  276. fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
  277. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  278. self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
  279. _, self.imgs[i] = cap.read() # guarantee first frame
  280. self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
  281. LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  282. self.threads[i].start()
  283. LOGGER.info('') # newline
  284. # check for common shapes
  285. s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
  286. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  287. if not self.rect:
  288. LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  289. def update(self, i, cap, stream):
  290. # Read stream `i` frames in daemon thread
  291. n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
  292. while cap.isOpened() and n < f:
  293. n += 1
  294. # _, self.imgs[index] = cap.read()
  295. cap.grab()
  296. if n % read == 0:
  297. success, im = cap.retrieve()
  298. if success:
  299. self.imgs[i] = im
  300. else:
  301. LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
  302. self.imgs[i] = np.zeros_like(self.imgs[i])
  303. cap.open(stream) # re-open stream if signal was lost
  304. time.sleep(0.0) # wait time
  305. def __iter__(self):
  306. self.count = -1
  307. return self
  308. def __next__(self):
  309. self.count += 1
  310. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  311. cv2.destroyAllWindows()
  312. raise StopIteration
  313. # Letterbox
  314. img0 = self.imgs.copy()
  315. img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
  316. # Stack
  317. img = np.stack(img, 0)
  318. # Convert
  319. img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
  320. img = np.ascontiguousarray(img)
  321. return self.sources, img, img0, None, ''
  322. def __len__(self):
  323. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  324. def img2label_paths(img_paths):
  325. # Define label paths as a function of image paths
  326. sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
  327. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  328. class LoadImagesAndLabels(Dataset):
  329. # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
  330. cache_version = 0.6 # dataset labels *.cache version
  331. rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
  332. def __init__(self,
  333. path,
  334. img_size=640,
  335. batch_size=16,
  336. augment=False,
  337. hyp=None,
  338. rect=False,
  339. image_weights=False,
  340. cache_images=False,
  341. single_cls=False,
  342. stride=32,
  343. pad=0.0,
  344. prefix=''):
  345. self.img_size = img_size
  346. self.augment = augment
  347. self.hyp = hyp
  348. self.image_weights = image_weights
  349. self.rect = False if image_weights else rect
  350. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  351. self.mosaic_border = [-img_size // 2, -img_size // 2]
  352. self.stride = stride
  353. self.path = path
  354. self.albumentations = Albumentations() if augment else None
  355. try:
  356. f = [] # image files
  357. for p in path if isinstance(path, list) else [path]:
  358. p = Path(p) # os-agnostic
  359. if p.is_dir(): # dir
  360. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  361. # f = list(p.rglob('*.*')) # pathlib
  362. elif p.is_file(): # file
  363. with open(p) as t:
  364. t = t.read().strip().splitlines()
  365. parent = str(p.parent) + os.sep
  366. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  367. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  368. else:
  369. raise Exception(f'{prefix}{p} does not exist')
  370. self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
  371. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  372. assert self.im_files, f'{prefix}No images found'
  373. except Exception as e:
  374. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
  375. # Check cache
  376. self.label_files = img2label_paths(self.im_files) # labels
  377. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
  378. try:
  379. cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
  380. assert cache['version'] == self.cache_version # same version
  381. assert cache['hash'] == get_hash(self.label_files + self.im_files) # same hash
  382. except Exception:
  383. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  384. # Display cache
  385. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
  386. if exists and LOCAL_RANK in {-1, 0}:
  387. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  388. tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
  389. if cache['msgs']:
  390. LOGGER.info('\n'.join(cache['msgs'])) # display warnings
  391. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
  392. # Read cache
  393. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  394. labels, shapes, self.segments = zip(*cache.values())
  395. self.labels = list(labels)
  396. self.shapes = np.array(shapes, dtype=np.float64)
  397. self.im_files = list(cache.keys()) # update
  398. self.label_files = img2label_paths(cache.keys()) # update
  399. n = len(shapes) # number of images
  400. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  401. nb = bi[-1] + 1 # number of batches
  402. self.batch = bi # batch index of image
  403. self.n = n
  404. self.indices = range(n)
  405. # Update labels
  406. include_class = [] # filter labels to include only these classes (optional)
  407. include_class_array = np.array(include_class).reshape(1, -1)
  408. for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
  409. if include_class:
  410. j = (label[:, 0:1] == include_class_array).any(1)
  411. self.labels[i] = label[j]
  412. if segment:
  413. self.segments[i] = segment[j]
  414. if single_cls: # single-class training, merge all classes into 0
  415. self.labels[i][:, 0] = 0
  416. if segment:
  417. self.segments[i][:, 0] = 0
  418. # Rectangular Training
  419. if self.rect:
  420. # Sort by aspect ratio
  421. s = self.shapes # wh
  422. ar = s[:, 1] / s[:, 0] # aspect ratio
  423. irect = ar.argsort()
  424. self.im_files = [self.im_files[i] for i in irect]
  425. self.label_files = [self.label_files[i] for i in irect]
  426. self.labels = [self.labels[i] for i in irect]
  427. self.shapes = s[irect] # wh
  428. ar = ar[irect]
  429. # Set training image shapes
  430. shapes = [[1, 1]] * nb
  431. for i in range(nb):
  432. ari = ar[bi == i]
  433. mini, maxi = ari.min(), ari.max()
  434. if maxi < 1:
  435. shapes[i] = [maxi, 1]
  436. elif mini > 1:
  437. shapes[i] = [1, 1 / mini]
  438. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  439. # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources)
  440. self.ims = [None] * n
  441. self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
  442. if cache_images:
  443. gb = 0 # Gigabytes of cached images
  444. self.im_hw0, self.im_hw = [None] * n, [None] * n
  445. fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
  446. results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
  447. pbar = tqdm(enumerate(results), total=n, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
  448. for i, x in pbar:
  449. if cache_images == 'disk':
  450. gb += self.npy_files[i].stat().st_size
  451. else: # 'ram'
  452. self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  453. gb += self.ims[i].nbytes
  454. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
  455. pbar.close()
  456. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  457. # Cache dataset labels, check images and read shapes
  458. x = {} # dict
  459. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  460. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  461. with Pool(NUM_THREADS) as pool:
  462. pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
  463. desc=desc,
  464. total=len(self.im_files),
  465. bar_format=BAR_FORMAT)
  466. for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  467. nm += nm_f
  468. nf += nf_f
  469. ne += ne_f
  470. nc += nc_f
  471. if im_file:
  472. x[im_file] = [lb, shape, segments]
  473. if msg:
  474. msgs.append(msg)
  475. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  476. pbar.close()
  477. if msgs:
  478. LOGGER.info('\n'.join(msgs))
  479. if nf == 0:
  480. LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
  481. x['hash'] = get_hash(self.label_files + self.im_files)
  482. x['results'] = nf, nm, ne, nc, len(self.im_files)
  483. x['msgs'] = msgs # warnings
  484. x['version'] = self.cache_version # cache version
  485. try:
  486. np.save(path, x) # save cache for next time
  487. path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
  488. LOGGER.info(f'{prefix}New cache created: {path}')
  489. except Exception as e:
  490. LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable
  491. return x
  492. def __len__(self):
  493. return len(self.im_files)
  494. # def __iter__(self):
  495. # self.count = -1
  496. # print('ran dataset iter')
  497. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  498. # return self
  499. def __getitem__(self, index):
  500. index = self.indices[index] # linear, shuffled, or image_weights
  501. hyp = self.hyp
  502. mosaic = self.mosaic and random.random() < hyp['mosaic']
  503. if mosaic:
  504. # Load mosaic
  505. img, labels = self.load_mosaic(index)
  506. shapes = None
  507. # MixUp augmentation
  508. if random.random() < hyp['mixup']:
  509. img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
  510. else:
  511. # Load image
  512. img, (h0, w0), (h, w) = self.load_image(index)
  513. # Letterbox
  514. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  515. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  516. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  517. labels = self.labels[index].copy()
  518. if labels.size: # normalized xywh to pixel xyxy format
  519. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  520. if self.augment:
  521. img, labels = random_perspective(img,
  522. labels,
  523. degrees=hyp['degrees'],
  524. translate=hyp['translate'],
  525. scale=hyp['scale'],
  526. shear=hyp['shear'],
  527. perspective=hyp['perspective'])
  528. nl = len(labels) # number of labels
  529. if nl:
  530. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
  531. if self.augment:
  532. # Albumentations
  533. img, labels = self.albumentations(img, labels)
  534. nl = len(labels) # update after albumentations
  535. # HSV color-space
  536. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  537. # Flip up-down
  538. if random.random() < hyp['flipud']:
  539. img = np.flipud(img)
  540. if nl:
  541. labels[:, 2] = 1 - labels[:, 2]
  542. # Flip left-right
  543. if random.random() < hyp['fliplr']:
  544. img = np.fliplr(img)
  545. if nl:
  546. labels[:, 1] = 1 - labels[:, 1]
  547. # Cutouts
  548. # labels = cutout(img, labels, p=0.5)
  549. # nl = len(labels) # update after cutout
  550. labels_out = torch.zeros((nl, 6))
  551. if nl:
  552. labels_out[:, 1:] = torch.from_numpy(labels)
  553. # Convert
  554. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  555. img = np.ascontiguousarray(img)
  556. return torch.from_numpy(img), labels_out, self.im_files[index], shapes
  557. def load_image(self, i):
  558. # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
  559. im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
  560. if im is None: # not cached in RAM
  561. if fn.exists(): # load npy
  562. im = np.load(fn)
  563. else: # read image
  564. im = cv2.imread(f) # BGR
  565. assert im is not None, f'Image Not Found {f}'
  566. h0, w0 = im.shape[:2] # orig hw
  567. r = self.img_size / max(h0, w0) # ratio
  568. if r != 1: # if sizes are not equal
  569. interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
  570. im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
  571. return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  572. else:
  573. return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
  574. def cache_images_to_disk(self, i):
  575. # Saves an image as an *.npy file for faster loading
  576. f = self.npy_files[i]
  577. if not f.exists():
  578. np.save(f.as_posix(), cv2.imread(self.im_files[i]))
  579. def load_mosaic(self, index):
  580. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  581. labels4, segments4 = [], []
  582. s = self.img_size
  583. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
  584. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  585. random.shuffle(indices)
  586. for i, index in enumerate(indices):
  587. # Load image
  588. img, _, (h, w) = self.load_image(index)
  589. # place img in img4
  590. if i == 0: # top left
  591. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  592. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  593. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  594. elif i == 1: # top right
  595. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  596. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  597. elif i == 2: # bottom left
  598. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  599. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  600. elif i == 3: # bottom right
  601. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  602. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  603. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  604. padw = x1a - x1b
  605. padh = y1a - y1b
  606. # Labels
  607. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  608. if labels.size:
  609. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  610. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  611. labels4.append(labels)
  612. segments4.extend(segments)
  613. # Concat/clip labels
  614. labels4 = np.concatenate(labels4, 0)
  615. for x in (labels4[:, 1:], *segments4):
  616. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  617. # img4, labels4 = replicate(img4, labels4) # replicate
  618. # Augment
  619. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
  620. img4, labels4 = random_perspective(img4,
  621. labels4,
  622. segments4,
  623. degrees=self.hyp['degrees'],
  624. translate=self.hyp['translate'],
  625. scale=self.hyp['scale'],
  626. shear=self.hyp['shear'],
  627. perspective=self.hyp['perspective'],
  628. border=self.mosaic_border) # border to remove
  629. return img4, labels4
  630. def load_mosaic9(self, index):
  631. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  632. labels9, segments9 = [], []
  633. s = self.img_size
  634. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  635. random.shuffle(indices)
  636. hp, wp = -1, -1 # height, width previous
  637. for i, index in enumerate(indices):
  638. # Load image
  639. img, _, (h, w) = self.load_image(index)
  640. # place img in img9
  641. if i == 0: # center
  642. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  643. h0, w0 = h, w
  644. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  645. elif i == 1: # top
  646. c = s, s - h, s + w, s
  647. elif i == 2: # top right
  648. c = s + wp, s - h, s + wp + w, s
  649. elif i == 3: # right
  650. c = s + w0, s, s + w0 + w, s + h
  651. elif i == 4: # bottom right
  652. c = s + w0, s + hp, s + w0 + w, s + hp + h
  653. elif i == 5: # bottom
  654. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  655. elif i == 6: # bottom left
  656. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  657. elif i == 7: # left
  658. c = s - w, s + h0 - h, s, s + h0
  659. elif i == 8: # top left
  660. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  661. padx, pady = c[:2]
  662. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  663. # Labels
  664. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  665. if labels.size:
  666. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  667. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  668. labels9.append(labels)
  669. segments9.extend(segments)
  670. # Image
  671. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  672. hp, wp = h, w # height, width previous
  673. # Offset
  674. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  675. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  676. # Concat/clip labels
  677. labels9 = np.concatenate(labels9, 0)
  678. labels9[:, [1, 3]] -= xc
  679. labels9[:, [2, 4]] -= yc
  680. c = np.array([xc, yc]) # centers
  681. segments9 = [x - c for x in segments9]
  682. for x in (labels9[:, 1:], *segments9):
  683. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  684. # img9, labels9 = replicate(img9, labels9) # replicate
  685. # Augment
  686. img9, labels9 = random_perspective(img9,
  687. labels9,
  688. segments9,
  689. degrees=self.hyp['degrees'],
  690. translate=self.hyp['translate'],
  691. scale=self.hyp['scale'],
  692. shear=self.hyp['shear'],
  693. perspective=self.hyp['perspective'],
  694. border=self.mosaic_border) # border to remove
  695. return img9, labels9
  696. @staticmethod
  697. def collate_fn(batch):
  698. im, label, path, shapes = zip(*batch) # transposed
  699. for i, lb in enumerate(label):
  700. lb[:, 0] = i # add target image index for build_targets()
  701. return torch.stack(im, 0), torch.cat(label, 0), path, shapes
  702. @staticmethod
  703. def collate_fn4(batch):
  704. img, label, path, shapes = zip(*batch) # transposed
  705. n = len(shapes) // 4
  706. im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  707. ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
  708. wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
  709. s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
  710. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  711. i *= 4
  712. if random.random() < 0.5:
  713. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
  714. align_corners=False)[0].type(img[i].type())
  715. lb = label[i]
  716. else:
  717. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  718. lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  719. im4.append(im)
  720. label4.append(lb)
  721. for i, lb in enumerate(label4):
  722. lb[:, 0] = i # add target image index for build_targets()
  723. return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
  724. # Ancillary functions --------------------------------------------------------------------------------------------------
  725. def create_folder(path='./new'):
  726. # Create folder
  727. if os.path.exists(path):
  728. shutil.rmtree(path) # delete output folder
  729. os.makedirs(path) # make new output folder
  730. def flatten_recursive(path=DATASETS_DIR / 'coco128'):
  731. # Flatten a recursive directory by bringing all files to top level
  732. new_path = Path(str(path) + '_flat')
  733. create_folder(new_path)
  734. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  735. shutil.copyfile(file, new_path / Path(file).name)
  736. def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
  737. # Convert detection dataset into classification dataset, with one directory per class
  738. path = Path(path) # images dir
  739. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  740. files = list(path.rglob('*.*'))
  741. n = len(files) # number of files
  742. for im_file in tqdm(files, total=n):
  743. if im_file.suffix[1:] in IMG_FORMATS:
  744. # image
  745. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  746. h, w = im.shape[:2]
  747. # labels
  748. lb_file = Path(img2label_paths([str(im_file)])[0])
  749. if Path(lb_file).exists():
  750. with open(lb_file) as f:
  751. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  752. for j, x in enumerate(lb):
  753. c = int(x[0]) # class
  754. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  755. if not f.parent.is_dir():
  756. f.parent.mkdir(parents=True)
  757. b = x[1:] * [w, h, w, h] # box
  758. # b[2:] = b[2:].max() # rectangle to square
  759. b[2:] = b[2:] * 1.2 + 3 # pad
  760. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  761. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  762. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  763. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  764. def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  765. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  766. Usage: from utils.dataloaders import *; autosplit()
  767. Arguments
  768. path: Path to images directory
  769. weights: Train, val, test weights (list, tuple)
  770. annotated_only: Only use images with an annotated txt file
  771. """
  772. path = Path(path) # images dir
  773. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  774. n = len(files) # number of files
  775. random.seed(0) # for reproducibility
  776. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  777. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  778. [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
  779. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  780. for i, img in tqdm(zip(indices, files), total=n):
  781. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  782. with open(path.parent / txt[i], 'a') as f:
  783. f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
  784. def verify_image_label(args):
  785. # Verify one image-label pair
  786. im_file, lb_file, prefix = args
  787. nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
  788. try:
  789. # verify images
  790. im = Image.open(im_file)
  791. im.verify() # PIL verify
  792. shape = exif_size(im) # image size
  793. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  794. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  795. if im.format.lower() in ('jpg', 'jpeg'):
  796. with open(im_file, 'rb') as f:
  797. f.seek(-2, 2)
  798. if f.read() != b'\xff\xd9': # corrupt JPEG
  799. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  800. msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'
  801. # verify labels
  802. if os.path.isfile(lb_file):
  803. nf = 1 # label found
  804. with open(lb_file) as f:
  805. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  806. if any(len(x) > 6 for x in lb): # is segment
  807. classes = np.array([x[0] for x in lb], dtype=np.float32)
  808. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  809. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  810. lb = np.array(lb, dtype=np.float32)
  811. nl = len(lb)
  812. if nl:
  813. assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
  814. assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
  815. assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
  816. _, i = np.unique(lb, axis=0, return_index=True)
  817. if len(i) < nl: # duplicate row check
  818. lb = lb[i] # remove duplicates
  819. if segments:
  820. segments = segments[i]
  821. msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
  822. else:
  823. ne = 1 # label empty
  824. lb = np.zeros((0, 5), dtype=np.float32)
  825. else:
  826. nm = 1 # label missing
  827. lb = np.zeros((0, 5), dtype=np.float32)
  828. return im_file, lb, shape, segments, nm, nf, ne, nc, msg
  829. except Exception as e:
  830. nc = 1
  831. msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
  832. return [None, None, None, None, nm, nf, ne, nc, msg]
  833. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
  834. """ Return dataset statistics dictionary with images and instances counts per split per class
  835. To run in parent directory: export PYTHONPATH="$PWD/yolov5"
  836. Usage1: from utils.dataloaders import *; dataset_stats('coco128.yaml', autodownload=True)
  837. Usage2: from utils.dataloaders import *; dataset_stats('path/to/coco128_with_yaml.zip')
  838. Arguments
  839. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  840. autodownload: Attempt to download dataset if not found locally
  841. verbose: Print stats dictionary
  842. """
  843. def _round_labels(labels):
  844. # Update labels to integer class and 6 decimal place floats
  845. return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
  846. def _find_yaml(dir):
  847. # Return data.yaml file
  848. files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
  849. assert files, f'No *.yaml file found in {dir}'
  850. if len(files) > 1:
  851. files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
  852. assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
  853. assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
  854. return files[0]
  855. def _unzip(path):
  856. # Unzip data.zip
  857. if str(path).endswith('.zip'): # path is data.zip
  858. assert Path(path).is_file(), f'Error unzipping {path}, file not found'
  859. ZipFile(path).extractall(path=path.parent) # unzip
  860. dir = path.with_suffix('') # dataset directory == zip name
  861. assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
  862. return True, str(dir), _find_yaml(dir) # zipped, data_dir, yaml_path
  863. else: # path is data.yaml
  864. return False, None, path
  865. def _hub_ops(f, max_dim=1920):
  866. # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
  867. f_new = im_dir / Path(f).name # dataset-hub image filename
  868. try: # use PIL
  869. im = Image.open(f)
  870. r = max_dim / max(im.height, im.width) # ratio
  871. if r < 1.0: # image too large
  872. im = im.resize((int(im.width * r), int(im.height * r)))
  873. im.save(f_new, 'JPEG', quality=75, optimize=True) # save
  874. except Exception as e: # use OpenCV
  875. print(f'WARNING: HUB ops PIL failure {f}: {e}')
  876. im = cv2.imread(f)
  877. im_height, im_width = im.shape[:2]
  878. r = max_dim / max(im_height, im_width) # ratio
  879. if r < 1.0: # image too large
  880. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  881. cv2.imwrite(str(f_new), im)
  882. zipped, data_dir, yaml_path = _unzip(Path(path))
  883. try:
  884. with open(check_yaml(yaml_path), errors='ignore') as f:
  885. data = yaml.safe_load(f) # data dict
  886. if zipped:
  887. data['path'] = data_dir # TODO: should this be dir.resolve()?`
  888. except Exception:
  889. raise Exception("error/HUB/dataset_stats/yaml_load")
  890. check_dataset(data, autodownload) # download dataset if missing
  891. hub_dir = Path(data['path'] + ('-hub' if hub else ''))
  892. stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
  893. for split in 'train', 'val', 'test':
  894. if data.get(split) is None:
  895. stats[split] = None # i.e. no test set
  896. continue
  897. x = []
  898. dataset = LoadImagesAndLabels(data[split]) # load dataset
  899. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  900. x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
  901. x = np.array(x) # shape(128x80)
  902. stats[split] = {
  903. 'instance_stats': {
  904. 'total': int(x.sum()),
  905. 'per_class': x.sum(0).tolist()},
  906. 'image_stats': {
  907. 'total': dataset.n,
  908. 'unlabelled': int(np.all(x == 0, 1).sum()),
  909. 'per_class': (x > 0).sum(0).tolist()},
  910. 'labels': [{
  911. str(Path(k).name): _round_labels(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
  912. if hub:
  913. im_dir = hub_dir / 'images'
  914. im_dir.mkdir(parents=True, exist_ok=True)
  915. for _ in tqdm(ThreadPool(NUM_THREADS).imap(_hub_ops, dataset.im_files), total=dataset.n, desc='HUB Ops'):
  916. pass
  917. # Profile
  918. stats_path = hub_dir / 'stats.json'
  919. if profile:
  920. for _ in range(1):
  921. file = stats_path.with_suffix('.npy')
  922. t1 = time.time()
  923. np.save(file, stats)
  924. t2 = time.time()
  925. x = np.load(file, allow_pickle=True)
  926. print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  927. file = stats_path.with_suffix('.json')
  928. t1 = time.time()
  929. with open(file, 'w') as f:
  930. json.dump(stats, f) # save stats *.json
  931. t2 = time.time()
  932. with open(file) as f:
  933. x = json.load(f) # load hyps dict
  934. print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  935. # Save, print and return
  936. if hub:
  937. print(f'Saving {stats_path.resolve()}...')
  938. with open(stats_path, 'w') as f:
  939. json.dump(stats, f) # save stats.json
  940. if verbose:
  941. print(json.dumps(stats, indent=2, sort_keys=False))
  942. return stats