You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1079 lines
46KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Dataloaders and dataset utils
  4. """
  5. import glob
  6. import hashlib
  7. import json
  8. import math
  9. import os
  10. import random
  11. import shutil
  12. import time
  13. from itertools import repeat
  14. from multiprocessing.pool import Pool, ThreadPool
  15. from pathlib import Path
  16. from threading import Thread
  17. from urllib.parse import urlparse
  18. from zipfile import ZipFile
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. import yaml
  23. from PIL import ExifTags, Image, ImageOps
  24. from torch.utils.data import DataLoader, Dataset, dataloader, distributed
  25. from tqdm.auto import tqdm
  26. from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
  27. from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
  28. cv2, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
  29. from utils.torch_utils import torch_distributed_zero_first
  30. # Parameters
  31. HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  32. IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
  33. VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
  34. BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
  35. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  36. # Get orientation exif tag
  37. for orientation in ExifTags.TAGS.keys():
  38. if ExifTags.TAGS[orientation] == 'Orientation':
  39. break
  40. def get_hash(paths):
  41. # Returns a single hash value of a list of paths (files or dirs)
  42. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  43. h = hashlib.md5(str(size).encode()) # hash sizes
  44. h.update(''.join(paths).encode()) # hash paths
  45. return h.hexdigest() # return hash
  46. def exif_size(img):
  47. # Returns exif-corrected PIL size
  48. s = img.size # (width, height)
  49. try:
  50. rotation = dict(img._getexif().items())[orientation]
  51. if rotation == 6: # rotation 270
  52. s = (s[1], s[0])
  53. elif rotation == 8: # rotation 90
  54. s = (s[1], s[0])
  55. except Exception:
  56. pass
  57. return s
  58. def exif_transpose(image):
  59. """
  60. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  61. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
  62. :param image: The image to transpose.
  63. :return: An image.
  64. """
  65. exif = image.getexif()
  66. orientation = exif.get(0x0112, 1) # default 1
  67. if orientation > 1:
  68. method = {
  69. 2: Image.FLIP_LEFT_RIGHT,
  70. 3: Image.ROTATE_180,
  71. 4: Image.FLIP_TOP_BOTTOM,
  72. 5: Image.TRANSPOSE,
  73. 6: Image.ROTATE_270,
  74. 7: Image.TRANSVERSE,
  75. 8: Image.ROTATE_90,}.get(orientation)
  76. if method is not None:
  77. image = image.transpose(method)
  78. del exif[0x0112]
  79. image.info["exif"] = exif.tobytes()
  80. return image
  81. def create_dataloader(path,
  82. imgsz,
  83. batch_size,
  84. stride,
  85. single_cls=False,
  86. hyp=None,
  87. augment=False,
  88. cache=False,
  89. pad=0.0,
  90. rect=False,
  91. rank=-1,
  92. workers=8,
  93. image_weights=False,
  94. quad=False,
  95. prefix='',
  96. shuffle=False):
  97. if rect and shuffle:
  98. LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
  99. shuffle = False
  100. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  101. dataset = LoadImagesAndLabels(
  102. path,
  103. imgsz,
  104. batch_size,
  105. augment=augment, # augmentation
  106. hyp=hyp, # hyperparameters
  107. rect=rect, # rectangular batches
  108. cache_images=cache,
  109. single_cls=single_cls,
  110. stride=int(stride),
  111. pad=pad,
  112. image_weights=image_weights,
  113. prefix=prefix)
  114. batch_size = min(batch_size, len(dataset))
  115. nd = torch.cuda.device_count() # number of CUDA devices
  116. nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
  117. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  118. loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
  119. return loader(dataset,
  120. batch_size=batch_size,
  121. shuffle=shuffle and sampler is None,
  122. num_workers=nw,
  123. sampler=sampler,
  124. pin_memory=True,
  125. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
  126. class InfiniteDataLoader(dataloader.DataLoader):
  127. """ Dataloader that reuses workers
  128. Uses same syntax as vanilla DataLoader
  129. """
  130. def __init__(self, *args, **kwargs):
  131. super().__init__(*args, **kwargs)
  132. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  133. self.iterator = super().__iter__()
  134. def __len__(self):
  135. return len(self.batch_sampler.sampler)
  136. def __iter__(self):
  137. for i in range(len(self)):
  138. yield next(self.iterator)
  139. class _RepeatSampler:
  140. """ Sampler that repeats forever
  141. Args:
  142. sampler (Sampler)
  143. """
  144. def __init__(self, sampler):
  145. self.sampler = sampler
  146. def __iter__(self):
  147. while True:
  148. yield from iter(self.sampler)
  149. class LoadImages:
  150. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  151. def __init__(self, path, img_size=640, stride=32, auto=True):
  152. p = str(Path(path).resolve()) # os-agnostic absolute path
  153. if '*' in p:
  154. files = sorted(glob.glob(p, recursive=True)) # glob
  155. elif os.path.isdir(p):
  156. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  157. elif os.path.isfile(p):
  158. files = [p] # files
  159. else:
  160. raise Exception(f'ERROR: {p} does not exist')
  161. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  162. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  163. ni, nv = len(images), len(videos)
  164. self.img_size = img_size
  165. self.stride = stride
  166. self.files = images + videos
  167. self.nf = ni + nv # number of files
  168. self.video_flag = [False] * ni + [True] * nv
  169. self.mode = 'image'
  170. self.auto = auto
  171. if any(videos):
  172. self.new_video(videos[0]) # new video
  173. else:
  174. self.cap = None
  175. assert self.nf > 0, f'No images or videos found in {p}. ' \
  176. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  177. def __iter__(self):
  178. self.count = 0
  179. return self
  180. def __next__(self):
  181. if self.count == self.nf:
  182. raise StopIteration
  183. path = self.files[self.count]
  184. if self.video_flag[self.count]:
  185. # Read video
  186. self.mode = 'video'
  187. ret_val, img0 = self.cap.read()
  188. while not ret_val:
  189. self.count += 1
  190. self.cap.release()
  191. if self.count == self.nf: # last video
  192. raise StopIteration
  193. else:
  194. path = self.files[self.count]
  195. self.new_video(path)
  196. ret_val, img0 = self.cap.read()
  197. self.frame += 1
  198. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  199. else:
  200. # Read image
  201. self.count += 1
  202. img0 = cv2.imread(path) # BGR
  203. assert img0 is not None, f'Image Not Found {path}'
  204. s = f'image {self.count}/{self.nf} {path}: '
  205. # Padded resize
  206. img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
  207. # Convert
  208. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  209. img = np.ascontiguousarray(img)
  210. return path, img, img0, self.cap, s
  211. def new_video(self, path):
  212. self.frame = 0
  213. self.cap = cv2.VideoCapture(path)
  214. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  215. def __len__(self):
  216. return self.nf # number of files
  217. class LoadWebcam: # for inference
  218. # YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
  219. def __init__(self, pipe='0', img_size=640, stride=32):
  220. self.img_size = img_size
  221. self.stride = stride
  222. self.pipe = eval(pipe) if pipe.isnumeric() else pipe
  223. self.cap = cv2.VideoCapture(self.pipe) # video capture object
  224. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  225. def __iter__(self):
  226. self.count = -1
  227. return self
  228. def __next__(self):
  229. self.count += 1
  230. if cv2.waitKey(1) == ord('q'): # q to quit
  231. self.cap.release()
  232. cv2.destroyAllWindows()
  233. raise StopIteration
  234. # Read frame
  235. ret_val, img0 = self.cap.read()
  236. img0 = cv2.flip(img0, 1) # flip left-right
  237. # Print
  238. assert ret_val, f'Camera Error {self.pipe}'
  239. img_path = 'webcam.jpg'
  240. s = f'webcam {self.count}: '
  241. # Padded resize
  242. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  243. # Convert
  244. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  245. img = np.ascontiguousarray(img)
  246. return img_path, img, img0, None, s
  247. def __len__(self):
  248. return 0
  249. class LoadStreams:
  250. # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
  251. def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
  252. self.mode = 'stream'
  253. self.img_size = img_size
  254. self.stride = stride
  255. if os.path.isfile(sources):
  256. with open(sources) as f:
  257. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  258. else:
  259. sources = [sources]
  260. n = len(sources)
  261. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  262. self.sources = [clean_str(x) for x in sources] # clean source names for later
  263. self.auto = auto
  264. for i, s in enumerate(sources): # index, source
  265. # Start thread to read frames from video stream
  266. st = f'{i + 1}/{n}: {s}... '
  267. if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
  268. check_requirements(('pafy', 'youtube_dl==2020.12.2'))
  269. import pafy
  270. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  271. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  272. cap = cv2.VideoCapture(s)
  273. assert cap.isOpened(), f'{st}Failed to open {s}'
  274. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  275. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  276. fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
  277. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  278. self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
  279. _, self.imgs[i] = cap.read() # guarantee first frame
  280. self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
  281. LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  282. self.threads[i].start()
  283. LOGGER.info('') # newline
  284. # check for common shapes
  285. s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
  286. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  287. if not self.rect:
  288. LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  289. def update(self, i, cap, stream):
  290. # Read stream `i` frames in daemon thread
  291. n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
  292. while cap.isOpened() and n < f:
  293. n += 1
  294. # _, self.imgs[index] = cap.read()
  295. cap.grab()
  296. if n % read == 0:
  297. success, im = cap.retrieve()
  298. if success:
  299. self.imgs[i] = im
  300. else:
  301. LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
  302. self.imgs[i] = np.zeros_like(self.imgs[i])
  303. cap.open(stream) # re-open stream if signal was lost
  304. time.sleep(1 / self.fps[i]) # wait time
  305. def __iter__(self):
  306. self.count = -1
  307. return self
  308. def __next__(self):
  309. self.count += 1
  310. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  311. cv2.destroyAllWindows()
  312. raise StopIteration
  313. # Letterbox
  314. img0 = self.imgs.copy()
  315. img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
  316. # Stack
  317. img = np.stack(img, 0)
  318. # Convert
  319. img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
  320. img = np.ascontiguousarray(img)
  321. return self.sources, img, img0, None, ''
  322. def __len__(self):
  323. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  324. def img2label_paths(img_paths):
  325. # Define label paths as a function of image paths
  326. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  327. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  328. class LoadImagesAndLabels(Dataset):
  329. # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
  330. cache_version = 0.6 # dataset labels *.cache version
  331. def __init__(self,
  332. path,
  333. img_size=640,
  334. batch_size=16,
  335. augment=False,
  336. hyp=None,
  337. rect=False,
  338. image_weights=False,
  339. cache_images=False,
  340. single_cls=False,
  341. stride=32,
  342. pad=0.0,
  343. prefix=''):
  344. self.img_size = img_size
  345. self.augment = augment
  346. self.hyp = hyp
  347. self.image_weights = image_weights
  348. self.rect = False if image_weights else rect
  349. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  350. self.mosaic_border = [-img_size // 2, -img_size // 2]
  351. self.stride = stride
  352. self.path = path
  353. self.albumentations = Albumentations() if augment else None
  354. try:
  355. f = [] # image files
  356. for p in path if isinstance(path, list) else [path]:
  357. p = Path(p) # os-agnostic
  358. if p.is_dir(): # dir
  359. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  360. # f = list(p.rglob('*.*')) # pathlib
  361. elif p.is_file(): # file
  362. with open(p) as t:
  363. t = t.read().strip().splitlines()
  364. parent = str(p.parent) + os.sep
  365. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  366. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  367. else:
  368. raise Exception(f'{prefix}{p} does not exist')
  369. self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
  370. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  371. assert self.im_files, f'{prefix}No images found'
  372. except Exception as e:
  373. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
  374. # Check cache
  375. self.label_files = img2label_paths(self.im_files) # labels
  376. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
  377. try:
  378. cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
  379. assert cache['version'] == self.cache_version # same version
  380. assert cache['hash'] == get_hash(self.label_files + self.im_files) # same hash
  381. except Exception:
  382. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  383. # Display cache
  384. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
  385. if exists and LOCAL_RANK in (-1, 0):
  386. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  387. tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
  388. if cache['msgs']:
  389. LOGGER.info('\n'.join(cache['msgs'])) # display warnings
  390. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
  391. # Read cache
  392. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  393. labels, shapes, self.segments = zip(*cache.values())
  394. self.labels = list(labels)
  395. self.shapes = np.array(shapes, dtype=np.float64)
  396. self.im_files = list(cache.keys()) # update
  397. self.label_files = img2label_paths(cache.keys()) # update
  398. n = len(shapes) # number of images
  399. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  400. nb = bi[-1] + 1 # number of batches
  401. self.batch = bi # batch index of image
  402. self.n = n
  403. self.indices = range(n)
  404. # Update labels
  405. include_class = [] # filter labels to include only these classes (optional)
  406. include_class_array = np.array(include_class).reshape(1, -1)
  407. for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
  408. if include_class:
  409. j = (label[:, 0:1] == include_class_array).any(1)
  410. self.labels[i] = label[j]
  411. if segment:
  412. self.segments[i] = segment[j]
  413. if single_cls: # single-class training, merge all classes into 0
  414. self.labels[i][:, 0] = 0
  415. if segment:
  416. self.segments[i][:, 0] = 0
  417. # Rectangular Training
  418. if self.rect:
  419. # Sort by aspect ratio
  420. s = self.shapes # wh
  421. ar = s[:, 1] / s[:, 0] # aspect ratio
  422. irect = ar.argsort()
  423. self.im_files = [self.im_files[i] for i in irect]
  424. self.label_files = [self.label_files[i] for i in irect]
  425. self.labels = [self.labels[i] for i in irect]
  426. self.shapes = s[irect] # wh
  427. ar = ar[irect]
  428. # Set training image shapes
  429. shapes = [[1, 1]] * nb
  430. for i in range(nb):
  431. ari = ar[bi == i]
  432. mini, maxi = ari.min(), ari.max()
  433. if maxi < 1:
  434. shapes[i] = [maxi, 1]
  435. elif mini > 1:
  436. shapes[i] = [1, 1 / mini]
  437. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  438. # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources)
  439. self.ims = [None] * n
  440. self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
  441. if cache_images:
  442. gb = 0 # Gigabytes of cached images
  443. self.im_hw0, self.im_hw = [None] * n, [None] * n
  444. fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
  445. results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
  446. pbar = tqdm(enumerate(results), total=n, bar_format=BAR_FORMAT)
  447. for i, x in pbar:
  448. if cache_images == 'disk':
  449. gb += self.npy_files[i].stat().st_size
  450. else: # 'ram'
  451. self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  452. gb += self.ims[i].nbytes
  453. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
  454. pbar.close()
  455. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  456. # Cache dataset labels, check images and read shapes
  457. x = {} # dict
  458. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  459. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  460. with Pool(NUM_THREADS) as pool:
  461. pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
  462. desc=desc,
  463. total=len(self.im_files),
  464. bar_format=BAR_FORMAT)
  465. for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  466. nm += nm_f
  467. nf += nf_f
  468. ne += ne_f
  469. nc += nc_f
  470. if im_file:
  471. x[im_file] = [lb, shape, segments]
  472. if msg:
  473. msgs.append(msg)
  474. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  475. pbar.close()
  476. if msgs:
  477. LOGGER.info('\n'.join(msgs))
  478. if nf == 0:
  479. LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
  480. x['hash'] = get_hash(self.label_files + self.im_files)
  481. x['results'] = nf, nm, ne, nc, len(self.im_files)
  482. x['msgs'] = msgs # warnings
  483. x['version'] = self.cache_version # cache version
  484. try:
  485. np.save(path, x) # save cache for next time
  486. path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
  487. LOGGER.info(f'{prefix}New cache created: {path}')
  488. except Exception as e:
  489. LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable
  490. return x
  491. def __len__(self):
  492. return len(self.im_files)
  493. # def __iter__(self):
  494. # self.count = -1
  495. # print('ran dataset iter')
  496. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  497. # return self
  498. def __getitem__(self, index):
  499. index = self.indices[index] # linear, shuffled, or image_weights
  500. hyp = self.hyp
  501. mosaic = self.mosaic and random.random() < hyp['mosaic']
  502. if mosaic:
  503. # Load mosaic
  504. img, labels = self.load_mosaic(index)
  505. shapes = None
  506. # MixUp augmentation
  507. if random.random() < hyp['mixup']:
  508. img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
  509. else:
  510. # Load image
  511. img, (h0, w0), (h, w) = self.load_image(index)
  512. # Letterbox
  513. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  514. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  515. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  516. labels = self.labels[index].copy()
  517. if labels.size: # normalized xywh to pixel xyxy format
  518. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  519. if self.augment:
  520. img, labels = random_perspective(img,
  521. labels,
  522. degrees=hyp['degrees'],
  523. translate=hyp['translate'],
  524. scale=hyp['scale'],
  525. shear=hyp['shear'],
  526. perspective=hyp['perspective'])
  527. nl = len(labels) # number of labels
  528. if nl:
  529. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
  530. if self.augment:
  531. # Albumentations
  532. img, labels = self.albumentations(img, labels)
  533. nl = len(labels) # update after albumentations
  534. # HSV color-space
  535. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  536. # Flip up-down
  537. if random.random() < hyp['flipud']:
  538. img = np.flipud(img)
  539. if nl:
  540. labels[:, 2] = 1 - labels[:, 2]
  541. # Flip left-right
  542. if random.random() < hyp['fliplr']:
  543. img = np.fliplr(img)
  544. if nl:
  545. labels[:, 1] = 1 - labels[:, 1]
  546. # Cutouts
  547. # labels = cutout(img, labels, p=0.5)
  548. # nl = len(labels) # update after cutout
  549. labels_out = torch.zeros((nl, 6))
  550. if nl:
  551. labels_out[:, 1:] = torch.from_numpy(labels)
  552. # Convert
  553. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  554. img = np.ascontiguousarray(img)
  555. return torch.from_numpy(img), labels_out, self.im_files[index], shapes
  556. def load_image(self, i):
  557. # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
  558. im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
  559. if im is None: # not cached in RAM
  560. if fn.exists(): # load npy
  561. im = np.load(fn)
  562. else: # read image
  563. im = cv2.imread(f) # BGR
  564. assert im is not None, f'Image Not Found {f}'
  565. h0, w0 = im.shape[:2] # orig hw
  566. r = self.img_size / max(h0, w0) # ratio
  567. if r != 1: # if sizes are not equal
  568. im = cv2.resize(im, (int(w0 * r), int(h0 * r)),
  569. interpolation=cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA)
  570. return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  571. else:
  572. return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
  573. def cache_images_to_disk(self, i):
  574. # Saves an image as an *.npy file for faster loading
  575. f = self.npy_files[i]
  576. if not f.exists():
  577. np.save(f.as_posix(), cv2.imread(self.im_files[i]))
  578. def load_mosaic(self, index):
  579. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  580. labels4, segments4 = [], []
  581. s = self.img_size
  582. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
  583. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  584. random.shuffle(indices)
  585. for i, index in enumerate(indices):
  586. # Load image
  587. img, _, (h, w) = self.load_image(index)
  588. # place img in img4
  589. if i == 0: # top left
  590. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  591. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  592. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  593. elif i == 1: # top right
  594. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  595. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  596. elif i == 2: # bottom left
  597. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  598. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  599. elif i == 3: # bottom right
  600. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  601. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  602. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  603. padw = x1a - x1b
  604. padh = y1a - y1b
  605. # Labels
  606. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  607. if labels.size:
  608. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  609. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  610. labels4.append(labels)
  611. segments4.extend(segments)
  612. # Concat/clip labels
  613. labels4 = np.concatenate(labels4, 0)
  614. for x in (labels4[:, 1:], *segments4):
  615. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  616. # img4, labels4 = replicate(img4, labels4) # replicate
  617. # Augment
  618. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
  619. img4, labels4 = random_perspective(img4,
  620. labels4,
  621. segments4,
  622. degrees=self.hyp['degrees'],
  623. translate=self.hyp['translate'],
  624. scale=self.hyp['scale'],
  625. shear=self.hyp['shear'],
  626. perspective=self.hyp['perspective'],
  627. border=self.mosaic_border) # border to remove
  628. return img4, labels4
  629. def load_mosaic9(self, index):
  630. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  631. labels9, segments9 = [], []
  632. s = self.img_size
  633. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  634. random.shuffle(indices)
  635. hp, wp = -1, -1 # height, width previous
  636. for i, index in enumerate(indices):
  637. # Load image
  638. img, _, (h, w) = self.load_image(index)
  639. # place img in img9
  640. if i == 0: # center
  641. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  642. h0, w0 = h, w
  643. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  644. elif i == 1: # top
  645. c = s, s - h, s + w, s
  646. elif i == 2: # top right
  647. c = s + wp, s - h, s + wp + w, s
  648. elif i == 3: # right
  649. c = s + w0, s, s + w0 + w, s + h
  650. elif i == 4: # bottom right
  651. c = s + w0, s + hp, s + w0 + w, s + hp + h
  652. elif i == 5: # bottom
  653. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  654. elif i == 6: # bottom left
  655. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  656. elif i == 7: # left
  657. c = s - w, s + h0 - h, s, s + h0
  658. elif i == 8: # top left
  659. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  660. padx, pady = c[:2]
  661. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  662. # Labels
  663. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  664. if labels.size:
  665. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  666. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  667. labels9.append(labels)
  668. segments9.extend(segments)
  669. # Image
  670. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  671. hp, wp = h, w # height, width previous
  672. # Offset
  673. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  674. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  675. # Concat/clip labels
  676. labels9 = np.concatenate(labels9, 0)
  677. labels9[:, [1, 3]] -= xc
  678. labels9[:, [2, 4]] -= yc
  679. c = np.array([xc, yc]) # centers
  680. segments9 = [x - c for x in segments9]
  681. for x in (labels9[:, 1:], *segments9):
  682. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  683. # img9, labels9 = replicate(img9, labels9) # replicate
  684. # Augment
  685. img9, labels9 = random_perspective(img9,
  686. labels9,
  687. segments9,
  688. degrees=self.hyp['degrees'],
  689. translate=self.hyp['translate'],
  690. scale=self.hyp['scale'],
  691. shear=self.hyp['shear'],
  692. perspective=self.hyp['perspective'],
  693. border=self.mosaic_border) # border to remove
  694. return img9, labels9
  695. @staticmethod
  696. def collate_fn(batch):
  697. im, label, path, shapes = zip(*batch) # transposed
  698. for i, lb in enumerate(label):
  699. lb[:, 0] = i # add target image index for build_targets()
  700. return torch.stack(im, 0), torch.cat(label, 0), path, shapes
  701. @staticmethod
  702. def collate_fn4(batch):
  703. img, label, path, shapes = zip(*batch) # transposed
  704. n = len(shapes) // 4
  705. im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  706. ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
  707. wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
  708. s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
  709. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  710. i *= 4
  711. if random.random() < 0.5:
  712. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
  713. align_corners=False)[0].type(img[i].type())
  714. lb = label[i]
  715. else:
  716. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  717. lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  718. im4.append(im)
  719. label4.append(lb)
  720. for i, lb in enumerate(label4):
  721. lb[:, 0] = i # add target image index for build_targets()
  722. return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
  723. # Ancillary functions --------------------------------------------------------------------------------------------------
  724. def create_folder(path='./new'):
  725. # Create folder
  726. if os.path.exists(path):
  727. shutil.rmtree(path) # delete output folder
  728. os.makedirs(path) # make new output folder
  729. def flatten_recursive(path=DATASETS_DIR / 'coco128'):
  730. # Flatten a recursive directory by bringing all files to top level
  731. new_path = Path(str(path) + '_flat')
  732. create_folder(new_path)
  733. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  734. shutil.copyfile(file, new_path / Path(file).name)
  735. def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.datasets import *; extract_boxes()
  736. # Convert detection dataset into classification dataset, with one directory per class
  737. path = Path(path) # images dir
  738. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  739. files = list(path.rglob('*.*'))
  740. n = len(files) # number of files
  741. for im_file in tqdm(files, total=n):
  742. if im_file.suffix[1:] in IMG_FORMATS:
  743. # image
  744. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  745. h, w = im.shape[:2]
  746. # labels
  747. lb_file = Path(img2label_paths([str(im_file)])[0])
  748. if Path(lb_file).exists():
  749. with open(lb_file) as f:
  750. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  751. for j, x in enumerate(lb):
  752. c = int(x[0]) # class
  753. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  754. if not f.parent.is_dir():
  755. f.parent.mkdir(parents=True)
  756. b = x[1:] * [w, h, w, h] # box
  757. # b[2:] = b[2:].max() # rectangle to square
  758. b[2:] = b[2:] * 1.2 + 3 # pad
  759. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  760. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  761. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  762. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  763. def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  764. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  765. Usage: from utils.datasets import *; autosplit()
  766. Arguments
  767. path: Path to images directory
  768. weights: Train, val, test weights (list, tuple)
  769. annotated_only: Only use images with an annotated txt file
  770. """
  771. path = Path(path) # images dir
  772. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  773. n = len(files) # number of files
  774. random.seed(0) # for reproducibility
  775. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  776. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  777. [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
  778. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  779. for i, img in tqdm(zip(indices, files), total=n):
  780. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  781. with open(path.parent / txt[i], 'a') as f:
  782. f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
  783. def verify_image_label(args):
  784. # Verify one image-label pair
  785. im_file, lb_file, prefix = args
  786. nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
  787. try:
  788. # verify images
  789. im = Image.open(im_file)
  790. im.verify() # PIL verify
  791. shape = exif_size(im) # image size
  792. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  793. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  794. if im.format.lower() in ('jpg', 'jpeg'):
  795. with open(im_file, 'rb') as f:
  796. f.seek(-2, 2)
  797. if f.read() != b'\xff\xd9': # corrupt JPEG
  798. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  799. msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'
  800. # verify labels
  801. if os.path.isfile(lb_file):
  802. nf = 1 # label found
  803. with open(lb_file) as f:
  804. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  805. if any(len(x) > 6 for x in lb): # is segment
  806. classes = np.array([x[0] for x in lb], dtype=np.float32)
  807. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  808. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  809. lb = np.array(lb, dtype=np.float32)
  810. nl = len(lb)
  811. if nl:
  812. assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
  813. assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
  814. assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
  815. _, i = np.unique(lb, axis=0, return_index=True)
  816. if len(i) < nl: # duplicate row check
  817. lb = lb[i] # remove duplicates
  818. if segments:
  819. segments = segments[i]
  820. msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
  821. else:
  822. ne = 1 # label empty
  823. lb = np.zeros((0, 5), dtype=np.float32)
  824. else:
  825. nm = 1 # label missing
  826. lb = np.zeros((0, 5), dtype=np.float32)
  827. return im_file, lb, shape, segments, nm, nf, ne, nc, msg
  828. except Exception as e:
  829. nc = 1
  830. msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
  831. return [None, None, None, None, nm, nf, ne, nc, msg]
  832. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
  833. """ Return dataset statistics dictionary with images and instances counts per split per class
  834. To run in parent directory: export PYTHONPATH="$PWD/yolov5"
  835. Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', autodownload=True)
  836. Usage2: from utils.datasets import *; dataset_stats('path/to/coco128_with_yaml.zip')
  837. Arguments
  838. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  839. autodownload: Attempt to download dataset if not found locally
  840. verbose: Print stats dictionary
  841. """
  842. def round_labels(labels):
  843. # Update labels to integer class and 6 decimal place floats
  844. return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
  845. def unzip(path):
  846. # Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
  847. if str(path).endswith('.zip'): # path is data.zip
  848. assert Path(path).is_file(), f'Error unzipping {path}, file not found'
  849. ZipFile(path).extractall(path=path.parent) # unzip
  850. dir = path.with_suffix('') # dataset directory == zip name
  851. return True, str(dir), next(dir.rglob('*.yaml')) # zipped, data_dir, yaml_path
  852. else: # path is data.yaml
  853. return False, None, path
  854. def hub_ops(f, max_dim=1920):
  855. # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
  856. f_new = im_dir / Path(f).name # dataset-hub image filename
  857. try: # use PIL
  858. im = Image.open(f)
  859. r = max_dim / max(im.height, im.width) # ratio
  860. if r < 1.0: # image too large
  861. im = im.resize((int(im.width * r), int(im.height * r)))
  862. im.save(f_new, 'JPEG', quality=75, optimize=True) # save
  863. except Exception as e: # use OpenCV
  864. print(f'WARNING: HUB ops PIL failure {f}: {e}')
  865. im = cv2.imread(f)
  866. im_height, im_width = im.shape[:2]
  867. r = max_dim / max(im_height, im_width) # ratio
  868. if r < 1.0: # image too large
  869. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  870. cv2.imwrite(str(f_new), im)
  871. zipped, data_dir, yaml_path = unzip(Path(path))
  872. with open(check_yaml(yaml_path), errors='ignore') as f:
  873. data = yaml.safe_load(f) # data dict
  874. if zipped:
  875. data['path'] = data_dir # TODO: should this be dir.resolve()?
  876. check_dataset(data, autodownload) # download dataset if missing
  877. hub_dir = Path(data['path'] + ('-hub' if hub else ''))
  878. stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
  879. for split in 'train', 'val', 'test':
  880. if data.get(split) is None:
  881. stats[split] = None # i.e. no test set
  882. continue
  883. x = []
  884. dataset = LoadImagesAndLabels(data[split]) # load dataset
  885. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  886. x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
  887. x = np.array(x) # shape(128x80)
  888. stats[split] = {
  889. 'instance_stats': {
  890. 'total': int(x.sum()),
  891. 'per_class': x.sum(0).tolist()},
  892. 'image_stats': {
  893. 'total': dataset.n,
  894. 'unlabelled': int(np.all(x == 0, 1).sum()),
  895. 'per_class': (x > 0).sum(0).tolist()},
  896. 'labels': [{
  897. str(Path(k).name): round_labels(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
  898. if hub:
  899. im_dir = hub_dir / 'images'
  900. im_dir.mkdir(parents=True, exist_ok=True)
  901. for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.im_files), total=dataset.n, desc='HUB Ops'):
  902. pass
  903. # Profile
  904. stats_path = hub_dir / 'stats.json'
  905. if profile:
  906. for _ in range(1):
  907. file = stats_path.with_suffix('.npy')
  908. t1 = time.time()
  909. np.save(file, stats)
  910. t2 = time.time()
  911. x = np.load(file, allow_pickle=True)
  912. print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  913. file = stats_path.with_suffix('.json')
  914. t1 = time.time()
  915. with open(file, 'w') as f:
  916. json.dump(stats, f) # save stats *.json
  917. t2 = time.time()
  918. with open(file) as f:
  919. x = json.load(f) # load hyps dict
  920. print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  921. # Save, print and return
  922. if hub:
  923. print(f'Saving {stats_path.resolve()}...')
  924. with open(stats_path, 'w') as f:
  925. json.dump(stats, f) # save stats.json
  926. if verbose:
  927. print(json.dumps(stats, indent=2, sort_keys=False))
  928. return stats