You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1039 lines
45KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Dataloaders and dataset utils
  4. """
  5. import glob
  6. import hashlib
  7. import json
  8. import math
  9. import os
  10. import random
  11. import shutil
  12. import time
  13. from itertools import repeat
  14. from multiprocessing.pool import Pool, ThreadPool
  15. from pathlib import Path
  16. from threading import Thread
  17. from zipfile import ZipFile
  18. import cv2
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. import yaml
  23. from PIL import ExifTags, Image, ImageOps
  24. from torch.utils.data import DataLoader, Dataset, dataloader, distributed
  25. from tqdm import tqdm
  26. from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
  27. from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
  28. segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
  29. from utils.torch_utils import torch_distributed_zero_first
  30. # Parameters
  31. HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  32. IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
  33. VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
  34. BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
  35. # Get orientation exif tag
  36. for orientation in ExifTags.TAGS.keys():
  37. if ExifTags.TAGS[orientation] == 'Orientation':
  38. break
  39. def get_hash(paths):
  40. # Returns a single hash value of a list of paths (files or dirs)
  41. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  42. h = hashlib.md5(str(size).encode()) # hash sizes
  43. h.update(''.join(paths).encode()) # hash paths
  44. return h.hexdigest() # return hash
  45. def exif_size(img):
  46. # Returns exif-corrected PIL size
  47. s = img.size # (width, height)
  48. try:
  49. rotation = dict(img._getexif().items())[orientation]
  50. if rotation == 6: # rotation 270
  51. s = (s[1], s[0])
  52. elif rotation == 8: # rotation 90
  53. s = (s[1], s[0])
  54. except Exception:
  55. pass
  56. return s
  57. def exif_transpose(image):
  58. """
  59. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  60. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
  61. :param image: The image to transpose.
  62. :return: An image.
  63. """
  64. exif = image.getexif()
  65. orientation = exif.get(0x0112, 1) # default 1
  66. if orientation > 1:
  67. method = {2: Image.FLIP_LEFT_RIGHT,
  68. 3: Image.ROTATE_180,
  69. 4: Image.FLIP_TOP_BOTTOM,
  70. 5: Image.TRANSPOSE,
  71. 6: Image.ROTATE_270,
  72. 7: Image.TRANSVERSE,
  73. 8: Image.ROTATE_90,
  74. }.get(orientation)
  75. if method is not None:
  76. image = image.transpose(method)
  77. del exif[0x0112]
  78. image.info["exif"] = exif.tobytes()
  79. return image
  80. def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
  81. rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
  82. if rect and shuffle:
  83. LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
  84. shuffle = False
  85. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  86. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  87. augment=augment, # augmentation
  88. hyp=hyp, # hyperparameters
  89. rect=rect, # rectangular batches
  90. cache_images=cache,
  91. single_cls=single_cls,
  92. stride=int(stride),
  93. pad=pad,
  94. image_weights=image_weights,
  95. prefix=prefix)
  96. batch_size = min(batch_size, len(dataset))
  97. nd = torch.cuda.device_count() # number of CUDA devices
  98. nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
  99. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  100. loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
  101. return loader(dataset,
  102. batch_size=batch_size,
  103. shuffle=shuffle and sampler is None,
  104. num_workers=nw,
  105. sampler=sampler,
  106. pin_memory=True,
  107. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
  108. class InfiniteDataLoader(dataloader.DataLoader):
  109. """ Dataloader that reuses workers
  110. Uses same syntax as vanilla DataLoader
  111. """
  112. def __init__(self, *args, **kwargs):
  113. super().__init__(*args, **kwargs)
  114. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  115. self.iterator = super().__iter__()
  116. def __len__(self):
  117. return len(self.batch_sampler.sampler)
  118. def __iter__(self):
  119. for i in range(len(self)):
  120. yield next(self.iterator)
  121. class _RepeatSampler:
  122. """ Sampler that repeats forever
  123. Args:
  124. sampler (Sampler)
  125. """
  126. def __init__(self, sampler):
  127. self.sampler = sampler
  128. def __iter__(self):
  129. while True:
  130. yield from iter(self.sampler)
  131. class LoadImages:
  132. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  133. def __init__(self, path, img_size=640, stride=32, auto=True):
  134. p = str(Path(path).resolve()) # os-agnostic absolute path
  135. if '*' in p:
  136. files = sorted(glob.glob(p, recursive=True)) # glob
  137. elif os.path.isdir(p):
  138. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  139. elif os.path.isfile(p):
  140. files = [p] # files
  141. else:
  142. raise Exception(f'ERROR: {p} does not exist')
  143. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  144. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  145. ni, nv = len(images), len(videos)
  146. self.img_size = img_size
  147. self.stride = stride
  148. self.files = images + videos
  149. self.nf = ni + nv # number of files
  150. self.video_flag = [False] * ni + [True] * nv
  151. self.mode = 'image'
  152. self.auto = auto
  153. if any(videos):
  154. self.new_video(videos[0]) # new video
  155. else:
  156. self.cap = None
  157. assert self.nf > 0, f'No images or videos found in {p}. ' \
  158. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  159. def __iter__(self):
  160. self.count = 0
  161. return self
  162. def __next__(self):
  163. if self.count == self.nf:
  164. raise StopIteration
  165. path = self.files[self.count]
  166. if self.video_flag[self.count]:
  167. # Read video
  168. self.mode = 'video'
  169. ret_val, img0 = self.cap.read()
  170. while not ret_val:
  171. self.count += 1
  172. self.cap.release()
  173. if self.count == self.nf: # last video
  174. raise StopIteration
  175. else:
  176. path = self.files[self.count]
  177. self.new_video(path)
  178. ret_val, img0 = self.cap.read()
  179. self.frame += 1
  180. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  181. else:
  182. # Read image
  183. self.count += 1
  184. img0 = cv2.imread(path) # BGR
  185. assert img0 is not None, f'Image Not Found {path}'
  186. s = f'image {self.count}/{self.nf} {path}: '
  187. # Padded resize
  188. img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
  189. # Convert
  190. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  191. img = np.ascontiguousarray(img)
  192. return path, img, img0, self.cap, s
  193. def new_video(self, path):
  194. self.frame = 0
  195. self.cap = cv2.VideoCapture(path)
  196. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  197. def __len__(self):
  198. return self.nf # number of files
  199. class LoadWebcam: # for inference
  200. # YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
  201. def __init__(self, pipe='0', img_size=640, stride=32):
  202. self.img_size = img_size
  203. self.stride = stride
  204. self.pipe = eval(pipe) if pipe.isnumeric() else pipe
  205. self.cap = cv2.VideoCapture(self.pipe) # video capture object
  206. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  207. def __iter__(self):
  208. self.count = -1
  209. return self
  210. def __next__(self):
  211. self.count += 1
  212. if cv2.waitKey(1) == ord('q'): # q to quit
  213. self.cap.release()
  214. cv2.destroyAllWindows()
  215. raise StopIteration
  216. # Read frame
  217. ret_val, img0 = self.cap.read()
  218. img0 = cv2.flip(img0, 1) # flip left-right
  219. # Print
  220. assert ret_val, f'Camera Error {self.pipe}'
  221. img_path = 'webcam.jpg'
  222. s = f'webcam {self.count}: '
  223. # Padded resize
  224. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  225. # Convert
  226. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  227. img = np.ascontiguousarray(img)
  228. return img_path, img, img0, None, s
  229. def __len__(self):
  230. return 0
  231. class LoadStreams:
  232. # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
  233. def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
  234. self.mode = 'stream'
  235. self.img_size = img_size
  236. self.stride = stride
  237. if os.path.isfile(sources):
  238. with open(sources) as f:
  239. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  240. else:
  241. sources = [sources]
  242. n = len(sources)
  243. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  244. self.sources = [clean_str(x) for x in sources] # clean source names for later
  245. self.auto = auto
  246. for i, s in enumerate(sources): # index, source
  247. # Start thread to read frames from video stream
  248. st = f'{i + 1}/{n}: {s}... '
  249. if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
  250. check_requirements(('pafy', 'youtube_dl==2020.12.2'))
  251. import pafy
  252. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  253. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  254. cap = cv2.VideoCapture(s)
  255. assert cap.isOpened(), f'{st}Failed to open {s}'
  256. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  257. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  258. fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
  259. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  260. self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
  261. _, self.imgs[i] = cap.read() # guarantee first frame
  262. self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
  263. LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  264. self.threads[i].start()
  265. LOGGER.info('') # newline
  266. # check for common shapes
  267. s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
  268. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  269. if not self.rect:
  270. LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  271. def update(self, i, cap, stream):
  272. # Read stream `i` frames in daemon thread
  273. n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
  274. while cap.isOpened() and n < f:
  275. n += 1
  276. # _, self.imgs[index] = cap.read()
  277. cap.grab()
  278. if n % read == 0:
  279. success, im = cap.retrieve()
  280. if success:
  281. self.imgs[i] = im
  282. else:
  283. LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
  284. self.imgs[i] = np.zeros_like(self.imgs[i])
  285. cap.open(stream) # re-open stream if signal was lost
  286. time.sleep(1 / self.fps[i]) # wait time
  287. def __iter__(self):
  288. self.count = -1
  289. return self
  290. def __next__(self):
  291. self.count += 1
  292. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  293. cv2.destroyAllWindows()
  294. raise StopIteration
  295. # Letterbox
  296. img0 = self.imgs.copy()
  297. img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
  298. # Stack
  299. img = np.stack(img, 0)
  300. # Convert
  301. img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
  302. img = np.ascontiguousarray(img)
  303. return self.sources, img, img0, None, ''
  304. def __len__(self):
  305. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  306. def img2label_paths(img_paths):
  307. # Define label paths as a function of image paths
  308. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  309. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  310. class LoadImagesAndLabels(Dataset):
  311. # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
  312. cache_version = 0.6 # dataset labels *.cache version
  313. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  314. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  315. self.img_size = img_size
  316. self.augment = augment
  317. self.hyp = hyp
  318. self.image_weights = image_weights
  319. self.rect = False if image_weights else rect
  320. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  321. self.mosaic_border = [-img_size // 2, -img_size // 2]
  322. self.stride = stride
  323. self.path = path
  324. self.albumentations = Albumentations() if augment else None
  325. try:
  326. f = [] # image files
  327. for p in path if isinstance(path, list) else [path]:
  328. p = Path(p) # os-agnostic
  329. if p.is_dir(): # dir
  330. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  331. # f = list(p.rglob('*.*')) # pathlib
  332. elif p.is_file(): # file
  333. with open(p) as t:
  334. t = t.read().strip().splitlines()
  335. parent = str(p.parent) + os.sep
  336. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  337. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  338. else:
  339. raise Exception(f'{prefix}{p} does not exist')
  340. self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
  341. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  342. assert self.im_files, f'{prefix}No images found'
  343. except Exception as e:
  344. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
  345. # Check cache
  346. self.label_files = img2label_paths(self.im_files) # labels
  347. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
  348. try:
  349. cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
  350. assert cache['version'] == self.cache_version # same version
  351. assert cache['hash'] == get_hash(self.label_files + self.im_files) # same hash
  352. except Exception:
  353. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  354. # Display cache
  355. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
  356. if exists:
  357. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  358. tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
  359. if cache['msgs']:
  360. LOGGER.info('\n'.join(cache['msgs'])) # display warnings
  361. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
  362. # Read cache
  363. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  364. labels, shapes, self.segments = zip(*cache.values())
  365. self.labels = list(labels)
  366. self.shapes = np.array(shapes, dtype=np.float64)
  367. self.im_files = list(cache.keys()) # update
  368. self.label_files = img2label_paths(cache.keys()) # update
  369. n = len(shapes) # number of images
  370. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  371. nb = bi[-1] + 1 # number of batches
  372. self.batch = bi # batch index of image
  373. self.n = n
  374. self.indices = range(n)
  375. # Update labels
  376. include_class = [] # filter labels to include only these classes (optional)
  377. include_class_array = np.array(include_class).reshape(1, -1)
  378. for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
  379. if include_class:
  380. j = (label[:, 0:1] == include_class_array).any(1)
  381. self.labels[i] = label[j]
  382. if segment:
  383. self.segments[i] = segment[j]
  384. if single_cls: # single-class training, merge all classes into 0
  385. self.labels[i][:, 0] = 0
  386. if segment:
  387. self.segments[i][:, 0] = 0
  388. # Rectangular Training
  389. if self.rect:
  390. # Sort by aspect ratio
  391. s = self.shapes # wh
  392. ar = s[:, 1] / s[:, 0] # aspect ratio
  393. irect = ar.argsort()
  394. self.im_files = [self.im_files[i] for i in irect]
  395. self.label_files = [self.label_files[i] for i in irect]
  396. self.labels = [self.labels[i] for i in irect]
  397. self.shapes = s[irect] # wh
  398. ar = ar[irect]
  399. # Set training image shapes
  400. shapes = [[1, 1]] * nb
  401. for i in range(nb):
  402. ari = ar[bi == i]
  403. mini, maxi = ari.min(), ari.max()
  404. if maxi < 1:
  405. shapes[i] = [maxi, 1]
  406. elif mini > 1:
  407. shapes[i] = [1, 1 / mini]
  408. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  409. # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources)
  410. self.ims = [None] * n
  411. self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
  412. if cache_images:
  413. gb = 0 # Gigabytes of cached images
  414. self.im_hw0, self.im_hw = [None] * n, [None] * n
  415. fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
  416. results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
  417. pbar = tqdm(enumerate(results), total=n, bar_format=BAR_FORMAT)
  418. for i, x in pbar:
  419. if cache_images == 'disk':
  420. gb += self.npy_files[i].stat().st_size
  421. else: # 'ram'
  422. self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  423. gb += self.ims[i].nbytes
  424. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
  425. pbar.close()
  426. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  427. # Cache dataset labels, check images and read shapes
  428. x = {} # dict
  429. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  430. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  431. with Pool(NUM_THREADS) as pool:
  432. pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
  433. desc=desc, total=len(self.im_files), bar_format=BAR_FORMAT)
  434. for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  435. nm += nm_f
  436. nf += nf_f
  437. ne += ne_f
  438. nc += nc_f
  439. if im_file:
  440. x[im_file] = [lb, shape, segments]
  441. if msg:
  442. msgs.append(msg)
  443. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  444. pbar.close()
  445. if msgs:
  446. LOGGER.info('\n'.join(msgs))
  447. if nf == 0:
  448. LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
  449. x['hash'] = get_hash(self.label_files + self.im_files)
  450. x['results'] = nf, nm, ne, nc, len(self.im_files)
  451. x['msgs'] = msgs # warnings
  452. x['version'] = self.cache_version # cache version
  453. try:
  454. np.save(path, x) # save cache for next time
  455. path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
  456. LOGGER.info(f'{prefix}New cache created: {path}')
  457. except Exception as e:
  458. LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable
  459. return x
  460. def __len__(self):
  461. return len(self.im_files)
  462. # def __iter__(self):
  463. # self.count = -1
  464. # print('ran dataset iter')
  465. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  466. # return self
  467. def __getitem__(self, index):
  468. index = self.indices[index] # linear, shuffled, or image_weights
  469. hyp = self.hyp
  470. mosaic = self.mosaic and random.random() < hyp['mosaic']
  471. if mosaic:
  472. # Load mosaic
  473. img, labels = self.load_mosaic(index)
  474. shapes = None
  475. # MixUp augmentation
  476. if random.random() < hyp['mixup']:
  477. img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
  478. else:
  479. # Load image
  480. img, (h0, w0), (h, w) = self.load_image(index)
  481. # Letterbox
  482. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  483. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  484. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  485. labels = self.labels[index].copy()
  486. if labels.size: # normalized xywh to pixel xyxy format
  487. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  488. if self.augment:
  489. img, labels = random_perspective(img, labels,
  490. degrees=hyp['degrees'],
  491. translate=hyp['translate'],
  492. scale=hyp['scale'],
  493. shear=hyp['shear'],
  494. perspective=hyp['perspective'])
  495. nl = len(labels) # number of labels
  496. if nl:
  497. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
  498. if self.augment:
  499. # Albumentations
  500. img, labels = self.albumentations(img, labels)
  501. nl = len(labels) # update after albumentations
  502. # HSV color-space
  503. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  504. # Flip up-down
  505. if random.random() < hyp['flipud']:
  506. img = np.flipud(img)
  507. if nl:
  508. labels[:, 2] = 1 - labels[:, 2]
  509. # Flip left-right
  510. if random.random() < hyp['fliplr']:
  511. img = np.fliplr(img)
  512. if nl:
  513. labels[:, 1] = 1 - labels[:, 1]
  514. # Cutouts
  515. # labels = cutout(img, labels, p=0.5)
  516. # nl = len(labels) # update after cutout
  517. labels_out = torch.zeros((nl, 6))
  518. if nl:
  519. labels_out[:, 1:] = torch.from_numpy(labels)
  520. # Convert
  521. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  522. img = np.ascontiguousarray(img)
  523. return torch.from_numpy(img), labels_out, self.im_files[index], shapes
  524. def load_image(self, i):
  525. # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
  526. im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
  527. if im is None: # not cached in RAM
  528. if fn.exists(): # load npy
  529. im = np.load(fn)
  530. else: # read image
  531. im = cv2.imread(f) # BGR
  532. assert im is not None, f'Image Not Found {f}'
  533. h0, w0 = im.shape[:2] # orig hw
  534. r = self.img_size / max(h0, w0) # ratio
  535. if r != 1: # if sizes are not equal
  536. im = cv2.resize(im,
  537. (int(w0 * r), int(h0 * r)),
  538. interpolation=cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA)
  539. return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  540. else:
  541. return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
  542. def cache_images_to_disk(self, i):
  543. # Saves an image as an *.npy file for faster loading
  544. f = self.npy_files[i]
  545. if not f.exists():
  546. np.save(f.as_posix(), cv2.imread(self.im_files[i]))
  547. def load_mosaic(self, index):
  548. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  549. labels4, segments4 = [], []
  550. s = self.img_size
  551. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
  552. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  553. random.shuffle(indices)
  554. for i, index in enumerate(indices):
  555. # Load image
  556. img, _, (h, w) = self.load_image(index)
  557. # place img in img4
  558. if i == 0: # top left
  559. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  560. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  561. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  562. elif i == 1: # top right
  563. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  564. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  565. elif i == 2: # bottom left
  566. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  567. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  568. elif i == 3: # bottom right
  569. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  570. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  571. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  572. padw = x1a - x1b
  573. padh = y1a - y1b
  574. # Labels
  575. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  576. if labels.size:
  577. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  578. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  579. labels4.append(labels)
  580. segments4.extend(segments)
  581. # Concat/clip labels
  582. labels4 = np.concatenate(labels4, 0)
  583. for x in (labels4[:, 1:], *segments4):
  584. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  585. # img4, labels4 = replicate(img4, labels4) # replicate
  586. # Augment
  587. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
  588. img4, labels4 = random_perspective(img4, labels4, segments4,
  589. degrees=self.hyp['degrees'],
  590. translate=self.hyp['translate'],
  591. scale=self.hyp['scale'],
  592. shear=self.hyp['shear'],
  593. perspective=self.hyp['perspective'],
  594. border=self.mosaic_border) # border to remove
  595. return img4, labels4
  596. def load_mosaic9(self, index):
  597. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  598. labels9, segments9 = [], []
  599. s = self.img_size
  600. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  601. random.shuffle(indices)
  602. hp, wp = -1, -1 # height, width previous
  603. for i, index in enumerate(indices):
  604. # Load image
  605. img, _, (h, w) = self.load_image(index)
  606. # place img in img9
  607. if i == 0: # center
  608. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  609. h0, w0 = h, w
  610. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  611. elif i == 1: # top
  612. c = s, s - h, s + w, s
  613. elif i == 2: # top right
  614. c = s + wp, s - h, s + wp + w, s
  615. elif i == 3: # right
  616. c = s + w0, s, s + w0 + w, s + h
  617. elif i == 4: # bottom right
  618. c = s + w0, s + hp, s + w0 + w, s + hp + h
  619. elif i == 5: # bottom
  620. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  621. elif i == 6: # bottom left
  622. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  623. elif i == 7: # left
  624. c = s - w, s + h0 - h, s, s + h0
  625. elif i == 8: # top left
  626. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  627. padx, pady = c[:2]
  628. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  629. # Labels
  630. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  631. if labels.size:
  632. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  633. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  634. labels9.append(labels)
  635. segments9.extend(segments)
  636. # Image
  637. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  638. hp, wp = h, w # height, width previous
  639. # Offset
  640. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  641. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  642. # Concat/clip labels
  643. labels9 = np.concatenate(labels9, 0)
  644. labels9[:, [1, 3]] -= xc
  645. labels9[:, [2, 4]] -= yc
  646. c = np.array([xc, yc]) # centers
  647. segments9 = [x - c for x in segments9]
  648. for x in (labels9[:, 1:], *segments9):
  649. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  650. # img9, labels9 = replicate(img9, labels9) # replicate
  651. # Augment
  652. img9, labels9 = random_perspective(img9, labels9, segments9,
  653. degrees=self.hyp['degrees'],
  654. translate=self.hyp['translate'],
  655. scale=self.hyp['scale'],
  656. shear=self.hyp['shear'],
  657. perspective=self.hyp['perspective'],
  658. border=self.mosaic_border) # border to remove
  659. return img9, labels9
  660. @staticmethod
  661. def collate_fn(batch):
  662. im, label, path, shapes = zip(*batch) # transposed
  663. for i, lb in enumerate(label):
  664. lb[:, 0] = i # add target image index for build_targets()
  665. return torch.stack(im, 0), torch.cat(label, 0), path, shapes
  666. @staticmethod
  667. def collate_fn4(batch):
  668. img, label, path, shapes = zip(*batch) # transposed
  669. n = len(shapes) // 4
  670. im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  671. ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
  672. wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
  673. s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
  674. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  675. i *= 4
  676. if random.random() < 0.5:
  677. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', align_corners=False)[
  678. 0].type(img[i].type())
  679. lb = label[i]
  680. else:
  681. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  682. lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  683. im4.append(im)
  684. label4.append(lb)
  685. for i, lb in enumerate(label4):
  686. lb[:, 0] = i # add target image index for build_targets()
  687. return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
  688. # Ancillary functions --------------------------------------------------------------------------------------------------
  689. def create_folder(path='./new'):
  690. # Create folder
  691. if os.path.exists(path):
  692. shutil.rmtree(path) # delete output folder
  693. os.makedirs(path) # make new output folder
  694. def flatten_recursive(path=DATASETS_DIR / 'coco128'):
  695. # Flatten a recursive directory by bringing all files to top level
  696. new_path = Path(str(path) + '_flat')
  697. create_folder(new_path)
  698. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  699. shutil.copyfile(file, new_path / Path(file).name)
  700. def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.datasets import *; extract_boxes()
  701. # Convert detection dataset into classification dataset, with one directory per class
  702. path = Path(path) # images dir
  703. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  704. files = list(path.rglob('*.*'))
  705. n = len(files) # number of files
  706. for im_file in tqdm(files, total=n):
  707. if im_file.suffix[1:] in IMG_FORMATS:
  708. # image
  709. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  710. h, w = im.shape[:2]
  711. # labels
  712. lb_file = Path(img2label_paths([str(im_file)])[0])
  713. if Path(lb_file).exists():
  714. with open(lb_file) as f:
  715. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  716. for j, x in enumerate(lb):
  717. c = int(x[0]) # class
  718. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  719. if not f.parent.is_dir():
  720. f.parent.mkdir(parents=True)
  721. b = x[1:] * [w, h, w, h] # box
  722. # b[2:] = b[2:].max() # rectangle to square
  723. b[2:] = b[2:] * 1.2 + 3 # pad
  724. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  725. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  726. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  727. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  728. def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  729. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  730. Usage: from utils.datasets import *; autosplit()
  731. Arguments
  732. path: Path to images directory
  733. weights: Train, val, test weights (list, tuple)
  734. annotated_only: Only use images with an annotated txt file
  735. """
  736. path = Path(path) # images dir
  737. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  738. n = len(files) # number of files
  739. random.seed(0) # for reproducibility
  740. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  741. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  742. [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
  743. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  744. for i, img in tqdm(zip(indices, files), total=n):
  745. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  746. with open(path.parent / txt[i], 'a') as f:
  747. f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
  748. def verify_image_label(args):
  749. # Verify one image-label pair
  750. im_file, lb_file, prefix = args
  751. nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
  752. try:
  753. # verify images
  754. im = Image.open(im_file)
  755. im.verify() # PIL verify
  756. shape = exif_size(im) # image size
  757. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  758. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  759. if im.format.lower() in ('jpg', 'jpeg'):
  760. with open(im_file, 'rb') as f:
  761. f.seek(-2, 2)
  762. if f.read() != b'\xff\xd9': # corrupt JPEG
  763. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  764. msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'
  765. # verify labels
  766. if os.path.isfile(lb_file):
  767. nf = 1 # label found
  768. with open(lb_file) as f:
  769. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  770. if any(len(x) > 6 for x in lb): # is segment
  771. classes = np.array([x[0] for x in lb], dtype=np.float32)
  772. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  773. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  774. lb = np.array(lb, dtype=np.float32)
  775. nl = len(lb)
  776. if nl:
  777. assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
  778. assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
  779. assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
  780. _, i = np.unique(lb, axis=0, return_index=True)
  781. if len(i) < nl: # duplicate row check
  782. lb = lb[i] # remove duplicates
  783. if segments:
  784. segments = segments[i]
  785. msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
  786. else:
  787. ne = 1 # label empty
  788. lb = np.zeros((0, 5), dtype=np.float32)
  789. else:
  790. nm = 1 # label missing
  791. lb = np.zeros((0, 5), dtype=np.float32)
  792. return im_file, lb, shape, segments, nm, nf, ne, nc, msg
  793. except Exception as e:
  794. nc = 1
  795. msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
  796. return [None, None, None, None, nm, nf, ne, nc, msg]
  797. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
  798. """ Return dataset statistics dictionary with images and instances counts per split per class
  799. To run in parent directory: export PYTHONPATH="$PWD/yolov5"
  800. Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', autodownload=True)
  801. Usage2: from utils.datasets import *; dataset_stats('path/to/coco128_with_yaml.zip')
  802. Arguments
  803. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  804. autodownload: Attempt to download dataset if not found locally
  805. verbose: Print stats dictionary
  806. """
  807. def round_labels(labels):
  808. # Update labels to integer class and 6 decimal place floats
  809. return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
  810. def unzip(path):
  811. # Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
  812. if str(path).endswith('.zip'): # path is data.zip
  813. assert Path(path).is_file(), f'Error unzipping {path}, file not found'
  814. ZipFile(path).extractall(path=path.parent) # unzip
  815. dir = path.with_suffix('') # dataset directory == zip name
  816. return True, str(dir), next(dir.rglob('*.yaml')) # zipped, data_dir, yaml_path
  817. else: # path is data.yaml
  818. return False, None, path
  819. def hub_ops(f, max_dim=1920):
  820. # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
  821. f_new = im_dir / Path(f).name # dataset-hub image filename
  822. try: # use PIL
  823. im = Image.open(f)
  824. r = max_dim / max(im.height, im.width) # ratio
  825. if r < 1.0: # image too large
  826. im = im.resize((int(im.width * r), int(im.height * r)))
  827. im.save(f_new, 'JPEG', quality=75, optimize=True) # save
  828. except Exception as e: # use OpenCV
  829. print(f'WARNING: HUB ops PIL failure {f}: {e}')
  830. im = cv2.imread(f)
  831. im_height, im_width = im.shape[:2]
  832. r = max_dim / max(im_height, im_width) # ratio
  833. if r < 1.0: # image too large
  834. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  835. cv2.imwrite(str(f_new), im)
  836. zipped, data_dir, yaml_path = unzip(Path(path))
  837. with open(check_yaml(yaml_path), errors='ignore') as f:
  838. data = yaml.safe_load(f) # data dict
  839. if zipped:
  840. data['path'] = data_dir # TODO: should this be dir.resolve()?
  841. check_dataset(data, autodownload) # download dataset if missing
  842. hub_dir = Path(data['path'] + ('-hub' if hub else ''))
  843. stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
  844. for split in 'train', 'val', 'test':
  845. if data.get(split) is None:
  846. stats[split] = None # i.e. no test set
  847. continue
  848. x = []
  849. dataset = LoadImagesAndLabels(data[split]) # load dataset
  850. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  851. x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
  852. x = np.array(x) # shape(128x80)
  853. stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
  854. 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
  855. 'per_class': (x > 0).sum(0).tolist()},
  856. 'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
  857. zip(dataset.im_files, dataset.labels)]}
  858. if hub:
  859. im_dir = hub_dir / 'images'
  860. im_dir.mkdir(parents=True, exist_ok=True)
  861. for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.im_files), total=dataset.n, desc='HUB Ops'):
  862. pass
  863. # Profile
  864. stats_path = hub_dir / 'stats.json'
  865. if profile:
  866. for _ in range(1):
  867. file = stats_path.with_suffix('.npy')
  868. t1 = time.time()
  869. np.save(file, stats)
  870. t2 = time.time()
  871. x = np.load(file, allow_pickle=True)
  872. print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  873. file = stats_path.with_suffix('.json')
  874. t1 = time.time()
  875. with open(file, 'w') as f:
  876. json.dump(stats, f) # save stats *.json
  877. t2 = time.time()
  878. with open(file) as f:
  879. x = json.load(f) # load hyps dict
  880. print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  881. # Save, print and return
  882. if hub:
  883. print(f'Saving {stats_path.resolve()}...')
  884. with open(stats_path, 'w') as f:
  885. json.dump(stats, f) # save stats.json
  886. if verbose:
  887. print(json.dumps(stats, indent=2, sort_keys=False))
  888. return stats