You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1039 lines
44KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Dataloaders and dataset utils
  4. """
  5. import glob
  6. import hashlib
  7. import json
  8. import math
  9. import os
  10. import random
  11. import shutil
  12. import time
  13. from itertools import repeat
  14. from multiprocessing.pool import Pool, ThreadPool
  15. from pathlib import Path
  16. from threading import Thread
  17. from zipfile import ZipFile
  18. import cv2
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. import yaml
  23. from PIL import ExifTags, Image, ImageOps
  24. from torch.utils.data import DataLoader, Dataset, dataloader, distributed
  25. from tqdm import tqdm
  26. from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
  27. from utils.general import (LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
  28. segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
  29. from utils.torch_utils import torch_distributed_zero_first
  30. # Parameters
  31. HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  32. IMG_FORMATS = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  33. VID_FORMATS = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  34. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # DPP
  35. # Get orientation exif tag
  36. for orientation in ExifTags.TAGS.keys():
  37. if ExifTags.TAGS[orientation] == 'Orientation':
  38. break
  39. def get_hash(paths):
  40. # Returns a single hash value of a list of paths (files or dirs)
  41. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  42. h = hashlib.md5(str(size).encode()) # hash sizes
  43. h.update(''.join(paths).encode()) # hash paths
  44. return h.hexdigest() # return hash
  45. def exif_size(img):
  46. # Returns exif-corrected PIL size
  47. s = img.size # (width, height)
  48. try:
  49. rotation = dict(img._getexif().items())[orientation]
  50. if rotation == 6: # rotation 270
  51. s = (s[1], s[0])
  52. elif rotation == 8: # rotation 90
  53. s = (s[1], s[0])
  54. except:
  55. pass
  56. return s
  57. def exif_transpose(image):
  58. """
  59. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  60. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
  61. :param image: The image to transpose.
  62. :return: An image.
  63. """
  64. exif = image.getexif()
  65. orientation = exif.get(0x0112, 1) # default 1
  66. if orientation > 1:
  67. method = {2: Image.FLIP_LEFT_RIGHT,
  68. 3: Image.ROTATE_180,
  69. 4: Image.FLIP_TOP_BOTTOM,
  70. 5: Image.TRANSPOSE,
  71. 6: Image.ROTATE_270,
  72. 7: Image.TRANSVERSE,
  73. 8: Image.ROTATE_90,
  74. }.get(orientation)
  75. if method is not None:
  76. image = image.transpose(method)
  77. del exif[0x0112]
  78. image.info["exif"] = exif.tobytes()
  79. return image
  80. def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
  81. rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False):
  82. if rect and shuffle:
  83. LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
  84. shuffle = False
  85. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  86. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  87. augment=augment, # augmentation
  88. hyp=hyp, # hyperparameters
  89. rect=rect, # rectangular batches
  90. cache_images=cache,
  91. single_cls=single_cls,
  92. stride=int(stride),
  93. pad=pad,
  94. image_weights=image_weights,
  95. prefix=prefix)
  96. batch_size = min(batch_size, len(dataset))
  97. nw = min([os.cpu_count() // WORLD_SIZE, batch_size if batch_size > 1 else 0, workers]) # number of workers
  98. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  99. loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
  100. return loader(dataset,
  101. batch_size=batch_size,
  102. shuffle=shuffle and sampler is None,
  103. num_workers=nw,
  104. sampler=sampler,
  105. pin_memory=True,
  106. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
  107. class InfiniteDataLoader(dataloader.DataLoader):
  108. """ Dataloader that reuses workers
  109. Uses same syntax as vanilla DataLoader
  110. """
  111. def __init__(self, *args, **kwargs):
  112. super().__init__(*args, **kwargs)
  113. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  114. self.iterator = super().__iter__()
  115. def __len__(self):
  116. return len(self.batch_sampler.sampler)
  117. def __iter__(self):
  118. for i in range(len(self)):
  119. yield next(self.iterator)
  120. class _RepeatSampler:
  121. """ Sampler that repeats forever
  122. Args:
  123. sampler (Sampler)
  124. """
  125. def __init__(self, sampler):
  126. self.sampler = sampler
  127. def __iter__(self):
  128. while True:
  129. yield from iter(self.sampler)
  130. class LoadImages:
  131. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  132. def __init__(self, path, img_size=640, stride=32, auto=True):
  133. p = str(Path(path).resolve()) # os-agnostic absolute path
  134. if '*' in p:
  135. files = sorted(glob.glob(p, recursive=True)) # glob
  136. elif os.path.isdir(p):
  137. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  138. elif os.path.isfile(p):
  139. files = [p] # files
  140. else:
  141. raise Exception(f'ERROR: {p} does not exist')
  142. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  143. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  144. ni, nv = len(images), len(videos)
  145. self.img_size = img_size
  146. self.stride = stride
  147. self.files = images + videos
  148. self.nf = ni + nv # number of files
  149. self.video_flag = [False] * ni + [True] * nv
  150. self.mode = 'image'
  151. self.auto = auto
  152. if any(videos):
  153. self.new_video(videos[0]) # new video
  154. else:
  155. self.cap = None
  156. assert self.nf > 0, f'No images or videos found in {p}. ' \
  157. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  158. def __iter__(self):
  159. self.count = 0
  160. return self
  161. def __next__(self):
  162. if self.count == self.nf:
  163. raise StopIteration
  164. path = self.files[self.count]
  165. if self.video_flag[self.count]:
  166. # Read video
  167. self.mode = 'video'
  168. ret_val, img0 = self.cap.read()
  169. while not ret_val:
  170. self.count += 1
  171. self.cap.release()
  172. if self.count == self.nf: # last video
  173. raise StopIteration
  174. else:
  175. path = self.files[self.count]
  176. self.new_video(path)
  177. ret_val, img0 = self.cap.read()
  178. self.frame += 1
  179. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  180. else:
  181. # Read image
  182. self.count += 1
  183. img0 = cv2.imread(path) # BGR
  184. assert img0 is not None, f'Image Not Found {path}'
  185. s = f'image {self.count}/{self.nf} {path}: '
  186. # Padded resize
  187. img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
  188. # Convert
  189. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  190. img = np.ascontiguousarray(img)
  191. return path, img, img0, self.cap, s
  192. def new_video(self, path):
  193. self.frame = 0
  194. self.cap = cv2.VideoCapture(path)
  195. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  196. def __len__(self):
  197. return self.nf # number of files
  198. class LoadWebcam: # for inference
  199. # YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
  200. def __init__(self, pipe='0', img_size=640, stride=32):
  201. self.img_size = img_size
  202. self.stride = stride
  203. self.pipe = eval(pipe) if pipe.isnumeric() else pipe
  204. self.cap = cv2.VideoCapture(self.pipe) # video capture object
  205. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  206. def __iter__(self):
  207. self.count = -1
  208. return self
  209. def __next__(self):
  210. self.count += 1
  211. if cv2.waitKey(1) == ord('q'): # q to quit
  212. self.cap.release()
  213. cv2.destroyAllWindows()
  214. raise StopIteration
  215. # Read frame
  216. ret_val, img0 = self.cap.read()
  217. img0 = cv2.flip(img0, 1) # flip left-right
  218. # Print
  219. assert ret_val, f'Camera Error {self.pipe}'
  220. img_path = 'webcam.jpg'
  221. s = f'webcam {self.count}: '
  222. # Padded resize
  223. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  224. # Convert
  225. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  226. img = np.ascontiguousarray(img)
  227. return img_path, img, img0, None, s
  228. def __len__(self):
  229. return 0
  230. class LoadStreams:
  231. # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
  232. def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
  233. self.mode = 'stream'
  234. self.img_size = img_size
  235. self.stride = stride
  236. if os.path.isfile(sources):
  237. with open(sources) as f:
  238. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  239. else:
  240. sources = [sources]
  241. n = len(sources)
  242. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  243. self.sources = [clean_str(x) for x in sources] # clean source names for later
  244. self.auto = auto
  245. for i, s in enumerate(sources): # index, source
  246. # Start thread to read frames from video stream
  247. st = f'{i + 1}/{n}: {s}... '
  248. if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
  249. check_requirements(('pafy', 'youtube_dl'))
  250. import pafy
  251. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  252. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  253. cap = cv2.VideoCapture(s)
  254. assert cap.isOpened(), f'{st}Failed to open {s}'
  255. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  256. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  257. fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
  258. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  259. self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
  260. _, self.imgs[i] = cap.read() # guarantee first frame
  261. self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
  262. LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  263. self.threads[i].start()
  264. LOGGER.info('') # newline
  265. # check for common shapes
  266. s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
  267. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  268. if not self.rect:
  269. LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  270. def update(self, i, cap, stream):
  271. # Read stream `i` frames in daemon thread
  272. n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
  273. while cap.isOpened() and n < f:
  274. n += 1
  275. # _, self.imgs[index] = cap.read()
  276. cap.grab()
  277. if n % read == 0:
  278. success, im = cap.retrieve()
  279. if success:
  280. self.imgs[i] = im
  281. else:
  282. LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
  283. self.imgs[i] = np.zeros_like(self.imgs[i])
  284. cap.open(stream) # re-open stream if signal was lost
  285. time.sleep(1 / self.fps[i]) # wait time
  286. def __iter__(self):
  287. self.count = -1
  288. return self
  289. def __next__(self):
  290. self.count += 1
  291. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  292. cv2.destroyAllWindows()
  293. raise StopIteration
  294. # Letterbox
  295. img0 = self.imgs.copy()
  296. img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
  297. # Stack
  298. img = np.stack(img, 0)
  299. # Convert
  300. img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
  301. img = np.ascontiguousarray(img)
  302. return self.sources, img, img0, None, ''
  303. def __len__(self):
  304. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  305. def img2label_paths(img_paths):
  306. # Define label paths as a function of image paths
  307. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  308. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  309. class LoadImagesAndLabels(Dataset):
  310. # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
  311. cache_version = 0.6 # dataset labels *.cache version
  312. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  313. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  314. self.img_size = img_size
  315. self.augment = augment
  316. self.hyp = hyp
  317. self.image_weights = image_weights
  318. self.rect = False if image_weights else rect
  319. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  320. self.mosaic_border = [-img_size // 2, -img_size // 2]
  321. self.stride = stride
  322. self.path = path
  323. self.albumentations = Albumentations() if augment else None
  324. try:
  325. f = [] # image files
  326. for p in path if isinstance(path, list) else [path]:
  327. p = Path(p) # os-agnostic
  328. if p.is_dir(): # dir
  329. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  330. # f = list(p.rglob('*.*')) # pathlib
  331. elif p.is_file(): # file
  332. with open(p) as t:
  333. t = t.read().strip().splitlines()
  334. parent = str(p.parent) + os.sep
  335. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  336. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  337. else:
  338. raise Exception(f'{prefix}{p} does not exist')
  339. self.img_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
  340. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  341. assert self.img_files, f'{prefix}No images found'
  342. except Exception as e:
  343. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
  344. # Check cache
  345. self.label_files = img2label_paths(self.img_files) # labels
  346. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
  347. try:
  348. cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
  349. assert cache['version'] == self.cache_version # same version
  350. assert cache['hash'] == get_hash(self.label_files + self.img_files) # same hash
  351. except:
  352. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  353. # Display cache
  354. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
  355. if exists:
  356. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  357. tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
  358. if cache['msgs']:
  359. LOGGER.info('\n'.join(cache['msgs'])) # display warnings
  360. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
  361. # Read cache
  362. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  363. labels, shapes, self.segments = zip(*cache.values())
  364. self.labels = list(labels)
  365. self.shapes = np.array(shapes, dtype=np.float64)
  366. self.img_files = list(cache.keys()) # update
  367. self.label_files = img2label_paths(cache.keys()) # update
  368. n = len(shapes) # number of images
  369. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  370. nb = bi[-1] + 1 # number of batches
  371. self.batch = bi # batch index of image
  372. self.n = n
  373. self.indices = range(n)
  374. # Update labels
  375. include_class = [] # filter labels to include only these classes (optional)
  376. include_class_array = np.array(include_class).reshape(1, -1)
  377. for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
  378. if include_class:
  379. j = (label[:, 0:1] == include_class_array).any(1)
  380. self.labels[i] = label[j]
  381. if segment:
  382. self.segments[i] = segment[j]
  383. if single_cls: # single-class training, merge all classes into 0
  384. self.labels[i][:, 0] = 0
  385. if segment:
  386. self.segments[i][:, 0] = 0
  387. # Rectangular Training
  388. if self.rect:
  389. # Sort by aspect ratio
  390. s = self.shapes # wh
  391. ar = s[:, 1] / s[:, 0] # aspect ratio
  392. irect = ar.argsort()
  393. self.img_files = [self.img_files[i] for i in irect]
  394. self.label_files = [self.label_files[i] for i in irect]
  395. self.labels = [self.labels[i] for i in irect]
  396. self.shapes = s[irect] # wh
  397. ar = ar[irect]
  398. # Set training image shapes
  399. shapes = [[1, 1]] * nb
  400. for i in range(nb):
  401. ari = ar[bi == i]
  402. mini, maxi = ari.min(), ari.max()
  403. if maxi < 1:
  404. shapes[i] = [maxi, 1]
  405. elif mini > 1:
  406. shapes[i] = [1, 1 / mini]
  407. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  408. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  409. self.imgs, self.img_npy = [None] * n, [None] * n
  410. if cache_images:
  411. if cache_images == 'disk':
  412. self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy')
  413. self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files]
  414. self.im_cache_dir.mkdir(parents=True, exist_ok=True)
  415. gb = 0 # Gigabytes of cached images
  416. self.img_hw0, self.img_hw = [None] * n, [None] * n
  417. results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
  418. pbar = tqdm(enumerate(results), total=n)
  419. for i, x in pbar:
  420. if cache_images == 'disk':
  421. if not self.img_npy[i].exists():
  422. np.save(self.img_npy[i].as_posix(), x[0])
  423. gb += self.img_npy[i].stat().st_size
  424. else:
  425. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  426. gb += self.imgs[i].nbytes
  427. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
  428. pbar.close()
  429. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  430. # Cache dataset labels, check images and read shapes
  431. x = {} # dict
  432. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  433. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  434. with Pool(NUM_THREADS) as pool:
  435. pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
  436. desc=desc, total=len(self.img_files))
  437. for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  438. nm += nm_f
  439. nf += nf_f
  440. ne += ne_f
  441. nc += nc_f
  442. if im_file:
  443. x[im_file] = [l, shape, segments]
  444. if msg:
  445. msgs.append(msg)
  446. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  447. pbar.close()
  448. if msgs:
  449. LOGGER.info('\n'.join(msgs))
  450. if nf == 0:
  451. LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
  452. x['hash'] = get_hash(self.label_files + self.img_files)
  453. x['results'] = nf, nm, ne, nc, len(self.img_files)
  454. x['msgs'] = msgs # warnings
  455. x['version'] = self.cache_version # cache version
  456. try:
  457. np.save(path, x) # save cache for next time
  458. path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
  459. LOGGER.info(f'{prefix}New cache created: {path}')
  460. except Exception as e:
  461. LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable
  462. return x
  463. def __len__(self):
  464. return len(self.img_files)
  465. # def __iter__(self):
  466. # self.count = -1
  467. # print('ran dataset iter')
  468. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  469. # return self
  470. def __getitem__(self, index):
  471. index = self.indices[index] # linear, shuffled, or image_weights
  472. hyp = self.hyp
  473. mosaic = self.mosaic and random.random() < hyp['mosaic']
  474. if mosaic:
  475. # Load mosaic
  476. img, labels = load_mosaic(self, index)
  477. shapes = None
  478. # MixUp augmentation
  479. if random.random() < hyp['mixup']:
  480. img, labels = mixup(img, labels, *load_mosaic(self, random.randint(0, self.n - 1)))
  481. else:
  482. # Load image
  483. img, (h0, w0), (h, w) = load_image(self, index)
  484. # Letterbox
  485. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  486. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  487. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  488. labels = self.labels[index].copy()
  489. if labels.size: # normalized xywh to pixel xyxy format
  490. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  491. if self.augment:
  492. img, labels = random_perspective(img, labels,
  493. degrees=hyp['degrees'],
  494. translate=hyp['translate'],
  495. scale=hyp['scale'],
  496. shear=hyp['shear'],
  497. perspective=hyp['perspective'])
  498. nl = len(labels) # number of labels
  499. if nl:
  500. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
  501. if self.augment:
  502. # Albumentations
  503. img, labels = self.albumentations(img, labels)
  504. nl = len(labels) # update after albumentations
  505. # HSV color-space
  506. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  507. # Flip up-down
  508. if random.random() < hyp['flipud']:
  509. img = np.flipud(img)
  510. if nl:
  511. labels[:, 2] = 1 - labels[:, 2]
  512. # Flip left-right
  513. if random.random() < hyp['fliplr']:
  514. img = np.fliplr(img)
  515. if nl:
  516. labels[:, 1] = 1 - labels[:, 1]
  517. # Cutouts
  518. # labels = cutout(img, labels, p=0.5)
  519. # nl = len(labels) # update after cutout
  520. labels_out = torch.zeros((nl, 6))
  521. if nl:
  522. labels_out[:, 1:] = torch.from_numpy(labels)
  523. # Convert
  524. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  525. img = np.ascontiguousarray(img)
  526. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  527. @staticmethod
  528. def collate_fn(batch):
  529. img, label, path, shapes = zip(*batch) # transposed
  530. for i, l in enumerate(label):
  531. l[:, 0] = i # add target image index for build_targets()
  532. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  533. @staticmethod
  534. def collate_fn4(batch):
  535. img, label, path, shapes = zip(*batch) # transposed
  536. n = len(shapes) // 4
  537. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  538. ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
  539. wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
  540. s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
  541. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  542. i *= 4
  543. if random.random() < 0.5:
  544. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', align_corners=False)[
  545. 0].type(img[i].type())
  546. l = label[i]
  547. else:
  548. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  549. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  550. img4.append(im)
  551. label4.append(l)
  552. for i, l in enumerate(label4):
  553. l[:, 0] = i # add target image index for build_targets()
  554. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  555. # Ancillary functions --------------------------------------------------------------------------------------------------
  556. def load_image(self, i):
  557. # loads 1 image from dataset index 'i', returns im, original hw, resized hw
  558. im = self.imgs[i]
  559. if im is None: # not cached in ram
  560. npy = self.img_npy[i]
  561. if npy and npy.exists(): # load npy
  562. im = np.load(npy)
  563. else: # read image
  564. path = self.img_files[i]
  565. im = cv2.imread(path) # BGR
  566. assert im is not None, f'Image Not Found {path}'
  567. h0, w0 = im.shape[:2] # orig hw
  568. r = self.img_size / max(h0, w0) # ratio
  569. if r != 1: # if sizes are not equal
  570. im = cv2.resize(im, (int(w0 * r), int(h0 * r)),
  571. interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
  572. return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  573. else:
  574. return self.imgs[i], self.img_hw0[i], self.img_hw[i] # im, hw_original, hw_resized
  575. def load_mosaic(self, index):
  576. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  577. labels4, segments4 = [], []
  578. s = self.img_size
  579. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
  580. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  581. random.shuffle(indices)
  582. for i, index in enumerate(indices):
  583. # Load image
  584. img, _, (h, w) = load_image(self, index)
  585. # place img in img4
  586. if i == 0: # top left
  587. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  588. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  589. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  590. elif i == 1: # top right
  591. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  592. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  593. elif i == 2: # bottom left
  594. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  595. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  596. elif i == 3: # bottom right
  597. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  598. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  599. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  600. padw = x1a - x1b
  601. padh = y1a - y1b
  602. # Labels
  603. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  604. if labels.size:
  605. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  606. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  607. labels4.append(labels)
  608. segments4.extend(segments)
  609. # Concat/clip labels
  610. labels4 = np.concatenate(labels4, 0)
  611. for x in (labels4[:, 1:], *segments4):
  612. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  613. # img4, labels4 = replicate(img4, labels4) # replicate
  614. # Augment
  615. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
  616. img4, labels4 = random_perspective(img4, labels4, segments4,
  617. degrees=self.hyp['degrees'],
  618. translate=self.hyp['translate'],
  619. scale=self.hyp['scale'],
  620. shear=self.hyp['shear'],
  621. perspective=self.hyp['perspective'],
  622. border=self.mosaic_border) # border to remove
  623. return img4, labels4
  624. def load_mosaic9(self, index):
  625. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  626. labels9, segments9 = [], []
  627. s = self.img_size
  628. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  629. random.shuffle(indices)
  630. for i, index in enumerate(indices):
  631. # Load image
  632. img, _, (h, w) = load_image(self, index)
  633. # place img in img9
  634. if i == 0: # center
  635. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  636. h0, w0 = h, w
  637. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  638. elif i == 1: # top
  639. c = s, s - h, s + w, s
  640. elif i == 2: # top right
  641. c = s + wp, s - h, s + wp + w, s
  642. elif i == 3: # right
  643. c = s + w0, s, s + w0 + w, s + h
  644. elif i == 4: # bottom right
  645. c = s + w0, s + hp, s + w0 + w, s + hp + h
  646. elif i == 5: # bottom
  647. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  648. elif i == 6: # bottom left
  649. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  650. elif i == 7: # left
  651. c = s - w, s + h0 - h, s, s + h0
  652. elif i == 8: # top left
  653. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  654. padx, pady = c[:2]
  655. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  656. # Labels
  657. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  658. if labels.size:
  659. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  660. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  661. labels9.append(labels)
  662. segments9.extend(segments)
  663. # Image
  664. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  665. hp, wp = h, w # height, width previous
  666. # Offset
  667. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  668. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  669. # Concat/clip labels
  670. labels9 = np.concatenate(labels9, 0)
  671. labels9[:, [1, 3]] -= xc
  672. labels9[:, [2, 4]] -= yc
  673. c = np.array([xc, yc]) # centers
  674. segments9 = [x - c for x in segments9]
  675. for x in (labels9[:, 1:], *segments9):
  676. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  677. # img9, labels9 = replicate(img9, labels9) # replicate
  678. # Augment
  679. img9, labels9 = random_perspective(img9, labels9, segments9,
  680. degrees=self.hyp['degrees'],
  681. translate=self.hyp['translate'],
  682. scale=self.hyp['scale'],
  683. shear=self.hyp['shear'],
  684. perspective=self.hyp['perspective'],
  685. border=self.mosaic_border) # border to remove
  686. return img9, labels9
  687. def create_folder(path='./new'):
  688. # Create folder
  689. if os.path.exists(path):
  690. shutil.rmtree(path) # delete output folder
  691. os.makedirs(path) # make new output folder
  692. def flatten_recursive(path='../datasets/coco128'):
  693. # Flatten a recursive directory by bringing all files to top level
  694. new_path = Path(path + '_flat')
  695. create_folder(new_path)
  696. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  697. shutil.copyfile(file, new_path / Path(file).name)
  698. def extract_boxes(path='../datasets/coco128'): # from utils.datasets import *; extract_boxes()
  699. # Convert detection dataset into classification dataset, with one directory per class
  700. path = Path(path) # images dir
  701. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  702. files = list(path.rglob('*.*'))
  703. n = len(files) # number of files
  704. for im_file in tqdm(files, total=n):
  705. if im_file.suffix[1:] in IMG_FORMATS:
  706. # image
  707. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  708. h, w = im.shape[:2]
  709. # labels
  710. lb_file = Path(img2label_paths([str(im_file)])[0])
  711. if Path(lb_file).exists():
  712. with open(lb_file) as f:
  713. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  714. for j, x in enumerate(lb):
  715. c = int(x[0]) # class
  716. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  717. if not f.parent.is_dir():
  718. f.parent.mkdir(parents=True)
  719. b = x[1:] * [w, h, w, h] # box
  720. # b[2:] = b[2:].max() # rectangle to square
  721. b[2:] = b[2:] * 1.2 + 3 # pad
  722. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  723. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  724. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  725. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  726. def autosplit(path='../datasets/coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  727. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  728. Usage: from utils.datasets import *; autosplit()
  729. Arguments
  730. path: Path to images directory
  731. weights: Train, val, test weights (list, tuple)
  732. annotated_only: Only use images with an annotated txt file
  733. """
  734. path = Path(path) # images dir
  735. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  736. n = len(files) # number of files
  737. random.seed(0) # for reproducibility
  738. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  739. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  740. [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
  741. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  742. for i, img in tqdm(zip(indices, files), total=n):
  743. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  744. with open(path.parent / txt[i], 'a') as f:
  745. f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
  746. def verify_image_label(args):
  747. # Verify one image-label pair
  748. im_file, lb_file, prefix = args
  749. nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
  750. try:
  751. # verify images
  752. im = Image.open(im_file)
  753. im.verify() # PIL verify
  754. shape = exif_size(im) # image size
  755. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  756. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  757. if im.format.lower() in ('jpg', 'jpeg'):
  758. with open(im_file, 'rb') as f:
  759. f.seek(-2, 2)
  760. if f.read() != b'\xff\xd9': # corrupt JPEG
  761. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  762. msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'
  763. # verify labels
  764. if os.path.isfile(lb_file):
  765. nf = 1 # label found
  766. with open(lb_file) as f:
  767. l = [x.split() for x in f.read().strip().splitlines() if len(x)]
  768. if any([len(x) > 8 for x in l]): # is segment
  769. classes = np.array([x[0] for x in l], dtype=np.float32)
  770. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
  771. l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  772. l = np.array(l, dtype=np.float32)
  773. nl = len(l)
  774. if nl:
  775. assert l.shape[1] == 5, f'labels require 5 columns, {l.shape[1]} columns detected'
  776. assert (l >= 0).all(), f'negative label values {l[l < 0]}'
  777. assert (l[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {l[:, 1:][l[:, 1:] > 1]}'
  778. _, i = np.unique(l, axis=0, return_index=True)
  779. if len(i) < nl: # duplicate row check
  780. l = l[i] # remove duplicates
  781. if segments:
  782. segments = segments[i]
  783. msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
  784. else:
  785. ne = 1 # label empty
  786. l = np.zeros((0, 5), dtype=np.float32)
  787. else:
  788. nm = 1 # label missing
  789. l = np.zeros((0, 5), dtype=np.float32)
  790. return im_file, l, shape, segments, nm, nf, ne, nc, msg
  791. except Exception as e:
  792. nc = 1
  793. msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
  794. return [None, None, None, None, nm, nf, ne, nc, msg]
  795. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
  796. """ Return dataset statistics dictionary with images and instances counts per split per class
  797. To run in parent directory: export PYTHONPATH="$PWD/yolov5"
  798. Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', autodownload=True)
  799. Usage2: from utils.datasets import *; dataset_stats('../datasets/coco128_with_yaml.zip')
  800. Arguments
  801. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  802. autodownload: Attempt to download dataset if not found locally
  803. verbose: Print stats dictionary
  804. """
  805. def round_labels(labels):
  806. # Update labels to integer class and 6 decimal place floats
  807. return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
  808. def unzip(path):
  809. # Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
  810. if str(path).endswith('.zip'): # path is data.zip
  811. assert Path(path).is_file(), f'Error unzipping {path}, file not found'
  812. ZipFile(path).extractall(path=path.parent) # unzip
  813. dir = path.with_suffix('') # dataset directory == zip name
  814. return True, str(dir), next(dir.rglob('*.yaml')) # zipped, data_dir, yaml_path
  815. else: # path is data.yaml
  816. return False, None, path
  817. def hub_ops(f, max_dim=1920):
  818. # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
  819. f_new = im_dir / Path(f).name # dataset-hub image filename
  820. try: # use PIL
  821. im = Image.open(f)
  822. r = max_dim / max(im.height, im.width) # ratio
  823. if r < 1.0: # image too large
  824. im = im.resize((int(im.width * r), int(im.height * r)))
  825. im.save(f_new, 'JPEG', quality=75, optimize=True) # save
  826. except Exception as e: # use OpenCV
  827. print(f'WARNING: HUB ops PIL failure {f}: {e}')
  828. im = cv2.imread(f)
  829. im_height, im_width = im.shape[:2]
  830. r = max_dim / max(im_height, im_width) # ratio
  831. if r < 1.0: # image too large
  832. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  833. cv2.imwrite(str(f_new), im)
  834. zipped, data_dir, yaml_path = unzip(Path(path))
  835. with open(check_yaml(yaml_path), errors='ignore') as f:
  836. data = yaml.safe_load(f) # data dict
  837. if zipped:
  838. data['path'] = data_dir # TODO: should this be dir.resolve()?
  839. check_dataset(data, autodownload) # download dataset if missing
  840. hub_dir = Path(data['path'] + ('-hub' if hub else ''))
  841. stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
  842. for split in 'train', 'val', 'test':
  843. if data.get(split) is None:
  844. stats[split] = None # i.e. no test set
  845. continue
  846. x = []
  847. dataset = LoadImagesAndLabels(data[split]) # load dataset
  848. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  849. x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
  850. x = np.array(x) # shape(128x80)
  851. stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
  852. 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
  853. 'per_class': (x > 0).sum(0).tolist()},
  854. 'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
  855. zip(dataset.img_files, dataset.labels)]}
  856. if hub:
  857. im_dir = hub_dir / 'images'
  858. im_dir.mkdir(parents=True, exist_ok=True)
  859. for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.img_files), total=dataset.n, desc='HUB Ops'):
  860. pass
  861. # Profile
  862. stats_path = hub_dir / 'stats.json'
  863. if profile:
  864. for _ in range(1):
  865. file = stats_path.with_suffix('.npy')
  866. t1 = time.time()
  867. np.save(file, stats)
  868. t2 = time.time()
  869. x = np.load(file, allow_pickle=True)
  870. print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  871. file = stats_path.with_suffix('.json')
  872. t1 = time.time()
  873. with open(file, 'w') as f:
  874. json.dump(stats, f) # save stats *.json
  875. t2 = time.time()
  876. with open(file) as f:
  877. x = json.load(f) # load hyps dict
  878. print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  879. # Save, print and return
  880. if hub:
  881. print(f'Saving {stats_path.resolve()}...')
  882. with open(stats_path, 'w') as f:
  883. json.dump(stats, f) # save stats.json
  884. if verbose:
  885. print(json.dumps(stats, indent=2, sort_keys=False))
  886. return stats