Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

1075 lines
45KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Dataloaders and dataset utils
  4. """
  5. import glob
  6. import hashlib
  7. import json
  8. import math
  9. import os
  10. import random
  11. import shutil
  12. import time
  13. from itertools import repeat
  14. from multiprocessing.pool import Pool, ThreadPool
  15. from pathlib import Path
  16. from threading import Thread
  17. from urllib.parse import urlparse
  18. from zipfile import ZipFile
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. import yaml
  23. from PIL import ExifTags, Image, ImageOps
  24. from torch.utils.data import DataLoader, Dataset, dataloader, distributed
  25. from tqdm import tqdm
  26. from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
  27. from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
  28. cv2, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
  29. from utils.torch_utils import torch_distributed_zero_first
  30. # Parameters
  31. HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  32. IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
  33. VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
  34. BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
  35. # Get orientation exif tag
  36. for orientation in ExifTags.TAGS.keys():
  37. if ExifTags.TAGS[orientation] == 'Orientation':
  38. break
  39. def get_hash(paths):
  40. # Returns a single hash value of a list of paths (files or dirs)
  41. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  42. h = hashlib.md5(str(size).encode()) # hash sizes
  43. h.update(''.join(paths).encode()) # hash paths
  44. return h.hexdigest() # return hash
  45. def exif_size(img):
  46. # Returns exif-corrected PIL size
  47. s = img.size # (width, height)
  48. try:
  49. rotation = dict(img._getexif().items())[orientation]
  50. if rotation == 6: # rotation 270
  51. s = (s[1], s[0])
  52. elif rotation == 8: # rotation 90
  53. s = (s[1], s[0])
  54. except Exception:
  55. pass
  56. return s
  57. def exif_transpose(image):
  58. """
  59. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  60. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
  61. :param image: The image to transpose.
  62. :return: An image.
  63. """
  64. exif = image.getexif()
  65. orientation = exif.get(0x0112, 1) # default 1
  66. if orientation > 1:
  67. method = {
  68. 2: Image.FLIP_LEFT_RIGHT,
  69. 3: Image.ROTATE_180,
  70. 4: Image.FLIP_TOP_BOTTOM,
  71. 5: Image.TRANSPOSE,
  72. 6: Image.ROTATE_270,
  73. 7: Image.TRANSVERSE,
  74. 8: Image.ROTATE_90,}.get(orientation)
  75. if method is not None:
  76. image = image.transpose(method)
  77. del exif[0x0112]
  78. image.info["exif"] = exif.tobytes()
  79. return image
  80. def create_dataloader(path,
  81. imgsz,
  82. batch_size,
  83. stride,
  84. single_cls=False,
  85. hyp=None,
  86. augment=False,
  87. cache=False,
  88. pad=0.0,
  89. rect=False,
  90. rank=-1,
  91. workers=8,
  92. image_weights=False,
  93. quad=False,
  94. prefix='',
  95. shuffle=False):
  96. if rect and shuffle:
  97. LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
  98. shuffle = False
  99. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  100. dataset = LoadImagesAndLabels(
  101. path,
  102. imgsz,
  103. batch_size,
  104. augment=augment, # augmentation
  105. hyp=hyp, # hyperparameters
  106. rect=rect, # rectangular batches
  107. cache_images=cache,
  108. single_cls=single_cls,
  109. stride=int(stride),
  110. pad=pad,
  111. image_weights=image_weights,
  112. prefix=prefix)
  113. batch_size = min(batch_size, len(dataset))
  114. nd = torch.cuda.device_count() # number of CUDA devices
  115. nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
  116. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  117. loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
  118. return loader(dataset,
  119. batch_size=batch_size,
  120. shuffle=shuffle and sampler is None,
  121. num_workers=nw,
  122. sampler=sampler,
  123. pin_memory=True,
  124. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn), dataset
  125. class InfiniteDataLoader(dataloader.DataLoader):
  126. """ Dataloader that reuses workers
  127. Uses same syntax as vanilla DataLoader
  128. """
  129. def __init__(self, *args, **kwargs):
  130. super().__init__(*args, **kwargs)
  131. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  132. self.iterator = super().__iter__()
  133. def __len__(self):
  134. return len(self.batch_sampler.sampler)
  135. def __iter__(self):
  136. for i in range(len(self)):
  137. yield next(self.iterator)
  138. class _RepeatSampler:
  139. """ Sampler that repeats forever
  140. Args:
  141. sampler (Sampler)
  142. """
  143. def __init__(self, sampler):
  144. self.sampler = sampler
  145. def __iter__(self):
  146. while True:
  147. yield from iter(self.sampler)
  148. class LoadImages:
  149. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  150. def __init__(self, path, img_size=640, stride=32, auto=True):
  151. p = str(Path(path).resolve()) # os-agnostic absolute path
  152. if '*' in p:
  153. files = sorted(glob.glob(p, recursive=True)) # glob
  154. elif os.path.isdir(p):
  155. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  156. elif os.path.isfile(p):
  157. files = [p] # files
  158. else:
  159. raise Exception(f'ERROR: {p} does not exist')
  160. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  161. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  162. ni, nv = len(images), len(videos)
  163. self.img_size = img_size
  164. self.stride = stride
  165. self.files = images + videos
  166. self.nf = ni + nv # number of files
  167. self.video_flag = [False] * ni + [True] * nv
  168. self.mode = 'image'
  169. self.auto = auto
  170. if any(videos):
  171. self.new_video(videos[0]) # new video
  172. else:
  173. self.cap = None
  174. assert self.nf > 0, f'No images or videos found in {p}. ' \
  175. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  176. def __iter__(self):
  177. self.count = 0
  178. return self
  179. def __next__(self):
  180. if self.count == self.nf:
  181. raise StopIteration
  182. path = self.files[self.count]
  183. if self.video_flag[self.count]:
  184. # Read video
  185. self.mode = 'video'
  186. ret_val, img0 = self.cap.read()
  187. while not ret_val:
  188. self.count += 1
  189. self.cap.release()
  190. if self.count == self.nf: # last video
  191. raise StopIteration
  192. else:
  193. path = self.files[self.count]
  194. self.new_video(path)
  195. ret_val, img0 = self.cap.read()
  196. self.frame += 1
  197. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  198. else:
  199. # Read image
  200. self.count += 1
  201. img0 = cv2.imread(path) # BGR
  202. assert img0 is not None, f'Image Not Found {path}'
  203. s = f'image {self.count}/{self.nf} {path}: '
  204. # Padded resize
  205. img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
  206. # Convert
  207. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  208. img = np.ascontiguousarray(img)
  209. return path, img, img0, self.cap, s
  210. def new_video(self, path):
  211. self.frame = 0
  212. self.cap = cv2.VideoCapture(path)
  213. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  214. def __len__(self):
  215. return self.nf # number of files
  216. class LoadWebcam: # for inference
  217. # YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
  218. def __init__(self, pipe='0', img_size=640, stride=32):
  219. self.img_size = img_size
  220. self.stride = stride
  221. self.pipe = eval(pipe) if pipe.isnumeric() else pipe
  222. self.cap = cv2.VideoCapture(self.pipe) # video capture object
  223. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  224. def __iter__(self):
  225. self.count = -1
  226. return self
  227. def __next__(self):
  228. self.count += 1
  229. if cv2.waitKey(1) == ord('q'): # q to quit
  230. self.cap.release()
  231. cv2.destroyAllWindows()
  232. raise StopIteration
  233. # Read frame
  234. ret_val, img0 = self.cap.read()
  235. img0 = cv2.flip(img0, 1) # flip left-right
  236. # Print
  237. assert ret_val, f'Camera Error {self.pipe}'
  238. img_path = 'webcam.jpg'
  239. s = f'webcam {self.count}: '
  240. # Padded resize
  241. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  242. # Convert
  243. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  244. img = np.ascontiguousarray(img)
  245. return img_path, img, img0, None, s
  246. def __len__(self):
  247. return 0
  248. class LoadStreams:
  249. # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
  250. def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
  251. self.mode = 'stream'
  252. self.img_size = img_size
  253. self.stride = stride
  254. if os.path.isfile(sources):
  255. with open(sources) as f:
  256. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  257. else:
  258. sources = [sources]
  259. n = len(sources)
  260. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  261. self.sources = [clean_str(x) for x in sources] # clean source names for later
  262. self.auto = auto
  263. for i, s in enumerate(sources): # index, source
  264. # Start thread to read frames from video stream
  265. st = f'{i + 1}/{n}: {s}... '
  266. if urlparse(s).hostname in ('youtube.com', 'youtu.be'): # if source is YouTube video
  267. check_requirements(('pafy', 'youtube_dl==2020.12.2'))
  268. import pafy
  269. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  270. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  271. cap = cv2.VideoCapture(s)
  272. assert cap.isOpened(), f'{st}Failed to open {s}'
  273. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  274. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  275. fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
  276. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  277. self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
  278. _, self.imgs[i] = cap.read() # guarantee first frame
  279. self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
  280. LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  281. self.threads[i].start()
  282. LOGGER.info('') # newline
  283. # check for common shapes
  284. s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
  285. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  286. if not self.rect:
  287. LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  288. def update(self, i, cap, stream):
  289. # Read stream `i` frames in daemon thread
  290. n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
  291. while cap.isOpened() and n < f:
  292. n += 1
  293. # _, self.imgs[index] = cap.read()
  294. cap.grab()
  295. if n % read == 0:
  296. success, im = cap.retrieve()
  297. if success:
  298. self.imgs[i] = im
  299. else:
  300. LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
  301. self.imgs[i] = np.zeros_like(self.imgs[i])
  302. cap.open(stream) # re-open stream if signal was lost
  303. time.sleep(1 / self.fps[i]) # wait time
  304. def __iter__(self):
  305. self.count = -1
  306. return self
  307. def __next__(self):
  308. self.count += 1
  309. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  310. cv2.destroyAllWindows()
  311. raise StopIteration
  312. # Letterbox
  313. img0 = self.imgs.copy()
  314. img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
  315. # Stack
  316. img = np.stack(img, 0)
  317. # Convert
  318. img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
  319. img = np.ascontiguousarray(img)
  320. return self.sources, img, img0, None, ''
  321. def __len__(self):
  322. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  323. def img2label_paths(img_paths):
  324. # Define label paths as a function of image paths
  325. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  326. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  327. class LoadImagesAndLabels(Dataset):
  328. # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
  329. cache_version = 0.6 # dataset labels *.cache version
  330. def __init__(self,
  331. path,
  332. img_size=640,
  333. batch_size=16,
  334. augment=False,
  335. hyp=None,
  336. rect=False,
  337. image_weights=False,
  338. cache_images=False,
  339. single_cls=False,
  340. stride=32,
  341. pad=0.0,
  342. prefix=''):
  343. self.img_size = img_size
  344. self.augment = augment
  345. self.hyp = hyp
  346. self.image_weights = image_weights
  347. self.rect = False if image_weights else rect
  348. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  349. self.mosaic_border = [-img_size // 2, -img_size // 2]
  350. self.stride = stride
  351. self.path = path
  352. self.albumentations = Albumentations() if augment else None
  353. try:
  354. f = [] # image files
  355. for p in path if isinstance(path, list) else [path]:
  356. p = Path(p) # os-agnostic
  357. if p.is_dir(): # dir
  358. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  359. # f = list(p.rglob('*.*')) # pathlib
  360. elif p.is_file(): # file
  361. with open(p) as t:
  362. t = t.read().strip().splitlines()
  363. parent = str(p.parent) + os.sep
  364. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  365. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  366. else:
  367. raise Exception(f'{prefix}{p} does not exist')
  368. self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
  369. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  370. assert self.im_files, f'{prefix}No images found'
  371. except Exception as e:
  372. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
  373. # Check cache
  374. self.label_files = img2label_paths(self.im_files) # labels
  375. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
  376. try:
  377. cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
  378. assert cache['version'] == self.cache_version # same version
  379. assert cache['hash'] == get_hash(self.label_files + self.im_files) # same hash
  380. except Exception:
  381. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  382. # Display cache
  383. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
  384. if exists:
  385. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  386. tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
  387. if cache['msgs']:
  388. LOGGER.info('\n'.join(cache['msgs'])) # display warnings
  389. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
  390. # Read cache
  391. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  392. labels, shapes, self.segments = zip(*cache.values())
  393. self.labels = list(labels)
  394. self.shapes = np.array(shapes, dtype=np.float64)
  395. self.im_files = list(cache.keys()) # update
  396. self.label_files = img2label_paths(cache.keys()) # update
  397. n = len(shapes) # number of images
  398. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  399. nb = bi[-1] + 1 # number of batches
  400. self.batch = bi # batch index of image
  401. self.n = n
  402. self.indices = range(n)
  403. # Update labels
  404. include_class = [] # filter labels to include only these classes (optional)
  405. include_class_array = np.array(include_class).reshape(1, -1)
  406. for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
  407. if include_class:
  408. j = (label[:, 0:1] == include_class_array).any(1)
  409. self.labels[i] = label[j]
  410. if segment:
  411. self.segments[i] = segment[j]
  412. if single_cls: # single-class training, merge all classes into 0
  413. self.labels[i][:, 0] = 0
  414. if segment:
  415. self.segments[i][:, 0] = 0
  416. # Rectangular Training
  417. if self.rect:
  418. # Sort by aspect ratio
  419. s = self.shapes # wh
  420. ar = s[:, 1] / s[:, 0] # aspect ratio
  421. irect = ar.argsort()
  422. self.im_files = [self.im_files[i] for i in irect]
  423. self.label_files = [self.label_files[i] for i in irect]
  424. self.labels = [self.labels[i] for i in irect]
  425. self.shapes = s[irect] # wh
  426. ar = ar[irect]
  427. # Set training image shapes
  428. shapes = [[1, 1]] * nb
  429. for i in range(nb):
  430. ari = ar[bi == i]
  431. mini, maxi = ari.min(), ari.max()
  432. if maxi < 1:
  433. shapes[i] = [maxi, 1]
  434. elif mini > 1:
  435. shapes[i] = [1, 1 / mini]
  436. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  437. # Cache images into RAM/disk for faster training (WARNING: large datasets may exceed system resources)
  438. self.ims = [None] * n
  439. self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
  440. if cache_images:
  441. gb = 0 # Gigabytes of cached images
  442. self.im_hw0, self.im_hw = [None] * n, [None] * n
  443. fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
  444. results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
  445. pbar = tqdm(enumerate(results), total=n, bar_format=BAR_FORMAT)
  446. for i, x in pbar:
  447. if cache_images == 'disk':
  448. gb += self.npy_files[i].stat().st_size
  449. else: # 'ram'
  450. self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  451. gb += self.ims[i].nbytes
  452. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
  453. pbar.close()
  454. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  455. # Cache dataset labels, check images and read shapes
  456. x = {} # dict
  457. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  458. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  459. with Pool(NUM_THREADS) as pool:
  460. pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
  461. desc=desc,
  462. total=len(self.im_files),
  463. bar_format=BAR_FORMAT)
  464. for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  465. nm += nm_f
  466. nf += nf_f
  467. ne += ne_f
  468. nc += nc_f
  469. if im_file:
  470. x[im_file] = [lb, shape, segments]
  471. if msg:
  472. msgs.append(msg)
  473. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
  474. pbar.close()
  475. if msgs:
  476. LOGGER.info('\n'.join(msgs))
  477. if nf == 0:
  478. LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
  479. x['hash'] = get_hash(self.label_files + self.im_files)
  480. x['results'] = nf, nm, ne, nc, len(self.im_files)
  481. x['msgs'] = msgs # warnings
  482. x['version'] = self.cache_version # cache version
  483. try:
  484. np.save(path, x) # save cache for next time
  485. path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
  486. LOGGER.info(f'{prefix}New cache created: {path}')
  487. except Exception as e:
  488. LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable
  489. return x
  490. def __len__(self):
  491. return len(self.im_files)
  492. # def __iter__(self):
  493. # self.count = -1
  494. # print('ran dataset iter')
  495. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  496. # return self
  497. def __getitem__(self, index):
  498. index = self.indices[index] # linear, shuffled, or image_weights
  499. hyp = self.hyp
  500. mosaic = self.mosaic and random.random() < hyp['mosaic']
  501. if mosaic:
  502. # Load mosaic
  503. img, labels = self.load_mosaic(index)
  504. shapes = None
  505. # MixUp augmentation
  506. if random.random() < hyp['mixup']:
  507. img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
  508. else:
  509. # Load image
  510. img, (h0, w0), (h, w) = self.load_image(index)
  511. # Letterbox
  512. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  513. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  514. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  515. labels = self.labels[index].copy()
  516. if labels.size: # normalized xywh to pixel xyxy format
  517. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  518. if self.augment:
  519. img, labels = random_perspective(img,
  520. labels,
  521. degrees=hyp['degrees'],
  522. translate=hyp['translate'],
  523. scale=hyp['scale'],
  524. shear=hyp['shear'],
  525. perspective=hyp['perspective'])
  526. nl = len(labels) # number of labels
  527. if nl:
  528. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
  529. if self.augment:
  530. # Albumentations
  531. img, labels = self.albumentations(img, labels)
  532. nl = len(labels) # update after albumentations
  533. # HSV color-space
  534. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  535. # Flip up-down
  536. if random.random() < hyp['flipud']:
  537. img = np.flipud(img)
  538. if nl:
  539. labels[:, 2] = 1 - labels[:, 2]
  540. # Flip left-right
  541. if random.random() < hyp['fliplr']:
  542. img = np.fliplr(img)
  543. if nl:
  544. labels[:, 1] = 1 - labels[:, 1]
  545. # Cutouts
  546. # labels = cutout(img, labels, p=0.5)
  547. # nl = len(labels) # update after cutout
  548. labels_out = torch.zeros((nl, 6))
  549. if nl:
  550. labels_out[:, 1:] = torch.from_numpy(labels)
  551. # Convert
  552. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  553. img = np.ascontiguousarray(img)
  554. return torch.from_numpy(img), labels_out, self.im_files[index], shapes
  555. def load_image(self, i):
  556. # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
  557. im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
  558. if im is None: # not cached in RAM
  559. if fn.exists(): # load npy
  560. im = np.load(fn)
  561. else: # read image
  562. im = cv2.imread(f) # BGR
  563. assert im is not None, f'Image Not Found {f}'
  564. h0, w0 = im.shape[:2] # orig hw
  565. r = self.img_size / max(h0, w0) # ratio
  566. if r != 1: # if sizes are not equal
  567. im = cv2.resize(im, (int(w0 * r), int(h0 * r)),
  568. interpolation=cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA)
  569. return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  570. else:
  571. return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
  572. def cache_images_to_disk(self, i):
  573. # Saves an image as an *.npy file for faster loading
  574. f = self.npy_files[i]
  575. if not f.exists():
  576. np.save(f.as_posix(), cv2.imread(self.im_files[i]))
  577. def load_mosaic(self, index):
  578. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  579. labels4, segments4 = [], []
  580. s = self.img_size
  581. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
  582. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  583. random.shuffle(indices)
  584. for i, index in enumerate(indices):
  585. # Load image
  586. img, _, (h, w) = self.load_image(index)
  587. # place img in img4
  588. if i == 0: # top left
  589. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  590. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  591. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  592. elif i == 1: # top right
  593. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  594. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  595. elif i == 2: # bottom left
  596. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  597. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  598. elif i == 3: # bottom right
  599. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  600. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  601. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  602. padw = x1a - x1b
  603. padh = y1a - y1b
  604. # Labels
  605. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  606. if labels.size:
  607. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  608. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  609. labels4.append(labels)
  610. segments4.extend(segments)
  611. # Concat/clip labels
  612. labels4 = np.concatenate(labels4, 0)
  613. for x in (labels4[:, 1:], *segments4):
  614. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  615. # img4, labels4 = replicate(img4, labels4) # replicate
  616. # Augment
  617. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
  618. img4, labels4 = random_perspective(img4,
  619. labels4,
  620. segments4,
  621. degrees=self.hyp['degrees'],
  622. translate=self.hyp['translate'],
  623. scale=self.hyp['scale'],
  624. shear=self.hyp['shear'],
  625. perspective=self.hyp['perspective'],
  626. border=self.mosaic_border) # border to remove
  627. return img4, labels4
  628. def load_mosaic9(self, index):
  629. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  630. labels9, segments9 = [], []
  631. s = self.img_size
  632. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  633. random.shuffle(indices)
  634. hp, wp = -1, -1 # height, width previous
  635. for i, index in enumerate(indices):
  636. # Load image
  637. img, _, (h, w) = self.load_image(index)
  638. # place img in img9
  639. if i == 0: # center
  640. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  641. h0, w0 = h, w
  642. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  643. elif i == 1: # top
  644. c = s, s - h, s + w, s
  645. elif i == 2: # top right
  646. c = s + wp, s - h, s + wp + w, s
  647. elif i == 3: # right
  648. c = s + w0, s, s + w0 + w, s + h
  649. elif i == 4: # bottom right
  650. c = s + w0, s + hp, s + w0 + w, s + hp + h
  651. elif i == 5: # bottom
  652. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  653. elif i == 6: # bottom left
  654. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  655. elif i == 7: # left
  656. c = s - w, s + h0 - h, s, s + h0
  657. elif i == 8: # top left
  658. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  659. padx, pady = c[:2]
  660. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  661. # Labels
  662. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  663. if labels.size:
  664. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  665. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  666. labels9.append(labels)
  667. segments9.extend(segments)
  668. # Image
  669. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  670. hp, wp = h, w # height, width previous
  671. # Offset
  672. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  673. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  674. # Concat/clip labels
  675. labels9 = np.concatenate(labels9, 0)
  676. labels9[:, [1, 3]] -= xc
  677. labels9[:, [2, 4]] -= yc
  678. c = np.array([xc, yc]) # centers
  679. segments9 = [x - c for x in segments9]
  680. for x in (labels9[:, 1:], *segments9):
  681. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  682. # img9, labels9 = replicate(img9, labels9) # replicate
  683. # Augment
  684. img9, labels9 = random_perspective(img9,
  685. labels9,
  686. segments9,
  687. degrees=self.hyp['degrees'],
  688. translate=self.hyp['translate'],
  689. scale=self.hyp['scale'],
  690. shear=self.hyp['shear'],
  691. perspective=self.hyp['perspective'],
  692. border=self.mosaic_border) # border to remove
  693. return img9, labels9
  694. @staticmethod
  695. def collate_fn(batch):
  696. im, label, path, shapes = zip(*batch) # transposed
  697. for i, lb in enumerate(label):
  698. lb[:, 0] = i # add target image index for build_targets()
  699. return torch.stack(im, 0), torch.cat(label, 0), path, shapes
  700. @staticmethod
  701. def collate_fn4(batch):
  702. img, label, path, shapes = zip(*batch) # transposed
  703. n = len(shapes) // 4
  704. im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  705. ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
  706. wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
  707. s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
  708. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  709. i *= 4
  710. if random.random() < 0.5:
  711. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
  712. align_corners=False)[0].type(img[i].type())
  713. lb = label[i]
  714. else:
  715. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  716. lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  717. im4.append(im)
  718. label4.append(lb)
  719. for i, lb in enumerate(label4):
  720. lb[:, 0] = i # add target image index for build_targets()
  721. return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
  722. # Ancillary functions --------------------------------------------------------------------------------------------------
  723. def create_folder(path='./new'):
  724. # Create folder
  725. if os.path.exists(path):
  726. shutil.rmtree(path) # delete output folder
  727. os.makedirs(path) # make new output folder
  728. def flatten_recursive(path=DATASETS_DIR / 'coco128'):
  729. # Flatten a recursive directory by bringing all files to top level
  730. new_path = Path(str(path) + '_flat')
  731. create_folder(new_path)
  732. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  733. shutil.copyfile(file, new_path / Path(file).name)
  734. def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.datasets import *; extract_boxes()
  735. # Convert detection dataset into classification dataset, with one directory per class
  736. path = Path(path) # images dir
  737. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  738. files = list(path.rglob('*.*'))
  739. n = len(files) # number of files
  740. for im_file in tqdm(files, total=n):
  741. if im_file.suffix[1:] in IMG_FORMATS:
  742. # image
  743. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  744. h, w = im.shape[:2]
  745. # labels
  746. lb_file = Path(img2label_paths([str(im_file)])[0])
  747. if Path(lb_file).exists():
  748. with open(lb_file) as f:
  749. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  750. for j, x in enumerate(lb):
  751. c = int(x[0]) # class
  752. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  753. if not f.parent.is_dir():
  754. f.parent.mkdir(parents=True)
  755. b = x[1:] * [w, h, w, h] # box
  756. # b[2:] = b[2:].max() # rectangle to square
  757. b[2:] = b[2:] * 1.2 + 3 # pad
  758. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  759. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  760. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  761. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  762. def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  763. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  764. Usage: from utils.datasets import *; autosplit()
  765. Arguments
  766. path: Path to images directory
  767. weights: Train, val, test weights (list, tuple)
  768. annotated_only: Only use images with an annotated txt file
  769. """
  770. path = Path(path) # images dir
  771. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  772. n = len(files) # number of files
  773. random.seed(0) # for reproducibility
  774. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  775. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  776. [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
  777. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  778. for i, img in tqdm(zip(indices, files), total=n):
  779. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  780. with open(path.parent / txt[i], 'a') as f:
  781. f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
  782. def verify_image_label(args):
  783. # Verify one image-label pair
  784. im_file, lb_file, prefix = args
  785. nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
  786. try:
  787. # verify images
  788. im = Image.open(im_file)
  789. im.verify() # PIL verify
  790. shape = exif_size(im) # image size
  791. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  792. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  793. if im.format.lower() in ('jpg', 'jpeg'):
  794. with open(im_file, 'rb') as f:
  795. f.seek(-2, 2)
  796. if f.read() != b'\xff\xd9': # corrupt JPEG
  797. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  798. msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'
  799. # verify labels
  800. if os.path.isfile(lb_file):
  801. nf = 1 # label found
  802. with open(lb_file) as f:
  803. lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
  804. if any(len(x) > 6 for x in lb): # is segment
  805. classes = np.array([x[0] for x in lb], dtype=np.float32)
  806. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
  807. lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  808. lb = np.array(lb, dtype=np.float32)
  809. nl = len(lb)
  810. if nl:
  811. assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
  812. assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
  813. assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
  814. _, i = np.unique(lb, axis=0, return_index=True)
  815. if len(i) < nl: # duplicate row check
  816. lb = lb[i] # remove duplicates
  817. if segments:
  818. segments = segments[i]
  819. msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
  820. else:
  821. ne = 1 # label empty
  822. lb = np.zeros((0, 5), dtype=np.float32)
  823. else:
  824. nm = 1 # label missing
  825. lb = np.zeros((0, 5), dtype=np.float32)
  826. return im_file, lb, shape, segments, nm, nf, ne, nc, msg
  827. except Exception as e:
  828. nc = 1
  829. msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
  830. return [None, None, None, None, nm, nf, ne, nc, msg]
  831. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
  832. """ Return dataset statistics dictionary with images and instances counts per split per class
  833. To run in parent directory: export PYTHONPATH="$PWD/yolov5"
  834. Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', autodownload=True)
  835. Usage2: from utils.datasets import *; dataset_stats('path/to/coco128_with_yaml.zip')
  836. Arguments
  837. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  838. autodownload: Attempt to download dataset if not found locally
  839. verbose: Print stats dictionary
  840. """
  841. def round_labels(labels):
  842. # Update labels to integer class and 6 decimal place floats
  843. return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
  844. def unzip(path):
  845. # Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
  846. if str(path).endswith('.zip'): # path is data.zip
  847. assert Path(path).is_file(), f'Error unzipping {path}, file not found'
  848. ZipFile(path).extractall(path=path.parent) # unzip
  849. dir = path.with_suffix('') # dataset directory == zip name
  850. return True, str(dir), next(dir.rglob('*.yaml')) # zipped, data_dir, yaml_path
  851. else: # path is data.yaml
  852. return False, None, path
  853. def hub_ops(f, max_dim=1920):
  854. # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
  855. f_new = im_dir / Path(f).name # dataset-hub image filename
  856. try: # use PIL
  857. im = Image.open(f)
  858. r = max_dim / max(im.height, im.width) # ratio
  859. if r < 1.0: # image too large
  860. im = im.resize((int(im.width * r), int(im.height * r)))
  861. im.save(f_new, 'JPEG', quality=75, optimize=True) # save
  862. except Exception as e: # use OpenCV
  863. print(f'WARNING: HUB ops PIL failure {f}: {e}')
  864. im = cv2.imread(f)
  865. im_height, im_width = im.shape[:2]
  866. r = max_dim / max(im_height, im_width) # ratio
  867. if r < 1.0: # image too large
  868. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
  869. cv2.imwrite(str(f_new), im)
  870. zipped, data_dir, yaml_path = unzip(Path(path))
  871. with open(check_yaml(yaml_path), errors='ignore') as f:
  872. data = yaml.safe_load(f) # data dict
  873. if zipped:
  874. data['path'] = data_dir # TODO: should this be dir.resolve()?
  875. check_dataset(data, autodownload) # download dataset if missing
  876. hub_dir = Path(data['path'] + ('-hub' if hub else ''))
  877. stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
  878. for split in 'train', 'val', 'test':
  879. if data.get(split) is None:
  880. stats[split] = None # i.e. no test set
  881. continue
  882. x = []
  883. dataset = LoadImagesAndLabels(data[split]) # load dataset
  884. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  885. x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
  886. x = np.array(x) # shape(128x80)
  887. stats[split] = {
  888. 'instance_stats': {
  889. 'total': int(x.sum()),
  890. 'per_class': x.sum(0).tolist()},
  891. 'image_stats': {
  892. 'total': dataset.n,
  893. 'unlabelled': int(np.all(x == 0, 1).sum()),
  894. 'per_class': (x > 0).sum(0).tolist()},
  895. 'labels': [{
  896. str(Path(k).name): round_labels(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
  897. if hub:
  898. im_dir = hub_dir / 'images'
  899. im_dir.mkdir(parents=True, exist_ok=True)
  900. for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.im_files), total=dataset.n, desc='HUB Ops'):
  901. pass
  902. # Profile
  903. stats_path = hub_dir / 'stats.json'
  904. if profile:
  905. for _ in range(1):
  906. file = stats_path.with_suffix('.npy')
  907. t1 = time.time()
  908. np.save(file, stats)
  909. t2 = time.time()
  910. x = np.load(file, allow_pickle=True)
  911. print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  912. file = stats_path.with_suffix('.json')
  913. t1 = time.time()
  914. with open(file, 'w') as f:
  915. json.dump(stats, f) # save stats *.json
  916. t2 = time.time()
  917. with open(file) as f:
  918. x = json.load(f) # load hyps dict
  919. print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  920. # Save, print and return
  921. if hub:
  922. print(f'Saving {stats_path.resolve()}...')
  923. with open(stats_path, 'w') as f:
  924. json.dump(stats, f) # save stats.json
  925. if verbose:
  926. print(json.dumps(stats, indent=2, sort_keys=False))
  927. return stats