You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1035 lines
44KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Dataloaders and dataset utils
  4. """
  5. import glob
  6. import hashlib
  7. import json
  8. import os
  9. import random
  10. import shutil
  11. import time
  12. from itertools import repeat
  13. from multiprocessing.pool import Pool, ThreadPool
  14. from pathlib import Path
  15. from threading import Thread
  16. from zipfile import ZipFile
  17. import cv2
  18. import numpy as np
  19. import torch
  20. import torch.nn.functional as F
  21. import yaml
  22. from PIL import ExifTags, Image, ImageOps
  23. from torch.utils.data import Dataset
  24. from tqdm import tqdm
  25. from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
  26. from utils.general import (LOGGER, check_dataset, check_requirements, check_yaml, clean_str, segments2boxes, xyn2xy,
  27. xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
  28. from utils.torch_utils import torch_distributed_zero_first
  29. # Parameters
  30. HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  31. IMG_FORMATS = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  32. VID_FORMATS = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  33. NUM_THREADS = min(8, os.cpu_count()) # number of multiprocessing threads
  34. # Get orientation exif tag
  35. for orientation in ExifTags.TAGS.keys():
  36. if ExifTags.TAGS[orientation] == 'Orientation':
  37. break
  38. def get_hash(paths):
  39. # Returns a single hash value of a list of paths (files or dirs)
  40. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  41. h = hashlib.md5(str(size).encode()) # hash sizes
  42. h.update(''.join(paths).encode()) # hash paths
  43. return h.hexdigest() # return hash
  44. def exif_size(img):
  45. # Returns exif-corrected PIL size
  46. s = img.size # (width, height)
  47. try:
  48. rotation = dict(img._getexif().items())[orientation]
  49. if rotation == 6: # rotation 270
  50. s = (s[1], s[0])
  51. elif rotation == 8: # rotation 90
  52. s = (s[1], s[0])
  53. except:
  54. pass
  55. return s
  56. def exif_transpose(image):
  57. """
  58. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  59. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
  60. :param image: The image to transpose.
  61. :return: An image.
  62. """
  63. exif = image.getexif()
  64. orientation = exif.get(0x0112, 1) # default 1
  65. if orientation > 1:
  66. method = {2: Image.FLIP_LEFT_RIGHT,
  67. 3: Image.ROTATE_180,
  68. 4: Image.FLIP_TOP_BOTTOM,
  69. 5: Image.TRANSPOSE,
  70. 6: Image.ROTATE_270,
  71. 7: Image.TRANSVERSE,
  72. 8: Image.ROTATE_90,
  73. }.get(orientation)
  74. if method is not None:
  75. image = image.transpose(method)
  76. del exif[0x0112]
  77. image.info["exif"] = exif.tobytes()
  78. return image
  79. def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
  80. rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
  81. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  82. with torch_distributed_zero_first(rank):
  83. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  84. augment=augment, # augment images
  85. hyp=hyp, # augmentation hyperparameters
  86. rect=rect, # rectangular training
  87. cache_images=cache,
  88. single_cls=single_cls,
  89. stride=int(stride),
  90. pad=pad,
  91. image_weights=image_weights,
  92. prefix=prefix)
  93. batch_size = min(batch_size, len(dataset))
  94. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
  95. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  96. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  97. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  98. dataloader = loader(dataset,
  99. batch_size=batch_size,
  100. num_workers=nw,
  101. sampler=sampler,
  102. pin_memory=True,
  103. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  104. return dataloader, dataset
  105. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  106. """ Dataloader that reuses workers
  107. Uses same syntax as vanilla DataLoader
  108. """
  109. def __init__(self, *args, **kwargs):
  110. super().__init__(*args, **kwargs)
  111. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  112. self.iterator = super().__iter__()
  113. def __len__(self):
  114. return len(self.batch_sampler.sampler)
  115. def __iter__(self):
  116. for i in range(len(self)):
  117. yield next(self.iterator)
  118. class _RepeatSampler:
  119. """ Sampler that repeats forever
  120. Args:
  121. sampler (Sampler)
  122. """
  123. def __init__(self, sampler):
  124. self.sampler = sampler
  125. def __iter__(self):
  126. while True:
  127. yield from iter(self.sampler)
  128. class LoadImages:
  129. # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
  130. def __init__(self, path, img_size=640, stride=32, auto=True):
  131. p = str(Path(path).resolve()) # os-agnostic absolute path
  132. if '*' in p:
  133. files = sorted(glob.glob(p, recursive=True)) # glob
  134. elif os.path.isdir(p):
  135. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  136. elif os.path.isfile(p):
  137. files = [p] # files
  138. else:
  139. raise Exception(f'ERROR: {p} does not exist')
  140. images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
  141. videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
  142. ni, nv = len(images), len(videos)
  143. self.img_size = img_size
  144. self.stride = stride
  145. self.files = images + videos
  146. self.nf = ni + nv # number of files
  147. self.video_flag = [False] * ni + [True] * nv
  148. self.mode = 'image'
  149. self.auto = auto
  150. if any(videos):
  151. self.new_video(videos[0]) # new video
  152. else:
  153. self.cap = None
  154. assert self.nf > 0, f'No images or videos found in {p}. ' \
  155. f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  156. def __iter__(self):
  157. self.count = 0
  158. return self
  159. def __next__(self):
  160. if self.count == self.nf:
  161. raise StopIteration
  162. path = self.files[self.count]
  163. if self.video_flag[self.count]:
  164. # Read video
  165. self.mode = 'video'
  166. ret_val, img0 = self.cap.read()
  167. if not ret_val:
  168. self.count += 1
  169. self.cap.release()
  170. if self.count == self.nf: # last video
  171. raise StopIteration
  172. else:
  173. path = self.files[self.count]
  174. self.new_video(path)
  175. ret_val, img0 = self.cap.read()
  176. self.frame += 1
  177. s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
  178. else:
  179. # Read image
  180. self.count += 1
  181. img0 = cv2.imread(path) # BGR
  182. assert img0 is not None, f'Image Not Found {path}'
  183. s = f'image {self.count}/{self.nf} {path}: '
  184. # Padded resize
  185. img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
  186. # Convert
  187. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  188. img = np.ascontiguousarray(img)
  189. return path, img, img0, self.cap, s
  190. def new_video(self, path):
  191. self.frame = 0
  192. self.cap = cv2.VideoCapture(path)
  193. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  194. def __len__(self):
  195. return self.nf # number of files
  196. class LoadWebcam: # for inference
  197. # YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
  198. def __init__(self, pipe='0', img_size=640, stride=32):
  199. self.img_size = img_size
  200. self.stride = stride
  201. self.pipe = eval(pipe) if pipe.isnumeric() else pipe
  202. self.cap = cv2.VideoCapture(self.pipe) # video capture object
  203. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  204. def __iter__(self):
  205. self.count = -1
  206. return self
  207. def __next__(self):
  208. self.count += 1
  209. if cv2.waitKey(1) == ord('q'): # q to quit
  210. self.cap.release()
  211. cv2.destroyAllWindows()
  212. raise StopIteration
  213. # Read frame
  214. ret_val, img0 = self.cap.read()
  215. img0 = cv2.flip(img0, 1) # flip left-right
  216. # Print
  217. assert ret_val, f'Camera Error {self.pipe}'
  218. img_path = 'webcam.jpg'
  219. s = f'webcam {self.count}: '
  220. # Padded resize
  221. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  222. # Convert
  223. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  224. img = np.ascontiguousarray(img)
  225. return img_path, img, img0, None, s
  226. def __len__(self):
  227. return 0
  228. class LoadStreams:
  229. # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
  230. def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
  231. self.mode = 'stream'
  232. self.img_size = img_size
  233. self.stride = stride
  234. if os.path.isfile(sources):
  235. with open(sources) as f:
  236. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  237. else:
  238. sources = [sources]
  239. n = len(sources)
  240. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  241. self.sources = [clean_str(x) for x in sources] # clean source names for later
  242. self.auto = auto
  243. for i, s in enumerate(sources): # index, source
  244. # Start thread to read frames from video stream
  245. st = f'{i + 1}/{n}: {s}... '
  246. if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
  247. check_requirements(('pafy', 'youtube_dl'))
  248. import pafy
  249. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  250. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  251. cap = cv2.VideoCapture(s)
  252. assert cap.isOpened(), f'{st}Failed to open {s}'
  253. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  254. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  255. self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback
  256. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  257. _, self.imgs[i] = cap.read() # guarantee first frame
  258. self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
  259. LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  260. self.threads[i].start()
  261. LOGGER.info('') # newline
  262. # check for common shapes
  263. s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
  264. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  265. if not self.rect:
  266. LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  267. def update(self, i, cap, stream):
  268. # Read stream `i` frames in daemon thread
  269. n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
  270. while cap.isOpened() and n < f:
  271. n += 1
  272. # _, self.imgs[index] = cap.read()
  273. cap.grab()
  274. if n % read == 0:
  275. success, im = cap.retrieve()
  276. if success:
  277. self.imgs[i] = im
  278. else:
  279. LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
  280. self.imgs[i] *= 0
  281. cap.open(stream) # re-open stream if signal was lost
  282. time.sleep(1 / self.fps[i]) # wait time
  283. def __iter__(self):
  284. self.count = -1
  285. return self
  286. def __next__(self):
  287. self.count += 1
  288. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  289. cv2.destroyAllWindows()
  290. raise StopIteration
  291. # Letterbox
  292. img0 = self.imgs.copy()
  293. img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
  294. # Stack
  295. img = np.stack(img, 0)
  296. # Convert
  297. img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
  298. img = np.ascontiguousarray(img)
  299. return self.sources, img, img0, None, ''
  300. def __len__(self):
  301. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  302. def img2label_paths(img_paths):
  303. # Define label paths as a function of image paths
  304. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  305. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  306. class LoadImagesAndLabels(Dataset):
  307. # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
  308. cache_version = 0.6 # dataset labels *.cache version
  309. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  310. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  311. self.img_size = img_size
  312. self.augment = augment
  313. self.hyp = hyp
  314. self.image_weights = image_weights
  315. self.rect = False if image_weights else rect
  316. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  317. self.mosaic_border = [-img_size // 2, -img_size // 2]
  318. self.stride = stride
  319. self.path = path
  320. self.albumentations = Albumentations() if augment else None
  321. try:
  322. f = [] # image files
  323. for p in path if isinstance(path, list) else [path]:
  324. p = Path(p) # os-agnostic
  325. if p.is_dir(): # dir
  326. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  327. # f = list(p.rglob('*.*')) # pathlib
  328. elif p.is_file(): # file
  329. with open(p) as t:
  330. t = t.read().strip().splitlines()
  331. parent = str(p.parent) + os.sep
  332. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  333. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  334. else:
  335. raise Exception(f'{prefix}{p} does not exist')
  336. self.img_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
  337. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
  338. assert self.img_files, f'{prefix}No images found'
  339. except Exception as e:
  340. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
  341. # Check cache
  342. self.label_files = img2label_paths(self.img_files) # labels
  343. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
  344. try:
  345. cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
  346. assert cache['version'] == self.cache_version # same version
  347. assert cache['hash'] == get_hash(self.label_files + self.img_files) # same hash
  348. except:
  349. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  350. # Display cache
  351. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
  352. if exists:
  353. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  354. tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
  355. if cache['msgs']:
  356. LOGGER.info('\n'.join(cache['msgs'])) # display warnings
  357. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
  358. # Read cache
  359. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  360. labels, shapes, self.segments = zip(*cache.values())
  361. self.labels = list(labels)
  362. self.shapes = np.array(shapes, dtype=np.float64)
  363. self.img_files = list(cache.keys()) # update
  364. self.label_files = img2label_paths(cache.keys()) # update
  365. n = len(shapes) # number of images
  366. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  367. nb = bi[-1] + 1 # number of batches
  368. self.batch = bi # batch index of image
  369. self.n = n
  370. self.indices = range(n)
  371. # Update labels
  372. include_class = [] # filter labels to include only these classes (optional)
  373. include_class_array = np.array(include_class).reshape(1, -1)
  374. for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
  375. if include_class:
  376. j = (label[:, 0:1] == include_class_array).any(1)
  377. self.labels[i] = label[j]
  378. if segment:
  379. self.segments[i] = segment[j]
  380. if single_cls: # single-class training, merge all classes into 0
  381. self.labels[i][:, 0] = 0
  382. if segment:
  383. self.segments[i][:, 0] = 0
  384. # Rectangular Training
  385. if self.rect:
  386. # Sort by aspect ratio
  387. s = self.shapes # wh
  388. ar = s[:, 1] / s[:, 0] # aspect ratio
  389. irect = ar.argsort()
  390. self.img_files = [self.img_files[i] for i in irect]
  391. self.label_files = [self.label_files[i] for i in irect]
  392. self.labels = [self.labels[i] for i in irect]
  393. self.shapes = s[irect] # wh
  394. ar = ar[irect]
  395. # Set training image shapes
  396. shapes = [[1, 1]] * nb
  397. for i in range(nb):
  398. ari = ar[bi == i]
  399. mini, maxi = ari.min(), ari.max()
  400. if maxi < 1:
  401. shapes[i] = [maxi, 1]
  402. elif mini > 1:
  403. shapes[i] = [1, 1 / mini]
  404. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  405. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  406. self.imgs, self.img_npy = [None] * n, [None] * n
  407. if cache_images:
  408. if cache_images == 'disk':
  409. self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy')
  410. self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files]
  411. self.im_cache_dir.mkdir(parents=True, exist_ok=True)
  412. gb = 0 # Gigabytes of cached images
  413. self.img_hw0, self.img_hw = [None] * n, [None] * n
  414. results = ThreadPool(NUM_THREADS).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
  415. pbar = tqdm(enumerate(results), total=n)
  416. for i, x in pbar:
  417. if cache_images == 'disk':
  418. if not self.img_npy[i].exists():
  419. np.save(self.img_npy[i].as_posix(), x[0])
  420. gb += self.img_npy[i].stat().st_size
  421. else:
  422. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
  423. gb += self.imgs[i].nbytes
  424. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB {cache_images})'
  425. pbar.close()
  426. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  427. # Cache dataset labels, check images and read shapes
  428. x = {} # dict
  429. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  430. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  431. with Pool(NUM_THREADS) as pool:
  432. pbar = tqdm(pool.imap(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
  433. desc=desc, total=len(self.img_files))
  434. for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  435. nm += nm_f
  436. nf += nf_f
  437. ne += ne_f
  438. nc += nc_f
  439. if im_file:
  440. x[im_file] = [l, shape, segments]
  441. if msg:
  442. msgs.append(msg)
  443. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  444. pbar.close()
  445. if msgs:
  446. LOGGER.info('\n'.join(msgs))
  447. if nf == 0:
  448. LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
  449. x['hash'] = get_hash(self.label_files + self.img_files)
  450. x['results'] = nf, nm, ne, nc, len(self.img_files)
  451. x['msgs'] = msgs # warnings
  452. x['version'] = self.cache_version # cache version
  453. try:
  454. np.save(path, x) # save cache for next time
  455. path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
  456. LOGGER.info(f'{prefix}New cache created: {path}')
  457. except Exception as e:
  458. LOGGER.warning(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # not writeable
  459. return x
  460. def __len__(self):
  461. return len(self.img_files)
  462. # def __iter__(self):
  463. # self.count = -1
  464. # print('ran dataset iter')
  465. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  466. # return self
  467. def __getitem__(self, index):
  468. index = self.indices[index] # linear, shuffled, or image_weights
  469. hyp = self.hyp
  470. mosaic = self.mosaic and random.random() < hyp['mosaic']
  471. if mosaic:
  472. # Load mosaic
  473. img, labels = load_mosaic(self, index)
  474. shapes = None
  475. # MixUp augmentation
  476. if random.random() < hyp['mixup']:
  477. img, labels = mixup(img, labels, *load_mosaic(self, random.randint(0, self.n - 1)))
  478. else:
  479. # Load image
  480. img, (h0, w0), (h, w) = load_image(self, index)
  481. # Letterbox
  482. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  483. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  484. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  485. labels = self.labels[index].copy()
  486. if labels.size: # normalized xywh to pixel xyxy format
  487. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  488. if self.augment:
  489. img, labels = random_perspective(img, labels,
  490. degrees=hyp['degrees'],
  491. translate=hyp['translate'],
  492. scale=hyp['scale'],
  493. shear=hyp['shear'],
  494. perspective=hyp['perspective'])
  495. nl = len(labels) # number of labels
  496. if nl:
  497. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
  498. if self.augment:
  499. # Albumentations
  500. img, labels = self.albumentations(img, labels)
  501. nl = len(labels) # update after albumentations
  502. # HSV color-space
  503. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  504. # Flip up-down
  505. if random.random() < hyp['flipud']:
  506. img = np.flipud(img)
  507. if nl:
  508. labels[:, 2] = 1 - labels[:, 2]
  509. # Flip left-right
  510. if random.random() < hyp['fliplr']:
  511. img = np.fliplr(img)
  512. if nl:
  513. labels[:, 1] = 1 - labels[:, 1]
  514. # Cutouts
  515. # labels = cutout(img, labels, p=0.5)
  516. labels_out = torch.zeros((nl, 6))
  517. if nl:
  518. labels_out[:, 1:] = torch.from_numpy(labels)
  519. # Convert
  520. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  521. img = np.ascontiguousarray(img)
  522. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  523. @staticmethod
  524. def collate_fn(batch):
  525. img, label, path, shapes = zip(*batch) # transposed
  526. for i, l in enumerate(label):
  527. l[:, 0] = i # add target image index for build_targets()
  528. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  529. @staticmethod
  530. def collate_fn4(batch):
  531. img, label, path, shapes = zip(*batch) # transposed
  532. n = len(shapes) // 4
  533. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  534. ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
  535. wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
  536. s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
  537. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  538. i *= 4
  539. if random.random() < 0.5:
  540. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear', align_corners=False)[
  541. 0].type(img[i].type())
  542. l = label[i]
  543. else:
  544. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  545. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  546. img4.append(im)
  547. label4.append(l)
  548. for i, l in enumerate(label4):
  549. l[:, 0] = i # add target image index for build_targets()
  550. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  551. # Ancillary functions --------------------------------------------------------------------------------------------------
  552. def load_image(self, i):
  553. # loads 1 image from dataset index 'i', returns im, original hw, resized hw
  554. im = self.imgs[i]
  555. if im is None: # not cached in ram
  556. npy = self.img_npy[i]
  557. if npy and npy.exists(): # load npy
  558. im = np.load(npy)
  559. else: # read image
  560. path = self.img_files[i]
  561. im = cv2.imread(path) # BGR
  562. assert im is not None, f'Image Not Found {path}'
  563. h0, w0 = im.shape[:2] # orig hw
  564. r = self.img_size / max(h0, w0) # ratio
  565. if r != 1: # if sizes are not equal
  566. im = cv2.resize(im, (int(w0 * r), int(h0 * r)),
  567. interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
  568. return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
  569. else:
  570. return self.imgs[i], self.img_hw0[i], self.img_hw[i] # im, hw_original, hw_resized
  571. def load_mosaic(self, index):
  572. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  573. labels4, segments4 = [], []
  574. s = self.img_size
  575. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
  576. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  577. random.shuffle(indices)
  578. for i, index in enumerate(indices):
  579. # Load image
  580. img, _, (h, w) = load_image(self, index)
  581. # place img in img4
  582. if i == 0: # top left
  583. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  584. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  585. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  586. elif i == 1: # top right
  587. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  588. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  589. elif i == 2: # bottom left
  590. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  591. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  592. elif i == 3: # bottom right
  593. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  594. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  595. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  596. padw = x1a - x1b
  597. padh = y1a - y1b
  598. # Labels
  599. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  600. if labels.size:
  601. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  602. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  603. labels4.append(labels)
  604. segments4.extend(segments)
  605. # Concat/clip labels
  606. labels4 = np.concatenate(labels4, 0)
  607. for x in (labels4[:, 1:], *segments4):
  608. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  609. # img4, labels4 = replicate(img4, labels4) # replicate
  610. # Augment
  611. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
  612. img4, labels4 = random_perspective(img4, labels4, segments4,
  613. degrees=self.hyp['degrees'],
  614. translate=self.hyp['translate'],
  615. scale=self.hyp['scale'],
  616. shear=self.hyp['shear'],
  617. perspective=self.hyp['perspective'],
  618. border=self.mosaic_border) # border to remove
  619. return img4, labels4
  620. def load_mosaic9(self, index):
  621. # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
  622. labels9, segments9 = [], []
  623. s = self.img_size
  624. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  625. random.shuffle(indices)
  626. for i, index in enumerate(indices):
  627. # Load image
  628. img, _, (h, w) = load_image(self, index)
  629. # place img in img9
  630. if i == 0: # center
  631. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  632. h0, w0 = h, w
  633. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  634. elif i == 1: # top
  635. c = s, s - h, s + w, s
  636. elif i == 2: # top right
  637. c = s + wp, s - h, s + wp + w, s
  638. elif i == 3: # right
  639. c = s + w0, s, s + w0 + w, s + h
  640. elif i == 4: # bottom right
  641. c = s + w0, s + hp, s + w0 + w, s + hp + h
  642. elif i == 5: # bottom
  643. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  644. elif i == 6: # bottom left
  645. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  646. elif i == 7: # left
  647. c = s - w, s + h0 - h, s, s + h0
  648. elif i == 8: # top left
  649. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  650. padx, pady = c[:2]
  651. x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
  652. # Labels
  653. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  654. if labels.size:
  655. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  656. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  657. labels9.append(labels)
  658. segments9.extend(segments)
  659. # Image
  660. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  661. hp, wp = h, w # height, width previous
  662. # Offset
  663. yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
  664. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  665. # Concat/clip labels
  666. labels9 = np.concatenate(labels9, 0)
  667. labels9[:, [1, 3]] -= xc
  668. labels9[:, [2, 4]] -= yc
  669. c = np.array([xc, yc]) # centers
  670. segments9 = [x - c for x in segments9]
  671. for x in (labels9[:, 1:], *segments9):
  672. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  673. # img9, labels9 = replicate(img9, labels9) # replicate
  674. # Augment
  675. img9, labels9 = random_perspective(img9, labels9, segments9,
  676. degrees=self.hyp['degrees'],
  677. translate=self.hyp['translate'],
  678. scale=self.hyp['scale'],
  679. shear=self.hyp['shear'],
  680. perspective=self.hyp['perspective'],
  681. border=self.mosaic_border) # border to remove
  682. return img9, labels9
  683. def create_folder(path='./new'):
  684. # Create folder
  685. if os.path.exists(path):
  686. shutil.rmtree(path) # delete output folder
  687. os.makedirs(path) # make new output folder
  688. def flatten_recursive(path='../datasets/coco128'):
  689. # Flatten a recursive directory by bringing all files to top level
  690. new_path = Path(path + '_flat')
  691. create_folder(new_path)
  692. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  693. shutil.copyfile(file, new_path / Path(file).name)
  694. def extract_boxes(path='../datasets/coco128'): # from utils.datasets import *; extract_boxes()
  695. # Convert detection dataset into classification dataset, with one directory per class
  696. path = Path(path) # images dir
  697. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  698. files = list(path.rglob('*.*'))
  699. n = len(files) # number of files
  700. for im_file in tqdm(files, total=n):
  701. if im_file.suffix[1:] in IMG_FORMATS:
  702. # image
  703. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  704. h, w = im.shape[:2]
  705. # labels
  706. lb_file = Path(img2label_paths([str(im_file)])[0])
  707. if Path(lb_file).exists():
  708. with open(lb_file) as f:
  709. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  710. for j, x in enumerate(lb):
  711. c = int(x[0]) # class
  712. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  713. if not f.parent.is_dir():
  714. f.parent.mkdir(parents=True)
  715. b = x[1:] * [w, h, w, h] # box
  716. # b[2:] = b[2:].max() # rectangle to square
  717. b[2:] = b[2:] * 1.2 + 3 # pad
  718. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  719. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  720. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  721. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  722. def autosplit(path='../datasets/coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  723. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  724. Usage: from utils.datasets import *; autosplit()
  725. Arguments
  726. path: Path to images directory
  727. weights: Train, val, test weights (list, tuple)
  728. annotated_only: Only use images with an annotated txt file
  729. """
  730. path = Path(path) # images dir
  731. files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
  732. n = len(files) # number of files
  733. random.seed(0) # for reproducibility
  734. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  735. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  736. [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
  737. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  738. for i, img in tqdm(zip(indices, files), total=n):
  739. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  740. with open(path.parent / txt[i], 'a') as f:
  741. f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
  742. def verify_image_label(args):
  743. # Verify one image-label pair
  744. im_file, lb_file, prefix = args
  745. nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
  746. try:
  747. # verify images
  748. im = Image.open(im_file)
  749. im.verify() # PIL verify
  750. shape = exif_size(im) # image size
  751. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  752. assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
  753. if im.format.lower() in ('jpg', 'jpeg'):
  754. with open(im_file, 'rb') as f:
  755. f.seek(-2, 2)
  756. if f.read() != b'\xff\xd9': # corrupt JPEG
  757. ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
  758. msg = f'{prefix}WARNING: {im_file}: corrupt JPEG restored and saved'
  759. # verify labels
  760. if os.path.isfile(lb_file):
  761. nf = 1 # label found
  762. with open(lb_file) as f:
  763. l = [x.split() for x in f.read().strip().splitlines() if len(x)]
  764. if any([len(x) > 8 for x in l]): # is segment
  765. classes = np.array([x[0] for x in l], dtype=np.float32)
  766. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
  767. l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  768. l = np.array(l, dtype=np.float32)
  769. nl = len(l)
  770. if nl:
  771. assert l.shape[1] == 5, f'labels require 5 columns, {l.shape[1]} columns detected'
  772. assert (l >= 0).all(), f'negative label values {l[l < 0]}'
  773. assert (l[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {l[:, 1:][l[:, 1:] > 1]}'
  774. _, i = np.unique(l, axis=0, return_index=True)
  775. if len(i) < nl: # duplicate row check
  776. l = l[i] # remove duplicates
  777. if segments:
  778. segments = segments[i]
  779. msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
  780. else:
  781. ne = 1 # label empty
  782. l = np.zeros((0, 5), dtype=np.float32)
  783. else:
  784. nm = 1 # label missing
  785. l = np.zeros((0, 5), dtype=np.float32)
  786. return im_file, l, shape, segments, nm, nf, ne, nc, msg
  787. except Exception as e:
  788. nc = 1
  789. msg = f'{prefix}WARNING: {im_file}: ignoring corrupt image/label: {e}'
  790. return [None, None, None, None, nm, nf, ne, nc, msg]
  791. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
  792. """ Return dataset statistics dictionary with images and instances counts per split per class
  793. To run in parent directory: export PYTHONPATH="$PWD/yolov5"
  794. Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', autodownload=True)
  795. Usage2: from utils.datasets import *; dataset_stats('../datasets/coco128_with_yaml.zip')
  796. Arguments
  797. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  798. autodownload: Attempt to download dataset if not found locally
  799. verbose: Print stats dictionary
  800. """
  801. def round_labels(labels):
  802. # Update labels to integer class and 6 decimal place floats
  803. return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
  804. def unzip(path):
  805. # Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
  806. if str(path).endswith('.zip'): # path is data.zip
  807. assert Path(path).is_file(), f'Error unzipping {path}, file not found'
  808. ZipFile(path).extractall(path=path.parent) # unzip
  809. dir = path.with_suffix('') # dataset directory == zip name
  810. return True, str(dir), next(dir.rglob('*.yaml')) # zipped, data_dir, yaml_path
  811. else: # path is data.yaml
  812. return False, None, path
  813. def hub_ops(f, max_dim=1920):
  814. # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
  815. f_new = im_dir / Path(f).name # dataset-hub image filename
  816. try: # use PIL
  817. im = Image.open(f)
  818. r = max_dim / max(im.height, im.width) # ratio
  819. if r < 1.0: # image too large
  820. im = im.resize((int(im.width * r), int(im.height * r)))
  821. im.save(f_new, quality=75) # save
  822. except Exception as e: # use OpenCV
  823. print(f'WARNING: HUB ops PIL failure {f}: {e}')
  824. im = cv2.imread(f)
  825. im_height, im_width = im.shape[:2]
  826. r = max_dim / max(im_height, im_width) # ratio
  827. if r < 1.0: # image too large
  828. im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_LINEAR)
  829. cv2.imwrite(str(f_new), im)
  830. zipped, data_dir, yaml_path = unzip(Path(path))
  831. with open(check_yaml(yaml_path), errors='ignore') as f:
  832. data = yaml.safe_load(f) # data dict
  833. if zipped:
  834. data['path'] = data_dir # TODO: should this be dir.resolve()?
  835. check_dataset(data, autodownload) # download dataset if missing
  836. hub_dir = Path(data['path'] + ('-hub' if hub else ''))
  837. stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
  838. for split in 'train', 'val', 'test':
  839. if data.get(split) is None:
  840. stats[split] = None # i.e. no test set
  841. continue
  842. x = []
  843. dataset = LoadImagesAndLabels(data[split]) # load dataset
  844. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  845. x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
  846. x = np.array(x) # shape(128x80)
  847. stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
  848. 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
  849. 'per_class': (x > 0).sum(0).tolist()},
  850. 'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
  851. zip(dataset.img_files, dataset.labels)]}
  852. if hub:
  853. im_dir = hub_dir / 'images'
  854. im_dir.mkdir(parents=True, exist_ok=True)
  855. for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.img_files), total=dataset.n, desc='HUB Ops'):
  856. pass
  857. # Profile
  858. stats_path = hub_dir / 'stats.json'
  859. if profile:
  860. for _ in range(1):
  861. file = stats_path.with_suffix('.npy')
  862. t1 = time.time()
  863. np.save(file, stats)
  864. t2 = time.time()
  865. x = np.load(file, allow_pickle=True)
  866. print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  867. file = stats_path.with_suffix('.json')
  868. t1 = time.time()
  869. with open(file, 'w') as f:
  870. json.dump(stats, f) # save stats *.json
  871. t2 = time.time()
  872. with open(file) as f:
  873. x = json.load(f) # load hyps dict
  874. print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')
  875. # Save, print and return
  876. if hub:
  877. print(f'Saving {stats_path.resolve()}...')
  878. with open(stats_path, 'w') as f:
  879. json.dump(stats, f) # save stats.json
  880. if verbose:
  881. print(json.dumps(stats, indent=2, sort_keys=False))
  882. return stats