You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

945 lines
39KB

  1. # YOLOv5 dataset utils and dataloaders
  2. import glob
  3. import hashlib
  4. import json
  5. import logging
  6. import os
  7. import random
  8. import shutil
  9. import time
  10. from itertools import repeat
  11. from multiprocessing.pool import ThreadPool, Pool
  12. from pathlib import Path
  13. from threading import Thread
  14. import cv2
  15. import numpy as np
  16. import torch
  17. import torch.nn.functional as F
  18. import yaml
  19. from PIL import Image, ExifTags
  20. from torch.utils.data import Dataset
  21. from tqdm import tqdm
  22. from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
  23. from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
  24. xyn2xy, segments2boxes, clean_str
  25. from utils.torch_utils import torch_distributed_zero_first
  26. # Parameters
  27. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  28. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  29. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  30. num_threads = min(8, os.cpu_count()) # number of multiprocessing threads
  31. logger = logging.getLogger(__name__)
  32. # Get orientation exif tag
  33. for orientation in ExifTags.TAGS.keys():
  34. if ExifTags.TAGS[orientation] == 'Orientation':
  35. break
  36. def get_hash(paths):
  37. # Returns a single hash value of a list of paths (files or dirs)
  38. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  39. h = hashlib.md5(str(size).encode()) # hash sizes
  40. h.update(''.join(paths).encode()) # hash paths
  41. return h.hexdigest() # return hash
  42. def exif_size(img):
  43. # Returns exif-corrected PIL size
  44. s = img.size # (width, height)
  45. try:
  46. rotation = dict(img._getexif().items())[orientation]
  47. if rotation == 6: # rotation 270
  48. s = (s[1], s[0])
  49. elif rotation == 8: # rotation 90
  50. s = (s[1], s[0])
  51. except:
  52. pass
  53. return s
  54. def exif_transpose(image):
  55. """
  56. Transpose a PIL image accordingly if it has an EXIF Orientation tag.
  57. From https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py
  58. :param image: The image to transpose.
  59. :return: An image.
  60. """
  61. exif = image.getexif()
  62. orientation = exif.get(0x0112, 1) # default 1
  63. if orientation > 1:
  64. method = {2: Image.FLIP_LEFT_RIGHT,
  65. 3: Image.ROTATE_180,
  66. 4: Image.FLIP_TOP_BOTTOM,
  67. 5: Image.TRANSPOSE,
  68. 6: Image.ROTATE_270,
  69. 7: Image.TRANSVERSE,
  70. 8: Image.ROTATE_90,
  71. }.get(orientation)
  72. if method is not None:
  73. image = image.transpose(method)
  74. del exif[0x0112]
  75. image.info["exif"] = exif.tobytes()
  76. return image
  77. def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
  78. rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
  79. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  80. with torch_distributed_zero_first(rank):
  81. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  82. augment=augment, # augment images
  83. hyp=hyp, # augmentation hyperparameters
  84. rect=rect, # rectangular training
  85. cache_images=cache,
  86. single_cls=single_cls,
  87. stride=int(stride),
  88. pad=pad,
  89. image_weights=image_weights,
  90. prefix=prefix)
  91. batch_size = min(batch_size, len(dataset))
  92. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
  93. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  94. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  95. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  96. dataloader = loader(dataset,
  97. batch_size=batch_size,
  98. num_workers=nw,
  99. sampler=sampler,
  100. pin_memory=True,
  101. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  102. return dataloader, dataset
  103. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  104. """ Dataloader that reuses workers
  105. Uses same syntax as vanilla DataLoader
  106. """
  107. def __init__(self, *args, **kwargs):
  108. super().__init__(*args, **kwargs)
  109. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  110. self.iterator = super().__iter__()
  111. def __len__(self):
  112. return len(self.batch_sampler.sampler)
  113. def __iter__(self):
  114. for i in range(len(self)):
  115. yield next(self.iterator)
  116. class _RepeatSampler(object):
  117. """ Sampler that repeats forever
  118. Args:
  119. sampler (Sampler)
  120. """
  121. def __init__(self, sampler):
  122. self.sampler = sampler
  123. def __iter__(self):
  124. while True:
  125. yield from iter(self.sampler)
  126. class LoadImages: # for inference
  127. def __init__(self, path, img_size=640, stride=32):
  128. p = str(Path(path).absolute()) # os-agnostic absolute path
  129. if '*' in p:
  130. files = sorted(glob.glob(p, recursive=True)) # glob
  131. elif os.path.isdir(p):
  132. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  133. elif os.path.isfile(p):
  134. files = [p] # files
  135. else:
  136. raise Exception(f'ERROR: {p} does not exist')
  137. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  138. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  139. ni, nv = len(images), len(videos)
  140. self.img_size = img_size
  141. self.stride = stride
  142. self.files = images + videos
  143. self.nf = ni + nv # number of files
  144. self.video_flag = [False] * ni + [True] * nv
  145. self.mode = 'image'
  146. if any(videos):
  147. self.new_video(videos[0]) # new video
  148. else:
  149. self.cap = None
  150. assert self.nf > 0, f'No images or videos found in {p}. ' \
  151. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
  152. def __iter__(self):
  153. self.count = 0
  154. return self
  155. def __next__(self):
  156. if self.count == self.nf:
  157. raise StopIteration
  158. path = self.files[self.count]
  159. if self.video_flag[self.count]:
  160. # Read video
  161. self.mode = 'video'
  162. ret_val, img0 = self.cap.read()
  163. if not ret_val:
  164. self.count += 1
  165. self.cap.release()
  166. if self.count == self.nf: # last video
  167. raise StopIteration
  168. else:
  169. path = self.files[self.count]
  170. self.new_video(path)
  171. ret_val, img0 = self.cap.read()
  172. self.frame += 1
  173. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='')
  174. else:
  175. # Read image
  176. self.count += 1
  177. img0 = cv2.imread(path) # BGR
  178. assert img0 is not None, 'Image Not Found ' + path
  179. print(f'image {self.count}/{self.nf} {path}: ', end='')
  180. # Padded resize
  181. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  182. # Convert
  183. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  184. img = np.ascontiguousarray(img)
  185. return path, img, img0, self.cap
  186. def new_video(self, path):
  187. self.frame = 0
  188. self.cap = cv2.VideoCapture(path)
  189. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  190. def __len__(self):
  191. return self.nf # number of files
  192. class LoadWebcam: # for inference
  193. def __init__(self, pipe='0', img_size=640, stride=32):
  194. self.img_size = img_size
  195. self.stride = stride
  196. self.pipe = eval(pipe) if pipe.isnumeric() else pipe
  197. self.cap = cv2.VideoCapture(self.pipe) # video capture object
  198. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  199. def __iter__(self):
  200. self.count = -1
  201. return self
  202. def __next__(self):
  203. self.count += 1
  204. if cv2.waitKey(1) == ord('q'): # q to quit
  205. self.cap.release()
  206. cv2.destroyAllWindows()
  207. raise StopIteration
  208. # Read frame
  209. ret_val, img0 = self.cap.read()
  210. img0 = cv2.flip(img0, 1) # flip left-right
  211. # Print
  212. assert ret_val, f'Camera Error {self.pipe}'
  213. img_path = 'webcam.jpg'
  214. print(f'webcam {self.count}: ', end='')
  215. # Padded resize
  216. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  217. # Convert
  218. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  219. img = np.ascontiguousarray(img)
  220. return img_path, img, img0, None
  221. def __len__(self):
  222. return 0
  223. class LoadStreams: # multiple IP or RTSP cameras
  224. def __init__(self, sources='streams.txt', img_size=640, stride=32):
  225. self.mode = 'stream'
  226. self.img_size = img_size
  227. self.stride = stride
  228. if os.path.isfile(sources):
  229. with open(sources, 'r') as f:
  230. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  231. else:
  232. sources = [sources]
  233. n = len(sources)
  234. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  235. self.sources = [clean_str(x) for x in sources] # clean source names for later
  236. for i, s in enumerate(sources): # index, source
  237. # Start thread to read frames from video stream
  238. print(f'{i + 1}/{n}: {s}... ', end='')
  239. if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
  240. check_requirements(('pafy', 'youtube_dl'))
  241. import pafy
  242. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  243. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  244. cap = cv2.VideoCapture(s)
  245. assert cap.isOpened(), f'Failed to open {s}'
  246. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  247. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  248. self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback
  249. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  250. _, self.imgs[i] = cap.read() # guarantee first frame
  251. self.threads[i] = Thread(target=self.update, args=([i, cap]), daemon=True)
  252. print(f" success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  253. self.threads[i].start()
  254. print('') # newline
  255. # check for common shapes
  256. s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
  257. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  258. if not self.rect:
  259. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  260. def update(self, i, cap):
  261. # Read stream `i` frames in daemon thread
  262. n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
  263. while cap.isOpened() and n < f:
  264. n += 1
  265. # _, self.imgs[index] = cap.read()
  266. cap.grab()
  267. if n % read == 0:
  268. success, im = cap.retrieve()
  269. self.imgs[i] = im if success else self.imgs[i] * 0
  270. time.sleep(1 / self.fps[i]) # wait time
  271. def __iter__(self):
  272. self.count = -1
  273. return self
  274. def __next__(self):
  275. self.count += 1
  276. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  277. cv2.destroyAllWindows()
  278. raise StopIteration
  279. # Letterbox
  280. img0 = self.imgs.copy()
  281. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
  282. # Stack
  283. img = np.stack(img, 0)
  284. # Convert
  285. img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
  286. img = np.ascontiguousarray(img)
  287. return self.sources, img, img0, None
  288. def __len__(self):
  289. return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
  290. def img2label_paths(img_paths):
  291. # Define label paths as a function of image paths
  292. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  293. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  294. class LoadImagesAndLabels(Dataset): # for training/testing
  295. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  296. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  297. self.img_size = img_size
  298. self.augment = augment
  299. self.hyp = hyp
  300. self.image_weights = image_weights
  301. self.rect = False if image_weights else rect
  302. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  303. self.mosaic_border = [-img_size // 2, -img_size // 2]
  304. self.stride = stride
  305. self.path = path
  306. self.albumentations = Albumentations() if augment else None
  307. try:
  308. f = [] # image files
  309. for p in path if isinstance(path, list) else [path]:
  310. p = Path(p) # os-agnostic
  311. if p.is_dir(): # dir
  312. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  313. # f = list(p.rglob('**/*.*')) # pathlib
  314. elif p.is_file(): # file
  315. with open(p, 'r') as t:
  316. t = t.read().strip().splitlines()
  317. parent = str(p.parent) + os.sep
  318. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  319. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  320. else:
  321. raise Exception(f'{prefix}{p} does not exist')
  322. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
  323. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
  324. assert self.img_files, f'{prefix}No images found'
  325. except Exception as e:
  326. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
  327. # Check cache
  328. self.label_files = img2label_paths(self.img_files) # labels
  329. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
  330. if cache_path.is_file():
  331. cache, exists = torch.load(cache_path), True # load
  332. if cache.get('version') != 0.3 or cache.get('hash') != get_hash(self.label_files + self.img_files):
  333. cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
  334. else:
  335. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  336. # Display cache
  337. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
  338. if exists:
  339. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  340. tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
  341. if cache['msgs']:
  342. logging.info('\n'.join(cache['msgs'])) # display warnings
  343. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
  344. # Read cache
  345. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  346. labels, shapes, self.segments = zip(*cache.values())
  347. self.labels = list(labels)
  348. self.shapes = np.array(shapes, dtype=np.float64)
  349. self.img_files = list(cache.keys()) # update
  350. self.label_files = img2label_paths(cache.keys()) # update
  351. if single_cls:
  352. for x in self.labels:
  353. x[:, 0] = 0
  354. n = len(shapes) # number of images
  355. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  356. nb = bi[-1] + 1 # number of batches
  357. self.batch = bi # batch index of image
  358. self.n = n
  359. self.indices = range(n)
  360. # Rectangular Training
  361. if self.rect:
  362. # Sort by aspect ratio
  363. s = self.shapes # wh
  364. ar = s[:, 1] / s[:, 0] # aspect ratio
  365. irect = ar.argsort()
  366. self.img_files = [self.img_files[i] for i in irect]
  367. self.label_files = [self.label_files[i] for i in irect]
  368. self.labels = [self.labels[i] for i in irect]
  369. self.shapes = s[irect] # wh
  370. ar = ar[irect]
  371. # Set training image shapes
  372. shapes = [[1, 1]] * nb
  373. for i in range(nb):
  374. ari = ar[bi == i]
  375. mini, maxi = ari.min(), ari.max()
  376. if maxi < 1:
  377. shapes[i] = [maxi, 1]
  378. elif mini > 1:
  379. shapes[i] = [1, 1 / mini]
  380. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  381. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  382. self.imgs = [None] * n
  383. if cache_images:
  384. gb = 0 # Gigabytes of cached images
  385. self.img_hw0, self.img_hw = [None] * n, [None] * n
  386. results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
  387. pbar = tqdm(enumerate(results), total=n)
  388. for i, x in pbar:
  389. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
  390. gb += self.imgs[i].nbytes
  391. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
  392. pbar.close()
  393. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  394. # Cache dataset labels, check images and read shapes
  395. x = {} # dict
  396. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  397. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  398. with Pool(num_threads) as pool:
  399. pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
  400. desc=desc, total=len(self.img_files))
  401. for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  402. nm += nm_f
  403. nf += nf_f
  404. ne += ne_f
  405. nc += nc_f
  406. if im_file:
  407. x[im_file] = [l, shape, segments]
  408. if msg:
  409. msgs.append(msg)
  410. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  411. pbar.close()
  412. if msgs:
  413. logging.info('\n'.join(msgs))
  414. if nf == 0:
  415. logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
  416. x['hash'] = get_hash(self.label_files + self.img_files)
  417. x['results'] = nf, nm, ne, nc, len(self.img_files)
  418. x['msgs'] = msgs # warnings
  419. x['version'] = 0.3 # cache version
  420. try:
  421. torch.save(x, path) # save cache for next time
  422. logging.info(f'{prefix}New cache created: {path}')
  423. except Exception as e:
  424. logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # path not writeable
  425. return x
  426. def __len__(self):
  427. return len(self.img_files)
  428. # def __iter__(self):
  429. # self.count = -1
  430. # print('ran dataset iter')
  431. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  432. # return self
  433. def __getitem__(self, index):
  434. index = self.indices[index] # linear, shuffled, or image_weights
  435. hyp = self.hyp
  436. mosaic = self.mosaic and random.random() < hyp['mosaic']
  437. if mosaic:
  438. # Load mosaic
  439. img, labels = load_mosaic(self, index)
  440. shapes = None
  441. # MixUp augmentation
  442. if random.random() < hyp['mixup']:
  443. img, labels = mixup(img, labels, *load_mosaic(self, random.randint(0, self.n - 1)))
  444. else:
  445. # Load image
  446. img, (h0, w0), (h, w) = load_image(self, index)
  447. # Letterbox
  448. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  449. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  450. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  451. labels = self.labels[index].copy()
  452. if labels.size: # normalized xywh to pixel xyxy format
  453. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  454. if self.augment:
  455. img, labels = random_perspective(img, labels,
  456. degrees=hyp['degrees'],
  457. translate=hyp['translate'],
  458. scale=hyp['scale'],
  459. shear=hyp['shear'],
  460. perspective=hyp['perspective'])
  461. nl = len(labels) # number of labels
  462. if nl:
  463. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized
  464. if self.augment:
  465. # Albumentations
  466. img, labels = self.albumentations(img, labels)
  467. # HSV color-space
  468. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  469. # Flip up-down
  470. if random.random() < hyp['flipud']:
  471. img = np.flipud(img)
  472. if nl:
  473. labels[:, 2] = 1 - labels[:, 2]
  474. # Flip left-right
  475. if random.random() < hyp['fliplr']:
  476. img = np.fliplr(img)
  477. if nl:
  478. labels[:, 1] = 1 - labels[:, 1]
  479. # Cutouts
  480. # if random.random() < 0.9:
  481. # labels = cutout(img, labels)
  482. labels_out = torch.zeros((nl, 6))
  483. if nl:
  484. labels_out[:, 1:] = torch.from_numpy(labels)
  485. # Convert
  486. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  487. img = np.ascontiguousarray(img)
  488. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  489. @staticmethod
  490. def collate_fn(batch):
  491. img, label, path, shapes = zip(*batch) # transposed
  492. for i, l in enumerate(label):
  493. l[:, 0] = i # add target image index for build_targets()
  494. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  495. @staticmethod
  496. def collate_fn4(batch):
  497. img, label, path, shapes = zip(*batch) # transposed
  498. n = len(shapes) // 4
  499. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  500. ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
  501. wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
  502. s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
  503. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  504. i *= 4
  505. if random.random() < 0.5:
  506. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
  507. 0].type(img[i].type())
  508. l = label[i]
  509. else:
  510. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  511. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  512. img4.append(im)
  513. label4.append(l)
  514. for i, l in enumerate(label4):
  515. l[:, 0] = i # add target image index for build_targets()
  516. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  517. # Ancillary functions --------------------------------------------------------------------------------------------------
  518. def load_image(self, index):
  519. # loads 1 image from dataset, returns img, original hw, resized hw
  520. img = self.imgs[index]
  521. if img is None: # not cached
  522. path = self.img_files[index]
  523. img = cv2.imread(path) # BGR
  524. assert img is not None, 'Image Not Found ' + path
  525. h0, w0 = img.shape[:2] # orig hw
  526. r = self.img_size / max(h0, w0) # ratio
  527. if r != 1: # if sizes are not equal
  528. img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
  529. interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
  530. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  531. else:
  532. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  533. def load_mosaic(self, index):
  534. # loads images in a 4-mosaic
  535. labels4, segments4 = [], []
  536. s = self.img_size
  537. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  538. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  539. for i, index in enumerate(indices):
  540. # Load image
  541. img, _, (h, w) = load_image(self, index)
  542. # place img in img4
  543. if i == 0: # top left
  544. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  545. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  546. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  547. elif i == 1: # top right
  548. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  549. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  550. elif i == 2: # bottom left
  551. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  552. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  553. elif i == 3: # bottom right
  554. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  555. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  556. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  557. padw = x1a - x1b
  558. padh = y1a - y1b
  559. # Labels
  560. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  561. if labels.size:
  562. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  563. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  564. labels4.append(labels)
  565. segments4.extend(segments)
  566. # Concat/clip labels
  567. labels4 = np.concatenate(labels4, 0)
  568. for x in (labels4[:, 1:], *segments4):
  569. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  570. # img4, labels4 = replicate(img4, labels4) # replicate
  571. # Augment
  572. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, probability=self.hyp['copy_paste'])
  573. img4, labels4 = random_perspective(img4, labels4, segments4,
  574. degrees=self.hyp['degrees'],
  575. translate=self.hyp['translate'],
  576. scale=self.hyp['scale'],
  577. shear=self.hyp['shear'],
  578. perspective=self.hyp['perspective'],
  579. border=self.mosaic_border) # border to remove
  580. return img4, labels4
  581. def load_mosaic9(self, index):
  582. # loads images in a 9-mosaic
  583. labels9, segments9 = [], []
  584. s = self.img_size
  585. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  586. for i, index in enumerate(indices):
  587. # Load image
  588. img, _, (h, w) = load_image(self, index)
  589. # place img in img9
  590. if i == 0: # center
  591. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  592. h0, w0 = h, w
  593. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  594. elif i == 1: # top
  595. c = s, s - h, s + w, s
  596. elif i == 2: # top right
  597. c = s + wp, s - h, s + wp + w, s
  598. elif i == 3: # right
  599. c = s + w0, s, s + w0 + w, s + h
  600. elif i == 4: # bottom right
  601. c = s + w0, s + hp, s + w0 + w, s + hp + h
  602. elif i == 5: # bottom
  603. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  604. elif i == 6: # bottom left
  605. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  606. elif i == 7: # left
  607. c = s - w, s + h0 - h, s, s + h0
  608. elif i == 8: # top left
  609. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  610. padx, pady = c[:2]
  611. x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
  612. # Labels
  613. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  614. if labels.size:
  615. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  616. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  617. labels9.append(labels)
  618. segments9.extend(segments)
  619. # Image
  620. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  621. hp, wp = h, w # height, width previous
  622. # Offset
  623. yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y
  624. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  625. # Concat/clip labels
  626. labels9 = np.concatenate(labels9, 0)
  627. labels9[:, [1, 3]] -= xc
  628. labels9[:, [2, 4]] -= yc
  629. c = np.array([xc, yc]) # centers
  630. segments9 = [x - c for x in segments9]
  631. for x in (labels9[:, 1:], *segments9):
  632. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  633. # img9, labels9 = replicate(img9, labels9) # replicate
  634. # Augment
  635. img9, labels9 = random_perspective(img9, labels9, segments9,
  636. degrees=self.hyp['degrees'],
  637. translate=self.hyp['translate'],
  638. scale=self.hyp['scale'],
  639. shear=self.hyp['shear'],
  640. perspective=self.hyp['perspective'],
  641. border=self.mosaic_border) # border to remove
  642. return img9, labels9
  643. def create_folder(path='./new'):
  644. # Create folder
  645. if os.path.exists(path):
  646. shutil.rmtree(path) # delete output folder
  647. os.makedirs(path) # make new output folder
  648. def flatten_recursive(path='../datasets/coco128'):
  649. # Flatten a recursive directory by bringing all files to top level
  650. new_path = Path(path + '_flat')
  651. create_folder(new_path)
  652. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  653. shutil.copyfile(file, new_path / Path(file).name)
  654. def extract_boxes(path='../datasets/coco128'): # from utils.datasets import *; extract_boxes()
  655. # Convert detection dataset into classification dataset, with one directory per class
  656. path = Path(path) # images dir
  657. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  658. files = list(path.rglob('*.*'))
  659. n = len(files) # number of files
  660. for im_file in tqdm(files, total=n):
  661. if im_file.suffix[1:] in img_formats:
  662. # image
  663. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  664. h, w = im.shape[:2]
  665. # labels
  666. lb_file = Path(img2label_paths([str(im_file)])[0])
  667. if Path(lb_file).exists():
  668. with open(lb_file, 'r') as f:
  669. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  670. for j, x in enumerate(lb):
  671. c = int(x[0]) # class
  672. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  673. if not f.parent.is_dir():
  674. f.parent.mkdir(parents=True)
  675. b = x[1:] * [w, h, w, h] # box
  676. # b[2:] = b[2:].max() # rectangle to square
  677. b[2:] = b[2:] * 1.2 + 3 # pad
  678. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  679. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  680. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  681. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  682. def autosplit(path='../datasets/coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
  683. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  684. Usage: from utils.datasets import *; autosplit()
  685. Arguments
  686. path: Path to images directory
  687. weights: Train, val, test weights (list, tuple)
  688. annotated_only: Only use images with an annotated txt file
  689. """
  690. path = Path(path) # images dir
  691. files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only
  692. n = len(files) # number of files
  693. random.seed(0) # for reproducibility
  694. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  695. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  696. [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
  697. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  698. for i, img in tqdm(zip(indices, files), total=n):
  699. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  700. with open(path.parent / txt[i], 'a') as f:
  701. f.write('./' + img.relative_to(path.parent).as_posix() + '\n') # add image to txt file
  702. def verify_image_label(args):
  703. # Verify one image-label pair
  704. im_file, lb_file, prefix = args
  705. nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
  706. try:
  707. # verify images
  708. im = Image.open(im_file)
  709. im.verify() # PIL verify
  710. shape = exif_size(im) # image size
  711. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  712. assert im.format.lower() in img_formats, f'invalid image format {im.format}'
  713. if im.format.lower() in ('jpg', 'jpeg'):
  714. with open(im_file, 'rb') as f:
  715. f.seek(-2, 2)
  716. assert f.read() == b'\xff\xd9', 'corrupted JPEG'
  717. # verify labels
  718. segments = [] # instance segments
  719. if os.path.isfile(lb_file):
  720. nf = 1 # label found
  721. with open(lb_file, 'r') as f:
  722. l = [x.split() for x in f.read().strip().splitlines() if len(x)]
  723. if any([len(x) > 8 for x in l]): # is segment
  724. classes = np.array([x[0] for x in l], dtype=np.float32)
  725. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
  726. l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  727. l = np.array(l, dtype=np.float32)
  728. if len(l):
  729. assert l.shape[1] == 5, 'labels require 5 columns each'
  730. assert (l >= 0).all(), 'negative labels'
  731. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  732. assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
  733. else:
  734. ne = 1 # label empty
  735. l = np.zeros((0, 5), dtype=np.float32)
  736. else:
  737. nm = 1 # label missing
  738. l = np.zeros((0, 5), dtype=np.float32)
  739. return im_file, l, shape, segments, nm, nf, ne, nc, ''
  740. except Exception as e:
  741. nc = 1
  742. msg = f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}'
  743. return [None, None, None, None, nm, nf, ne, nc, msg]
  744. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
  745. """ Return dataset statistics dictionary with images and instances counts per split per class
  746. Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', verbose=True)
  747. Usage2: from utils.datasets import *; dataset_stats('../datasets/coco128.zip', verbose=True)
  748. Arguments
  749. path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
  750. autodownload: Attempt to download dataset if not found locally
  751. verbose: Print stats dictionary
  752. """
  753. def round_labels(labels):
  754. # Update labels to integer class and 6 decimal place floats
  755. return [[int(c), *[round(x, 6) for x in points]] for c, *points in labels]
  756. def unzip(path):
  757. # Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
  758. if str(path).endswith('.zip'): # path is data.zip
  759. assert os.system(f'unzip -q {path} -d {path.parent}') == 0, f'Error unzipping {path}'
  760. data_dir = path.with_suffix('') # dataset directory
  761. return True, data_dir, list(data_dir.rglob('*.yaml'))[0] # zipped, data_dir, yaml_path
  762. else: # path is data.yaml
  763. return False, None, path
  764. zipped, data_dir, yaml_path = unzip(Path(path))
  765. with open(check_file(yaml_path)) as f:
  766. data = yaml.safe_load(f) # data dict
  767. if zipped:
  768. data['path'] = data_dir # TODO: should this be dir.resolve()?
  769. check_dataset(data, autodownload) # download dataset if missing
  770. nc = data['nc'] # number of classes
  771. stats = {'nc': nc, 'names': data['names']} # statistics dictionary
  772. for split in 'train', 'val', 'test':
  773. if data.get(split) is None:
  774. stats[split] = None # i.e. no test set
  775. continue
  776. x = []
  777. dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
  778. if split == 'train':
  779. cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
  780. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  781. x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
  782. x = np.array(x) # shape(128x80)
  783. stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
  784. 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
  785. 'per_class': (x > 0).sum(0).tolist()},
  786. 'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
  787. zip(dataset.img_files, dataset.labels)]}
  788. # Save, print and return
  789. with open(cache_path.with_suffix('.json'), 'w') as f:
  790. json.dump(stats, f) # save stats *.json
  791. if verbose:
  792. print(json.dumps(stats, indent=2, sort_keys=False))
  793. # print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
  794. return stats