You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

905 lines
36KB

  1. # Dataset utils and dataloaders
  2. import glob
  3. import math
  4. import os
  5. import random
  6. import shutil
  7. import time
  8. from itertools import repeat
  9. from multiprocessing.pool import ThreadPool
  10. from pathlib import Path
  11. from threading import Thread
  12. import cv2
  13. import numpy as np
  14. import torch
  15. from PIL import Image, ExifTags
  16. from torch.utils.data import Dataset
  17. from tqdm import tqdm
  18. from utils.general import xyxy2xywh, xywh2xyxy
  19. from utils.torch_utils import torch_distributed_zero_first
  20. # Parameters
  21. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  22. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng'] # acceptable image suffixes
  23. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  24. # Get orientation exif tag
  25. for orientation in ExifTags.TAGS.keys():
  26. if ExifTags.TAGS[orientation] == 'Orientation':
  27. break
  28. def get_hash(files):
  29. # Returns a single hash value of a list of files
  30. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  31. def exif_size(img):
  32. # Returns exif-corrected PIL size
  33. s = img.size # (width, height)
  34. try:
  35. rotation = dict(img._getexif().items())[orientation]
  36. if rotation == 6: # rotation 270
  37. s = (s[1], s[0])
  38. elif rotation == 8: # rotation 90
  39. s = (s[1], s[0])
  40. except:
  41. pass
  42. return s
  43. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  44. rank=-1, world_size=1, workers=8):
  45. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  46. with torch_distributed_zero_first(rank):
  47. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  48. augment=augment, # augment images
  49. hyp=hyp, # augmentation hyperparameters
  50. rect=rect, # rectangular training
  51. cache_images=cache,
  52. single_cls=opt.single_cls,
  53. stride=int(stride),
  54. pad=pad,
  55. rank=rank)
  56. batch_size = min(batch_size, len(dataset))
  57. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  58. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  59. dataloader = InfiniteDataLoader(dataset,
  60. batch_size=batch_size,
  61. num_workers=nw,
  62. sampler=sampler,
  63. pin_memory=True,
  64. collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader()
  65. return dataloader, dataset
  66. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  67. """ Dataloader that reuses workers
  68. Uses same syntax as vanilla DataLoader
  69. """
  70. def __init__(self, *args, **kwargs):
  71. super().__init__(*args, **kwargs)
  72. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  73. self.iterator = super().__iter__()
  74. def __len__(self):
  75. return len(self.batch_sampler.sampler)
  76. def __iter__(self):
  77. for i in range(len(self)):
  78. yield next(self.iterator)
  79. class _RepeatSampler(object):
  80. """ Sampler that repeats forever
  81. Args:
  82. sampler (Sampler)
  83. """
  84. def __init__(self, sampler):
  85. self.sampler = sampler
  86. def __iter__(self):
  87. while True:
  88. yield from iter(self.sampler)
  89. class LoadImages: # for inference
  90. def __init__(self, path, img_size=640):
  91. p = str(Path(path)) # os-agnostic
  92. p = os.path.abspath(p) # absolute path
  93. if '*' in p:
  94. files = sorted(glob.glob(p, recursive=True)) # glob
  95. elif os.path.isdir(p):
  96. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  97. elif os.path.isfile(p):
  98. files = [p] # files
  99. else:
  100. raise Exception('ERROR: %s does not exist' % p)
  101. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  102. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  103. ni, nv = len(images), len(videos)
  104. self.img_size = img_size
  105. self.files = images + videos
  106. self.nf = ni + nv # number of files
  107. self.video_flag = [False] * ni + [True] * nv
  108. self.mode = 'images'
  109. if any(videos):
  110. self.new_video(videos[0]) # new video
  111. else:
  112. self.cap = None
  113. assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
  114. (p, img_formats, vid_formats)
  115. def __iter__(self):
  116. self.count = 0
  117. return self
  118. def __next__(self):
  119. if self.count == self.nf:
  120. raise StopIteration
  121. path = self.files[self.count]
  122. if self.video_flag[self.count]:
  123. # Read video
  124. self.mode = 'video'
  125. ret_val, img0 = self.cap.read()
  126. if not ret_val:
  127. self.count += 1
  128. self.cap.release()
  129. if self.count == self.nf: # last video
  130. raise StopIteration
  131. else:
  132. path = self.files[self.count]
  133. self.new_video(path)
  134. ret_val, img0 = self.cap.read()
  135. self.frame += 1
  136. print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')
  137. else:
  138. # Read image
  139. self.count += 1
  140. img0 = cv2.imread(path) # BGR
  141. assert img0 is not None, 'Image Not Found ' + path
  142. print('image %g/%g %s: ' % (self.count, self.nf, path), end='')
  143. # Padded resize
  144. img = letterbox(img0, new_shape=self.img_size)[0]
  145. # Convert
  146. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  147. img = np.ascontiguousarray(img)
  148. return path, img, img0, self.cap
  149. def new_video(self, path):
  150. self.frame = 0
  151. self.cap = cv2.VideoCapture(path)
  152. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  153. def __len__(self):
  154. return self.nf # number of files
  155. class LoadWebcam: # for inference
  156. def __init__(self, pipe='0', img_size=640):
  157. self.img_size = img_size
  158. if pipe.isnumeric():
  159. pipe = eval(pipe) # local camera
  160. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  161. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  162. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  163. self.pipe = pipe
  164. self.cap = cv2.VideoCapture(pipe) # video capture object
  165. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  166. def __iter__(self):
  167. self.count = -1
  168. return self
  169. def __next__(self):
  170. self.count += 1
  171. if cv2.waitKey(1) == ord('q'): # q to quit
  172. self.cap.release()
  173. cv2.destroyAllWindows()
  174. raise StopIteration
  175. # Read frame
  176. if self.pipe == 0: # local camera
  177. ret_val, img0 = self.cap.read()
  178. img0 = cv2.flip(img0, 1) # flip left-right
  179. else: # IP camera
  180. n = 0
  181. while True:
  182. n += 1
  183. self.cap.grab()
  184. if n % 30 == 0: # skip frames
  185. ret_val, img0 = self.cap.retrieve()
  186. if ret_val:
  187. break
  188. # Print
  189. assert ret_val, 'Camera Error %s' % self.pipe
  190. img_path = 'webcam.jpg'
  191. print('webcam %g: ' % self.count, end='')
  192. # Padded resize
  193. img = letterbox(img0, new_shape=self.img_size)[0]
  194. # Convert
  195. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  196. img = np.ascontiguousarray(img)
  197. return img_path, img, img0, None
  198. def __len__(self):
  199. return 0
  200. class LoadStreams: # multiple IP or RTSP cameras
  201. def __init__(self, sources='streams.txt', img_size=640):
  202. self.mode = 'images'
  203. self.img_size = img_size
  204. if os.path.isfile(sources):
  205. with open(sources, 'r') as f:
  206. sources = [x.strip() for x in f.read().splitlines() if len(x.strip())]
  207. else:
  208. sources = [sources]
  209. n = len(sources)
  210. self.imgs = [None] * n
  211. self.sources = sources
  212. for i, s in enumerate(sources):
  213. # Start the thread to read frames from the video stream
  214. print('%g/%g: %s... ' % (i + 1, n, s), end='')
  215. cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s)
  216. assert cap.isOpened(), 'Failed to open %s' % s
  217. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  218. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  219. fps = cap.get(cv2.CAP_PROP_FPS) % 100
  220. _, self.imgs[i] = cap.read() # guarantee first frame
  221. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  222. print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
  223. thread.start()
  224. print('') # newline
  225. # check for common shapes
  226. s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
  227. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  228. if not self.rect:
  229. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  230. def update(self, index, cap):
  231. # Read next stream frame in a daemon thread
  232. n = 0
  233. while cap.isOpened():
  234. n += 1
  235. # _, self.imgs[index] = cap.read()
  236. cap.grab()
  237. if n == 4: # read every 4th frame
  238. _, self.imgs[index] = cap.retrieve()
  239. n = 0
  240. time.sleep(0.01) # wait time
  241. def __iter__(self):
  242. self.count = -1
  243. return self
  244. def __next__(self):
  245. self.count += 1
  246. img0 = self.imgs.copy()
  247. if cv2.waitKey(1) == ord('q'): # q to quit
  248. cv2.destroyAllWindows()
  249. raise StopIteration
  250. # Letterbox
  251. img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
  252. # Stack
  253. img = np.stack(img, 0)
  254. # Convert
  255. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  256. img = np.ascontiguousarray(img)
  257. return self.sources, img, img0, None
  258. def __len__(self):
  259. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  260. class LoadImagesAndLabels(Dataset): # for training/testing
  261. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  262. cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
  263. self.img_size = img_size
  264. self.augment = augment
  265. self.hyp = hyp
  266. self.image_weights = image_weights
  267. self.rect = False if image_weights else rect
  268. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  269. self.mosaic_border = [-img_size // 2, -img_size // 2]
  270. self.stride = stride
  271. def img2label_paths(img_paths):
  272. # Define label paths as a function of image paths
  273. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  274. return [x.replace(sa, sb, 1).replace('.' + x.split('.')[-1], '.txt') for x in img_paths]
  275. try:
  276. f = [] # image files
  277. for p in path if isinstance(path, list) else [path]:
  278. p = Path(p) # os-agnostic
  279. if p.is_dir(): # dir
  280. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  281. elif p.is_file(): # file
  282. with open(p, 'r') as t:
  283. t = t.read().splitlines()
  284. parent = str(p.parent) + os.sep
  285. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  286. else:
  287. raise Exception('%s does not exist' % p)
  288. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
  289. assert self.img_files, 'No images found'
  290. except Exception as e:
  291. raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
  292. # Check cache
  293. self.label_files = img2label_paths(self.img_files) # labels
  294. cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
  295. if os.path.isfile(cache_path):
  296. cache = torch.load(cache_path) # load
  297. if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
  298. cache = self.cache_labels(cache_path) # re-cache
  299. else:
  300. cache = self.cache_labels(cache_path) # cache
  301. # Read cache
  302. cache.pop('hash') # remove hash
  303. labels, shapes = zip(*cache.values())
  304. self.labels = list(labels)
  305. self.shapes = np.array(shapes, dtype=np.float64)
  306. self.img_files = list(cache.keys()) # update
  307. self.label_files = img2label_paths(cache.keys()) # update
  308. n = len(shapes) # number of images
  309. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  310. nb = bi[-1] + 1 # number of batches
  311. self.batch = bi # batch index of image
  312. self.n = n
  313. # Rectangular Training
  314. if self.rect:
  315. # Sort by aspect ratio
  316. s = self.shapes # wh
  317. ar = s[:, 1] / s[:, 0] # aspect ratio
  318. irect = ar.argsort()
  319. self.img_files = [self.img_files[i] for i in irect]
  320. self.label_files = [self.label_files[i] for i in irect]
  321. self.labels = [self.labels[i] for i in irect]
  322. self.shapes = s[irect] # wh
  323. ar = ar[irect]
  324. # Set training image shapes
  325. shapes = [[1, 1]] * nb
  326. for i in range(nb):
  327. ari = ar[bi == i]
  328. mini, maxi = ari.min(), ari.max()
  329. if maxi < 1:
  330. shapes[i] = [maxi, 1]
  331. elif mini > 1:
  332. shapes[i] = [1, 1 / mini]
  333. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  334. # Check labels
  335. create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
  336. nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
  337. pbar = enumerate(self.label_files)
  338. if rank in [-1, 0]:
  339. pbar = tqdm(pbar)
  340. for i, file in pbar:
  341. l = self.labels[i] # label
  342. if l is not None and l.shape[0]:
  343. assert l.shape[1] == 5, '> 5 label columns: %s' % file
  344. assert (l >= 0).all(), 'negative labels: %s' % file
  345. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
  346. if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
  347. nd += 1 # print('WARNING: duplicate rows in %s' % self.label_files[i]) # duplicate rows
  348. if single_cls:
  349. l[:, 0] = 0 # force dataset into single-class mode
  350. self.labels[i] = l
  351. nf += 1 # file found
  352. # Create subdataset (a smaller dataset)
  353. if create_datasubset and ns < 1E4:
  354. if ns == 0:
  355. create_folder(path='./datasubset')
  356. os.makedirs('./datasubset/images')
  357. exclude_classes = 43
  358. if exclude_classes not in l[:, 0]:
  359. ns += 1
  360. # shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
  361. with open('./datasubset/images.txt', 'a') as f:
  362. f.write(self.img_files[i] + '\n')
  363. # Extract object detection boxes for a second stage classifier
  364. if extract_bounding_boxes:
  365. p = Path(self.img_files[i])
  366. img = cv2.imread(str(p))
  367. h, w = img.shape[:2]
  368. for j, x in enumerate(l):
  369. f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
  370. if not os.path.exists(Path(f).parent):
  371. os.makedirs(Path(f).parent) # make new output folder
  372. b = x[1:] * [w, h, w, h] # box
  373. b[2:] = b[2:].max() # rectangle to square
  374. b[2:] = b[2:] * 1.3 + 30 # pad
  375. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  376. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  377. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  378. assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
  379. else:
  380. ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
  381. # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
  382. if rank in [-1, 0]:
  383. pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
  384. cache_path, nf, nm, ne, nd, n)
  385. if nf == 0:
  386. s = 'WARNING: No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
  387. print(s)
  388. assert not augment, '%s. Can not train without labels.' % s
  389. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  390. self.imgs = [None] * n
  391. if cache_images:
  392. gb = 0 # Gigabytes of cached images
  393. self.img_hw0, self.img_hw = [None] * n, [None] * n
  394. results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads
  395. pbar = tqdm(enumerate(results), total=n)
  396. for i, x in pbar:
  397. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
  398. gb += self.imgs[i].nbytes
  399. pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
  400. def cache_labels(self, path='labels.cache'):
  401. # Cache dataset labels, check images and read shapes
  402. x = {} # dict
  403. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  404. for (img, label) in pbar:
  405. try:
  406. l = []
  407. im = Image.open(img)
  408. im.verify() # PIL verify
  409. shape = exif_size(im) # image size
  410. assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
  411. if os.path.isfile(label):
  412. with open(label, 'r') as f:
  413. l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels
  414. if len(l) == 0:
  415. l = np.zeros((0, 5), dtype=np.float32)
  416. x[img] = [l, shape]
  417. except Exception as e:
  418. print('WARNING: Ignoring corrupted image and/or label %s: %s' % (img, e))
  419. x['hash'] = get_hash(self.label_files + self.img_files)
  420. torch.save(x, path) # save for next time
  421. return x
  422. def __len__(self):
  423. return len(self.img_files)
  424. # def __iter__(self):
  425. # self.count = -1
  426. # print('ran dataset iter')
  427. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  428. # return self
  429. def __getitem__(self, index):
  430. if self.image_weights:
  431. index = self.indices[index]
  432. hyp = self.hyp
  433. mosaic = self.mosaic and random.random() < hyp['mosaic']
  434. if mosaic:
  435. # Load mosaic
  436. img, labels = load_mosaic(self, index)
  437. shapes = None
  438. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  439. if random.random() < hyp['mixup']:
  440. img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
  441. r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
  442. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  443. labels = np.concatenate((labels, labels2), 0)
  444. else:
  445. # Load image
  446. img, (h0, w0), (h, w) = load_image(self, index)
  447. # Letterbox
  448. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  449. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  450. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  451. # Load labels
  452. labels = []
  453. x = self.labels[index]
  454. if x.size > 0:
  455. # Normalized xywh to pixel xyxy format
  456. labels = x.copy()
  457. labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
  458. labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
  459. labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
  460. labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
  461. if self.augment:
  462. # Augment imagespace
  463. if not mosaic:
  464. img, labels = random_perspective(img, labels,
  465. degrees=hyp['degrees'],
  466. translate=hyp['translate'],
  467. scale=hyp['scale'],
  468. shear=hyp['shear'],
  469. perspective=hyp['perspective'])
  470. # Augment colorspace
  471. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  472. # Apply cutouts
  473. # if random.random() < 0.9:
  474. # labels = cutout(img, labels)
  475. nL = len(labels) # number of labels
  476. if nL:
  477. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  478. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  479. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  480. if self.augment:
  481. # flip up-down
  482. if random.random() < hyp['flipud']:
  483. img = np.flipud(img)
  484. if nL:
  485. labels[:, 2] = 1 - labels[:, 2]
  486. # flip left-right
  487. if random.random() < hyp['fliplr']:
  488. img = np.fliplr(img)
  489. if nL:
  490. labels[:, 1] = 1 - labels[:, 1]
  491. labels_out = torch.zeros((nL, 6))
  492. if nL:
  493. labels_out[:, 1:] = torch.from_numpy(labels)
  494. # Convert
  495. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  496. img = np.ascontiguousarray(img)
  497. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  498. @staticmethod
  499. def collate_fn(batch):
  500. img, label, path, shapes = zip(*batch) # transposed
  501. for i, l in enumerate(label):
  502. l[:, 0] = i # add target image index for build_targets()
  503. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  504. # Ancillary functions --------------------------------------------------------------------------------------------------
  505. def load_image(self, index):
  506. # loads 1 image from dataset, returns img, original hw, resized hw
  507. img = self.imgs[index]
  508. if img is None: # not cached
  509. path = self.img_files[index]
  510. img = cv2.imread(path) # BGR
  511. assert img is not None, 'Image Not Found ' + path
  512. h0, w0 = img.shape[:2] # orig hw
  513. r = self.img_size / max(h0, w0) # resize image to img_size
  514. if r != 1: # always resize down, only resize up if training with augmentation
  515. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  516. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  517. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  518. else:
  519. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  520. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  521. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  522. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  523. dtype = img.dtype # uint8
  524. x = np.arange(0, 256, dtype=np.int16)
  525. lut_hue = ((x * r[0]) % 180).astype(dtype)
  526. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  527. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  528. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  529. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  530. # Histogram equalization
  531. # if random.random() < 0.2:
  532. # for i in range(3):
  533. # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
  534. def load_mosaic(self, index):
  535. # loads images in a mosaic
  536. labels4 = []
  537. s = self.img_size
  538. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  539. indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
  540. for i, index in enumerate(indices):
  541. # Load image
  542. img, _, (h, w) = load_image(self, index)
  543. # place img in img4
  544. if i == 0: # top left
  545. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  546. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  547. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  548. elif i == 1: # top right
  549. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  550. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  551. elif i == 2: # bottom left
  552. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  553. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  554. elif i == 3: # bottom right
  555. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  556. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  557. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  558. padw = x1a - x1b
  559. padh = y1a - y1b
  560. # Labels
  561. x = self.labels[index]
  562. labels = x.copy()
  563. if x.size > 0: # Normalized xywh to pixel xyxy format
  564. labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
  565. labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
  566. labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
  567. labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
  568. labels4.append(labels)
  569. # Concat/clip labels
  570. if len(labels4):
  571. labels4 = np.concatenate(labels4, 0)
  572. np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective
  573. # img4, labels4 = replicate(img4, labels4) # replicate
  574. # Augment
  575. img4, labels4 = random_perspective(img4, labels4,
  576. degrees=self.hyp['degrees'],
  577. translate=self.hyp['translate'],
  578. scale=self.hyp['scale'],
  579. shear=self.hyp['shear'],
  580. perspective=self.hyp['perspective'],
  581. border=self.mosaic_border) # border to remove
  582. return img4, labels4
  583. def replicate(img, labels):
  584. # Replicate labels
  585. h, w = img.shape[:2]
  586. boxes = labels[:, 1:].astype(int)
  587. x1, y1, x2, y2 = boxes.T
  588. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  589. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  590. x1b, y1b, x2b, y2b = boxes[i]
  591. bh, bw = y2b - y1b, x2b - x1b
  592. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  593. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  594. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  595. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  596. return img, labels
  597. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
  598. # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
  599. shape = img.shape[:2] # current shape [height, width]
  600. if isinstance(new_shape, int):
  601. new_shape = (new_shape, new_shape)
  602. # Scale ratio (new / old)
  603. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  604. if not scaleup: # only scale down, do not scale up (for better test mAP)
  605. r = min(r, 1.0)
  606. # Compute padding
  607. ratio = r, r # width, height ratios
  608. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  609. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  610. if auto: # minimum rectangle
  611. dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
  612. elif scaleFill: # stretch
  613. dw, dh = 0.0, 0.0
  614. new_unpad = (new_shape[1], new_shape[0])
  615. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  616. dw /= 2 # divide padding into 2 sides
  617. dh /= 2
  618. if shape[::-1] != new_unpad: # resize
  619. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  620. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  621. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  622. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  623. return img, ratio, (dw, dh)
  624. def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, border=(0, 0)):
  625. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  626. # targets = [cls, xyxy]
  627. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  628. width = img.shape[1] + border[1] * 2
  629. # Center
  630. C = np.eye(3)
  631. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  632. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  633. # Perspective
  634. P = np.eye(3)
  635. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  636. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  637. # Rotation and Scale
  638. R = np.eye(3)
  639. a = random.uniform(-degrees, degrees)
  640. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  641. s = random.uniform(1 - scale, 1 + scale)
  642. # s = 2 ** random.uniform(-scale, scale)
  643. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  644. # Shear
  645. S = np.eye(3)
  646. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  647. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  648. # Translation
  649. T = np.eye(3)
  650. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  651. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  652. # Combined rotation matrix
  653. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  654. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  655. if perspective:
  656. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  657. else: # affine
  658. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  659. # Visualize
  660. # import matplotlib.pyplot as plt
  661. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  662. # ax[0].imshow(img[:, :, ::-1]) # base
  663. # ax[1].imshow(img2[:, :, ::-1]) # warped
  664. # Transform label coordinates
  665. n = len(targets)
  666. if n:
  667. # warp points
  668. xy = np.ones((n * 4, 3))
  669. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  670. xy = xy @ M.T # transform
  671. if perspective:
  672. xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
  673. else: # affine
  674. xy = xy[:, :2].reshape(n, 8)
  675. # create new boxes
  676. x = xy[:, [0, 2, 4, 6]]
  677. y = xy[:, [1, 3, 5, 7]]
  678. xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  679. # # apply angle-based reduction of bounding boxes
  680. # radians = a * math.pi / 180
  681. # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
  682. # x = (xy[:, 2] + xy[:, 0]) / 2
  683. # y = (xy[:, 3] + xy[:, 1]) / 2
  684. # w = (xy[:, 2] - xy[:, 0]) * reduction
  685. # h = (xy[:, 3] - xy[:, 1]) * reduction
  686. # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  687. # clip boxes
  688. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  689. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  690. # filter candidates
  691. i = box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T)
  692. targets = targets[i]
  693. targets[:, 1:5] = xy[i]
  694. return img, targets
  695. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1): # box1(4,n), box2(4,n)
  696. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  697. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  698. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  699. ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
  700. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates
  701. def cutout(image, labels):
  702. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  703. h, w = image.shape[:2]
  704. def bbox_ioa(box1, box2):
  705. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  706. box2 = box2.transpose()
  707. # Get the coordinates of bounding boxes
  708. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  709. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  710. # Intersection area
  711. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  712. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  713. # box2 area
  714. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  715. # Intersection over box2 area
  716. return inter_area / box2_area
  717. # create random masks
  718. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  719. for s in scales:
  720. mask_h = random.randint(1, int(h * s))
  721. mask_w = random.randint(1, int(w * s))
  722. # box
  723. xmin = max(0, random.randint(0, w) - mask_w // 2)
  724. ymin = max(0, random.randint(0, h) - mask_h // 2)
  725. xmax = min(w, xmin + mask_w)
  726. ymax = min(h, ymin + mask_h)
  727. # apply random color mask
  728. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  729. # return unobscured labels
  730. if len(labels) and s > 0.03:
  731. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  732. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  733. labels = labels[ioa < 0.60] # remove >60% obscured labels
  734. return labels
  735. def create_folder(path='./new'):
  736. # Create folder
  737. if os.path.exists(path):
  738. shutil.rmtree(path) # delete output folder
  739. os.makedirs(path) # make new output folder
  740. def flatten_recursive(path='../coco128'):
  741. # Flatten a recursive directory by bringing all files to top level
  742. new_path = Path(path + '_flat')
  743. create_folder(new_path)
  744. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  745. shutil.copyfile(file, new_path / Path(file).name)