You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1077 lines
46KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Dataloaders and dataset utils
  4. """
  5. import glob
  6. import hashlib
  7. import json
  8. import math
  9. import os
  10. import random
  11. import shutil
  12. import time
  13. from itertools import repeat
  14. from multiprocessing.pool import Pool, ThreadPool
  15. from pathlib import Path
  16. from threading import Thread
  17. from urllib.parse import urlparse
  18. from zipfile import ZipFile
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. import yaml
  23. from PIL import ExifTags, Image, ImageOps
  24. from torch.utils.data import DataLoader, Dataset, dataloader, distributed
  25. from tqdm import tqdm
  26. from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
  27. from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
  28. cv2, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
  29. from utils.torch_utils import torch_distributed_zero_first
  30. # Parameters
  31. HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  32. IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
  33. VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
  34. BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
  35. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  36. # Get orientation exif tag
  37. for orientation in ExifTags.TAGS.keys():
  38. if ExifTags.TAGS[orientation] == 'Orientation':
  39. break
  40. def get_hash(paths):
  41. # Returns a single hash value of a list of paths (files or dirs)
  42. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  43. h = hashlib.md5(str(size).encode()) # hash sizes
  44. h.update(''.join(paths).encode()) # hash paths
  45. return h.hexdigest() # return hash
  46. def exif_size(img):
  47. # Returns exif-corrected PIL size
  48. s = img.size # (width, height)
  49. try:
  50. rotation = dict(img._getexif().items())[orientation]
  51. if rotation in [6, 8]: # rotation 270 or 90
  52. s = (s[1], s[0])
  53. except Exception:
  54. pass
  55. return s
  56. def exif_transpose(image):
  57. """
  58. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  59. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
  60. :param image: The image to transpose.
  61. :return: An image.
  62. """
  63. exif = image.getexif()
  64. orientation = exif.get(0x0112, 1) # default 1
  65. if orientation > 1:
  66. method = {
  67. 2: Image.FLIP_LEFT_RIGHT,
  68. 3: Image.ROTATE_180,
  69. 4: Image.FLIP_TOP_BOTTOM,
  70. 5: Image.TRANSPOSE,
  71. 6: Image.ROTATE_270,
  72. 7: Image.TRANSVERSE,
  73. 8: Image.ROTATE_90,}.get(orientation)
  74. if method is not None:
  75. image = image.transpose(method)
  76. del exif[0x0112]
  77. image.info["exif"] = exif.tobytes()
  78. return image
  79. def create_dataloader(path,
  80. imgsz,
  81. batch_size,
  82. stride,
  83. single_cls=False,
  84. hyp=None,
  85. augment=False,
  86. cache=False,
  87. pad=0.0,
  88. rect=False,
  89. rank=-1,
  90. workers=8,
  91. image_weights=False,
  92. quad=False,
  93. prefix='',
  94. shuffle=False):
  95. if rect and shuffle:
  96. LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
  97. shuffle = False
  98. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  99. dataset = LoadImagesAndLabels(
  100. path,
  101. imgsz,
  102. batch_size,
  103. augment=augment, # augmentation
  104. hyp=hyp, # hyperparameters
  105. rect=rect, # rectangular batches
  106. cache_images=cache,
  107. single_cls=single_cls,
  108. stride=int(stride),
  109. pad=pad,
  110. image_weights=image_weights,
  111. prefix=prefix)
  112. batch_size = min(batch_size, len(dataset))
  113. nd = torch.cuda.device_count() # number of CUDA devices
  114. nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
  115. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  116. loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
  117. return loader(dataset,
  118. batch_size=batch_size,
  119. shuffle=shuffle and sampler is None,
  120. num_workers=nw,
  121. sampler=sampler,
  122. pin_memory=True,
  123. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
  124. class InfiniteDataLoader(dataloader.DataLoader):
  125. """ Dataloader that reuses workers
  126. Uses same syntax as vanilla DataLoader
  127. """
  128. def __init__(self, *args, **kwargs):
  129. super().__init__(*args, **kwargs)
  130. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  131. self.iterator = super().__iter__()
  132. def __len__(self):
  133. return len(self.batch_sampler.sampler)
  134. def __iter__(self):
  135. for _ in range(len(self)):
  136. yield next(self.iterator)
  137. class _RepeatSampler:
  138. """ Sampler that repeats forever
  139. Args:
  140. sampler (Sampler)
  141. """
  142. def __init__(self, sampler):
  143. self.sampler = sampler
  144. def __iter__(self):
  145. while True:
  146. yield from iter(self.sampler)
  147. class LoadImages:
  148. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  149. def __init__(self, path, img_size=640, stride=32, auto=True):
  150. p = str(Path(path).resolve()) # os-agnostic absolute path
  151. if '*' in p:
  152. files = sorted(glob.glob(p, recursive=True)) # glob
  153. elif os.path.isdir(p):
  154. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  155. elif os.path.isfile(p):
  156. files = [p] # files
  157. else:
  158. raise Exception(f'ERROR: {p} does not exist')
  159. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  160. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  161. ni, nv = len(images), len(videos)
  162. self.img_size = img_size
  163. self.stride = stride
  164. self.files = images + videos
  165. self.nf = ni + nv # number of files
  166. self.video_flag = [False] * ni + [True] * nv
  167. self.mode = 'image'
  168. self.auto = auto
  169. if any(videos):
  170. self.new_video(videos[0]) # new video
  171. else:
  172. self.cap = None
  173. assert self.nf > 0, f'No images or videos found in {p}. ' \
  174. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  175. def __iter__(self):
  176. self.count = 0
  177. return self
  178. def __next__(self):
  179. if self.count == self.nf:
  180. raise StopIteration
  181. path = self.files[self.count]
  182. if self.video_flag[self.count]:
  183. # Read video
  184. self.mode = 'video'
  185. ret_val, img0 = self.cap.read()
  186. while not ret_val:
  187. self.count += 1
  188. self.cap.release()
  189. if self.count == self.nf: # last video
  190. raise StopIteration
  191. path = self.files[self.count]
  192. self.new_video(path)
  193. ret_val, img0 = self.cap.read()
  194. self.frame += 1
  195. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  196. else:
  197. # Read image
  198. self.count += 1
  199. img0 = cv2.imread(path) # BGR
  200. assert img0 is not None, f'Image Not Found {path}'
  201. s = f'image {self.count}/{self.nf} {path}: '
  202. # Padded resize
  203. img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
  204. # Convert
  205. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  206. img = np.ascontiguousarray(img)
  207. return path, img, img0, self.cap, s
  208. def new_video(self, path):
  209. self.frame = 0
  210. self.cap = cv2.VideoCapture(path)
  211. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  212. def __len__(self):
  213. return self.nf # number of files
  214. class LoadWebcam: # for inference
  215. # YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
  216. def __init__(self, pipe='0', img_size=640, stride=32):
  217. self.img_size = img_size
  218. self.stride = stride
  219. self.pipe = eval(pipe) if pipe.isnumeric() else pipe
  220. self.cap = cv2.VideoCapture(self.pipe) # video capture object
  221. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  222. def __iter__(self):
  223. self.count = -1
  224. return self
  225. def __next__(self):
  226. self.count += 1
  227. if cv2.waitKey(1) == ord('q'): # q to quit
  228. self.cap.release()
  229. cv2.destroyAllWindows()
  230. raise StopIteration
  231. # Read frame
  232. ret_val, img0 = self.cap.read()
  233. img0 = cv2.flip(img0, 1) # flip left-right
  234. # Print
  235. assert ret_val, f'Camera Error {self.pipe}'
  236. img_path = 'webcam.jpg'
  237. s = f'webcam {self.count}: '
  238. # Padded resize
  239. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  240. # Convert
  241. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  242. img = np.ascontiguousarray(img)
  243. return img_path, img, img0, None, s
  244. def __len__(self):
  245. return 0
  246. class LoadStreams:
  247. # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
  248. def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
  249. self.mode = 'stream'
  250. self.img_size = img_size
  251. self.stride = stride
  252. if os.path.isfile(sources):
  253. with open(sources) as f:
  254. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  255. else:
  256. sources = [sources]
  257. n = len(sources)
  258. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  259. self.sources = [clean_str(x) for x in sources] # clean source names for later
  260. self.auto = auto
  261. for i, s in enumerate(sources): # index, source
  262. # Start thread to read frames from video stream
  263. st = f'{i + 1}/{n}: {s}... '
  264. if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
  265. check_requirements(('pafy', 'youtube_dl==2020.12.2'))
  266. import pafy
  267. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  268. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  269. cap = cv2.VideoCapture(s)
  270. assert cap.isOpened(), f'{st}Failed to open {s}'
  271. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  272. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  273. fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
  274. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  275. self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
  276. _, self.imgs[i] = cap.read() # guarantee first frame
  277. self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
  278. LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  279. self.threads[i].start()
  280. LOGGER.info('') # newline
  281. # check for common shapes
  282. s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
  283. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  284. if not self.rect:
  285. LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  286. def update(self, i, cap, stream):
  287. # Read stream `i` frames in daemon thread
  288. n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
  289. while cap.isOpened() and n < f:
  290. n += 1
  291. # _, self.imgs[index] = cap.read()
  292. cap.grab()
  293. if n % read == 0:
  294. success, im = cap.retrieve()
  295. if success:
  296. self.imgs[i] = im
  297. else:
  298. LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
  299. self.imgs[i] = np.zeros_like(self.imgs[i])
  300. cap.open(stream) # re-open stream if signal was lost
  301. time.sleep(1 / self.fps[i]) # wait time
  302. def __iter__(self):
  303. self.count = -1
  304. return self
  305. def __next__(self):
  306. self.count += 1
  307. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  308. cv2.destroyAllWindows()
  309. raise StopIteration
  310. # Letterbox
  311. img0 = self.imgs.copy()
  312. img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
  313. # Stack
  314. img = np.stack(img, 0)
  315. # Convert
  316. img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
  317. img = np.ascontiguousarray(img)
  318. return self.sources, img, img0, None, ''
  319. def __len__(self):
  320. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  321. def img2label_paths(img_paths):
  322. # Define label paths as a function of image paths
  323. sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
  324. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  325. class LoadImagesAndLabels(Dataset):
  326. # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
  327. cache_version = 0.6 # dataset labels *.cache version
  328. rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
  329. def __init__(self,
  330. path,
  331. img_size=640,
  332. batch_size=16,
  333. augment=False,
  334. hyp=None,
  335. rect=False,
  336. image_weights=False,
  337. cache_images=False,
  338. single_cls=False,
  339. stride=32,
  340. pad=0.0,
  341. prefix=''):
  342. self.img_size = img_size
  343. self.augment = augment
  344. self.hyp = hyp
  345. self.image_weights = image_weights
  346. self.rect = False if image_weights else rect
  347. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  348. self.mosaic_border = [-img_size // 2, -img_size // 2]
  349. self.stride = stride
  350. self.path = path
  351. self.albumentations = Albumentations() if augment else None
  352. try:
  353. f = [] # image files
  354. for p in path if isinstance(path, list) else [path]:
  355. p = Path(p) # os-agnostic
  356. if p.is_dir(): # dir
  357. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  358. # f = list(p.rglob('*.*')) # pathlib
  359. elif p.is_file(): # file
  360. with open(p) as t:
  361. t = t.read().strip().splitlines()
  362. parent = str(p.parent) + os.sep
  363. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  364. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  365. else:
  366. raise Exception(f'{prefix}{p} does not exist')
  367. self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
  368. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  369. assert self.im_files, f'{prefix}No images found'
  370. except Exception as e:
  371. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
  372. # Check cache
  373. self.label_files = img2label_paths(self.im_files) # labels
  374. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
  375. try:
  376. cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
  377. assert cache['version'] == self.cache_version # same version
  378. assert cache['hash'] == get_hash(self.label_files + self.im_files) # same hash
  379. except Exception:
  380. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  381. # Display cache
  382. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
  383. if exists and LOCAL_RANK in {-1, 0}:
  384. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  385. tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
  386. if cache['msgs']:
  387. LOGGER.info('\n'.join(cache['msgs'])) # display warnings
  388. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
  389. # Read cache
  390. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  391. labels, shapes, self.segments = zip(*cache.values())
  392. self.labels = list(labels)
  393. self.shapes = np.array(shapes, dtype=np.float64)
  394. self.im_files = list(cache.keys()) # update
  395. self.label_files = img2label_paths(cache.keys()) # update
  396. n = len(shapes) # number of images
  397. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  398. nb = bi[-1] + 1 # number of batches
  399. self.batch = bi # batch index of image
  400. self.n = n
  401. self.indices = range(n)
  402. # Update labels
  403. include_class = [] # filter labels to include only these classes (optional)
  404. include_class_array = np.array(include_class).reshape(1, -1)
  405. for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
  406. if include_class:
  407. j = (label[:, 0:1] == include_class_array).any(1)
  408. self.labels[i] = label[j]
  409. if segment:
  410. self.segments[i] = segment[j]
  411. if single_cls: # single-class training, merge all classes into 0
  412. self.labels[i][:, 0] = 0
  413. if segment:
  414. self.segments[i][:, 0] = 0
  415. # Rectangular Training
  416. if self.rect:
  417. # Sort by aspect ratio
  418. s = self.shapes # wh
  419. ar = s[:, 1] / s[:, 0] # aspect ratio
  420. irect = ar.argsort()
  421. self.im_files = [self.im_files[i] for i in irect]
  422. self.label_files = [self.label_files[i] for i in irect]
  423. self.labels = [self.labels[i] for i in irect]
  424. self.shapes = s[irect] # wh
  425. ar = ar[irect]
  426. # Set training image shapes
  427. shapes = [[1, 1]] * nb
  428. for i in range(nb):
  429. ari = ar[bi == i]
  430. mini, maxi = ari.min(), ari.max()
  431. if maxi < 1:
  432. shapes[i] = [maxi, 1]
  433. elif mini > 1:
  434. shapes[i] = [1, 1 / mini]
  435. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  436. # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources)
  437. self.ims = [None] * n
  438. self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
  439. if cache_images:
  440. gb = 0 # Gigabytes of cached images
  441. self.im_hw0, self.im_hw = [None] * n, [None] * n
  442. fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
  443. results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
  444. pbar = tqdm(enumerate(results), total=n, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
  445. for i, x in pbar:
  446. if cache_images == 'disk':
  447. gb += self.npy_files[i].stat().st_size
  448. else: # 'ram'
  449. self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  450. gb += self.ims[i].nbytes
  451. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
  452. pbar.close()
  453. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  454. # Cache dataset labels, check images and read shapes
  455. x = {} # dict
  456. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  457. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  458. with Pool(NUM_THREADS) as pool:
  459. pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
  460. desc=desc,
  461. total=len(self.im_files),
  462. bar_format=BAR_FORMAT)
  463. for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  464. nm += nm_f
  465. nf += nf_f
  466. ne += ne_f
  467. nc += nc_f
  468. if im_file:
  469. x[im_file] = [lb, shape, segments]
  470. if msg:
  471. msgs.append(msg)
  472. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  473. pbar.close()
  474. if msgs:
  475. LOGGER.info('\n'.join(msgs))
  476. if nf == 0:
  477. LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
  478. x['hash'] = get_hash(self.label_files + self.im_files)
  479. x['results'] = nf, nm, ne, nc, len(self.im_files)
  480. x['msgs'] = msgs # warnings
  481. x['version'] = self.cache_version # cache version
  482. try:
  483. np.save(path, x) # save cache for next time
  484. path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
  485. LOGGER.info(f'{prefix}New cache created: {path}')
  486. except Exception as e:
  487. LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable
  488. return x
  489. def __len__(self):
  490. return len(self.im_files)
  491. # def __iter__(self):
  492. # self.count = -1
  493. # print('ran dataset iter')
  494. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  495. # return self
  496. def __getitem__(self, index):
  497. index = self.indices[index] # linear, shuffled, or image_weights
  498. hyp = self.hyp
  499. mosaic = self.mosaic and random.random() < hyp['mosaic']
  500. if mosaic:
  501. # Load mosaic
  502. img, labels = self.load_mosaic(index)
  503. shapes = None
  504. # MixUp augmentation
  505. if random.random() < hyp['mixup']:
  506. img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
  507. else:
  508. # Load image
  509. img, (h0, w0), (h, w) = self.load_image(index)
  510. # Letterbox
  511. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  512. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  513. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  514. labels = self.labels[index].copy()
  515. if labels.size: # normalized xywh to pixel xyxy format
  516. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  517. if self.augment:
  518. img, labels = random_perspective(img,
  519. labels,
  520. degrees=hyp['degrees'],
  521. translate=hyp['translate'],
  522. scale=hyp['scale'],
  523. shear=hyp['shear'],
  524. perspective=hyp['perspective'])
  525. nl = len(labels) # number of labels
  526. if nl:
  527. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
  528. if self.augment:
  529. # Albumentations
  530. img, labels = self.albumentations(img, labels)
  531. nl = len(labels) # update after albumentations
  532. # HSV color-space
  533. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  534. # Flip up-down
  535. if random.random() < hyp['flipud']:
  536. img = np.flipud(img)
  537. if nl:
  538. labels[:, 2] = 1 - labels[:, 2]
  539. # Flip left-right
  540. if random.random() < hyp['fliplr']:
  541. img = np.fliplr(img)
  542. if nl:
  543. labels[:, 1] = 1 - labels[:, 1]
  544. # Cutouts
  545. # labels = cutout(img, labels, p=0.5)
  546. # nl = len(labels) # update after cutout
  547. labels_out = torch.zeros((nl, 6))
  548. if nl:
  549. labels_out[:, 1:] = torch.from_numpy(labels)
  550. # Convert
  551. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  552. img = np.ascontiguousarray(img)
  553. return torch.from_numpy(img), labels_out, self.im_files[index], shapes
  554. def load_image(self, i):
  555. # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
  556. im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
  557. if im is None: # not cached in RAM
  558. if fn.exists(): # load npy
  559. im = np.load(fn)
  560. else: # read image
  561. im = cv2.imread(f) # BGR
  562. assert im is not None, f'Image Not Found {f}'
  563. h0, w0 = im.shape[:2] # orig hw
  564. r = self.img_size / max(h0, w0) # ratio
  565. if r != 1: # if sizes are not equal
  566. interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
  567. im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
  568. return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  569. else:
  570. return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
  571. def cache_images_to_disk(self, i):
  572. # Saves an image as an *.npy file for faster loading
  573. f = self.npy_files[i]
  574. if not f.exists():
  575. np.save(f.as_posix(), cv2.imread(self.im_files[i]))
  576. def load_mosaic(self, index):
  577. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  578. labels4, segments4 = [], []
  579. s = self.img_size
  580. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
  581. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  582. random.shuffle(indices)
  583. for i, index in enumerate(indices):
  584. # Load image
  585. img, _, (h, w) = self.load_image(index)
  586. # place img in img4
  587. if i == 0: # top left
  588. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  589. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  590. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  591. elif i == 1: # top right
  592. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  593. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  594. elif i == 2: # bottom left
  595. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  596. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  597. elif i == 3: # bottom right
  598. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  599. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  600. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  601. padw = x1a - x1b
  602. padh = y1a - y1b
  603. # Labels
  604. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  605. if labels.size:
  606. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  607. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  608. labels4.append(labels)
  609. segments4.extend(segments)
  610. # Concat/clip labels
  611. labels4 = np.concatenate(labels4, 0)
  612. for x in (labels4[:, 1:], *segments4):
  613. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  614. # img4, labels4 = replicate(img4, labels4) # replicate
  615. # Augment
  616. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
  617. img4, labels4 = random_perspective(img4,
  618. labels4,
  619. segments4,
  620. degrees=self.hyp['degrees'],
  621. translate=self.hyp['translate'],
  622. scale=self.hyp['scale'],
  623. shear=self.hyp['shear'],
  624. perspective=self.hyp['perspective'],
  625. border=self.mosaic_border) # border to remove
  626. return img4, labels4
  627. def load_mosaic9(self, index):
  628. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  629. labels9, segments9 = [], []
  630. s = self.img_size
  631. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  632. random.shuffle(indices)
  633. hp, wp = -1, -1 # height, width previous
  634. for i, index in enumerate(indices):
  635. # Load image
  636. img, _, (h, w) = self.load_image(index)
  637. # place img in img9
  638. if i == 0: # center
  639. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  640. h0, w0 = h, w
  641. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  642. elif i == 1: # top
  643. c = s, s - h, s + w, s
  644. elif i == 2: # top right
  645. c = s + wp, s - h, s + wp + w, s
  646. elif i == 3: # right
  647. c = s + w0, s, s + w0 + w, s + h
  648. elif i == 4: # bottom right
  649. c = s + w0, s + hp, s + w0 + w, s + hp + h
  650. elif i == 5: # bottom
  651. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  652. elif i == 6: # bottom left
  653. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  654. elif i == 7: # left
  655. c = s - w, s + h0 - h, s, s + h0
  656. elif i == 8: # top left
  657. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  658. padx, pady = c[:2]
  659. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  660. # Labels
  661. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  662. if labels.size:
  663. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  664. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  665. labels9.append(labels)
  666. segments9.extend(segments)
  667. # Image
  668. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  669. hp, wp = h, w # height, width previous
  670. # Offset
  671. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  672. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  673. # Concat/clip labels
  674. labels9 = np.concatenate(labels9, 0)
  675. labels9[:, [1, 3]] -= xc
  676. labels9[:, [2, 4]] -= yc
  677. c = np.array([xc, yc]) # centers
  678. segments9 = [x - c for x in segments9]
  679. for x in (labels9[:, 1:], *segments9):
  680. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  681. # img9, labels9 = replicate(img9, labels9) # replicate
  682. # Augment
  683. img9, labels9 = random_perspective(img9,
  684. labels9,
  685. segments9,
  686. degrees=self.hyp['degrees'],
  687. translate=self.hyp['translate'],
  688. scale=self.hyp['scale'],
  689. shear=self.hyp['shear'],
  690. perspective=self.hyp['perspective'],
  691. border=self.mosaic_border) # border to remove
  692. return img9, labels9
  693. @staticmethod
  694. def collate_fn(batch):
  695. im, label, path, shapes = zip(*batch) # transposed
  696. for i, lb in enumerate(label):
  697. lb[:, 0] = i # add target image index for build_targets()
  698. return torch.stack(im, 0), torch.cat(label, 0), path, shapes
  699. @staticmethod
  700. def collate_fn4(batch):
  701. img, label, path, shapes = zip(*batch) # transposed
  702. n = len(shapes) // 4
  703. im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  704. ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
  705. wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
  706. s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
  707. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  708. i *= 4
  709. if random.random() < 0.5:
  710. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
  711. align_corners=False)[0].type(img[i].type())
  712. lb = label[i]
  713. else:
  714. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  715. lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  716. im4.append(im)
  717. label4.append(lb)
  718. for i, lb in enumerate(label4):
  719. lb[:, 0] = i # add target image index for build_targets()
  720. return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
  721. # Ancillary functions --------------------------------------------------------------------------------------------------
  722. def create_folder(path='./new'):
  723. # Create folder
  724. if os.path.exists(path):
  725. shutil.rmtree(path) # delete output folder
  726. os.makedirs(path) # make new output folder
  727. def flatten_recursive(path=DATASETS_DIR / 'coco128'):
  728. # Flatten a recursive directory by bringing all files to top level
  729. new_path = Path(str(path) + '_flat')
  730. create_folder(new_path)
  731. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  732. shutil.copyfile(file, new_path / Path(file).name)
  733. def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.datasets import *; extract_boxes()
  734. # Convert detection dataset into classification dataset, with one directory per class
  735. path = Path(path) # images dir
  736. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  737. files = list(path.rglob('*.*'))
  738. n = len(files) # number of files
  739. for im_file in tqdm(files, total=n):
  740. if im_file.suffix[1:] in IMG_FORMATS:
  741. # image
  742. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  743. h, w = im.shape[:2]
  744. # labels
  745. lb_file = Path(img2label_paths([str(im_file)])[0])
  746. if Path(lb_file).exists():
  747. with open(lb_file) as f:
  748. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  749. for j, x in enumerate(lb):
  750. c = int(x[0]) # class
  751. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  752. if not f.parent.is_dir():
  753. f.parent.mkdir(parents=True)
  754. b = x[1:] * [w, h, w, h] # box
  755. # b[2:] = b[2:].max() # rectangle to square
  756. b[2:] = b[2:] * 1.2 + 3 # pad
  757. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  758. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  759. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  760. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  761. def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  762. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  763. Usage: from utils.datasets import *; autosplit()
  764. Arguments
  765. path: Path to images directory
  766. weights: Train, val, test weights (list, tuple)
  767. annotated_only: Only use images with an annotated txt file
  768. """
  769. path = Path(path) # images dir
  770. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  771. n = len(files) # number of files
  772. random.seed(0) # for reproducibility
  773. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  774. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  775. [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
  776. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  777. for i, img in tqdm(zip(indices, files), total=n):
  778. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  779. with open(path.parent / txt[i], 'a') as f:
  780. f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
  781. def verify_image_label(args):
  782. # Verify one image-label pair
  783. im_file, lb_file, prefix = args
  784. nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
  785. try:
  786. # verify images
  787. im = Image.open(im_file)
  788. im.verify() # PIL verify
  789. shape = exif_size(im) # image size
  790. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  791. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  792. if im.format.lower() in ('jpg', 'jpeg'):
  793. with open(im_file, 'rb') as f:
  794. f.seek(-2, 2)
  795. if f.read() != b'\xff\xd9': # corrupt JPEG
  796. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  797. msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'
  798. # verify labels
  799. if os.path.isfile(lb_file):
  800. nf = 1 # label found
  801. with open(lb_file) as f:
  802. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  803. if any(len(x) > 6 for x in lb): # is segment
  804. classes = np.array([x[0] for x in lb], dtype=np.float32)
  805. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  806. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  807. lb = np.array(lb, dtype=np.float32)
  808. nl = len(lb)
  809. if nl:
  810. assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
  811. assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
  812. assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
  813. _, i = np.unique(lb, axis=0, return_index=True)
  814. if len(i) < nl: # duplicate row check
  815. lb = lb[i] # remove duplicates
  816. if segments:
  817. segments = segments[i]
  818. msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
  819. else:
  820. ne = 1 # label empty
  821. lb = np.zeros((0, 5), dtype=np.float32)
  822. else:
  823. nm = 1 # label missing
  824. lb = np.zeros((0, 5), dtype=np.float32)
  825. return im_file, lb, shape, segments, nm, nf, ne, nc, msg
  826. except Exception as e:
  827. nc = 1
  828. msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
  829. return [None, None, None, None, nm, nf, ne, nc, msg]
  830. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
  831. """ Return dataset statistics dictionary with images and instances counts per split per class
  832. To run in parent directory: export PYTHONPATH="$PWD/yolov5"
  833. Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', autodownload=True)
  834. Usage2: from utils.datasets import *; dataset_stats('path/to/coco128_with_yaml.zip')
  835. Arguments
  836. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  837. autodownload: Attempt to download dataset if not found locally
  838. verbose: Print stats dictionary
  839. """
  840. def round_labels(labels):
  841. # Update labels to integer class and 6 decimal place floats
  842. return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
  843. def unzip(path):
  844. # Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
  845. if str(path).endswith('.zip'): # path is data.zip
  846. assert Path(path).is_file(), f'Error unzipping {path}, file not found'
  847. ZipFile(path).extractall(path=path.parent) # unzip
  848. dir = path.with_suffix('') # dataset directory == zip name
  849. return True, str(dir), next(dir.rglob('*.yaml')) # zipped, data_dir, yaml_path
  850. else: # path is data.yaml
  851. return False, None, path
  852. def hub_ops(f, max_dim=1920):
  853. # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
  854. f_new = im_dir / Path(f).name # dataset-hub image filename
  855. try: # use PIL
  856. im = Image.open(f)
  857. r = max_dim / max(im.height, im.width) # ratio
  858. if r < 1.0: # image too large
  859. im = im.resize((int(im.width * r), int(im.height * r)))
  860. im.save(f_new, 'JPEG', quality=75, optimize=True) # save
  861. except Exception as e: # use OpenCV
  862. print(f'WARNING: HUB ops PIL failure {f}: {e}')
  863. im = cv2.imread(f)
  864. im_height, im_width = im.shape[:2]
  865. r = max_dim / max(im_height, im_width) # ratio
  866. if r < 1.0: # image too large
  867. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  868. cv2.imwrite(str(f_new), im)
  869. zipped, data_dir, yaml_path = unzip(Path(path))
  870. with open(check_yaml(yaml_path), errors='ignore') as f:
  871. data = yaml.safe_load(f) # data dict
  872. if zipped:
  873. data['path'] = data_dir # TODO: should this be dir.resolve()?
  874. check_dataset(data, autodownload) # download dataset if missing
  875. hub_dir = Path(data['path'] + ('-hub' if hub else ''))
  876. stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
  877. for split in 'train', 'val', 'test':
  878. if data.get(split) is None:
  879. stats[split] = None # i.e. no test set
  880. continue
  881. x = []
  882. dataset = LoadImagesAndLabels(data[split]) # load dataset
  883. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  884. x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
  885. x = np.array(x) # shape(128x80)
  886. stats[split] = {
  887. 'instance_stats': {
  888. 'total': int(x.sum()),
  889. 'per_class': x.sum(0).tolist()},
  890. 'image_stats': {
  891. 'total': dataset.n,
  892. 'unlabelled': int(np.all(x == 0, 1).sum()),
  893. 'per_class': (x > 0).sum(0).tolist()},
  894. 'labels': [{
  895. str(Path(k).name): round_labels(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
  896. if hub:
  897. im_dir = hub_dir / 'images'
  898. im_dir.mkdir(parents=True, exist_ok=True)
  899. for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.im_files), total=dataset.n, desc='HUB Ops'):
  900. pass
  901. # Profile
  902. stats_path = hub_dir / 'stats.json'
  903. if profile:
  904. for _ in range(1):
  905. file = stats_path.with_suffix('.npy')
  906. t1 = time.time()
  907. np.save(file, stats)
  908. t2 = time.time()
  909. x = np.load(file, allow_pickle=True)
  910. print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  911. file = stats_path.with_suffix('.json')
  912. t1 = time.time()
  913. with open(file, 'w') as f:
  914. json.dump(stats, f) # save stats *.json
  915. t2 = time.time()
  916. with open(file) as f:
  917. x = json.load(f) # load hyps dict
  918. print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  919. # Save, print and return
  920. if hub:
  921. print(f'Saving {stats_path.resolve()}...')
  922. with open(stats_path, 'w') as f:
  923. json.dump(stats, f) # save stats.json
  924. if verbose:
  925. print(json.dumps(stats, indent=2, sort_keys=False))
  926. return stats