您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

1020 行
41KB

  1. # Dataset utils and dataloaders
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import shutil
  8. import time
  9. from itertools import repeat
  10. from multiprocessing.pool import ThreadPool
  11. from pathlib import Path
  12. from threading import Thread
  13. import cv2
  14. import numpy as np
  15. import torch
  16. import torch.nn.functional as F
  17. from PIL import Image, ExifTags
  18. from torch.utils.data import Dataset
  19. from tqdm import tqdm
  20. from utils.general import xyxy2xywh, xywh2xyxy, xywhn2xyxy, clean_str
  21. from utils.torch_utils import torch_distributed_zero_first
  22. # Parameters
  23. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  24. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng'] # acceptable image suffixes
  25. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  26. logger = logging.getLogger(__name__)
  27. # Get orientation exif tag
  28. for orientation in ExifTags.TAGS.keys():
  29. if ExifTags.TAGS[orientation] == 'Orientation':
  30. break
  31. def get_hash(files):
  32. # Returns a single hash value of a list of files
  33. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  34. def exif_size(img):
  35. # Returns exif-corrected PIL size
  36. s = img.size # (width, height)
  37. try:
  38. rotation = dict(img._getexif().items())[orientation]
  39. if rotation == 6: # rotation 270
  40. s = (s[1], s[0])
  41. elif rotation == 8: # rotation 90
  42. s = (s[1], s[0])
  43. except:
  44. pass
  45. return s
  46. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  47. rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
  48. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  49. with torch_distributed_zero_first(rank):
  50. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  51. augment=augment, # augment images
  52. hyp=hyp, # augmentation hyperparameters
  53. rect=rect, # rectangular training
  54. cache_images=cache,
  55. single_cls=opt.single_cls,
  56. stride=int(stride),
  57. pad=pad,
  58. image_weights=image_weights,
  59. prefix=prefix)
  60. batch_size = min(batch_size, len(dataset))
  61. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  62. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  63. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  64. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  65. dataloader = loader(dataset,
  66. batch_size=batch_size,
  67. num_workers=nw,
  68. sampler=sampler,
  69. pin_memory=True,
  70. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  71. return dataloader, dataset
  72. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  73. """ Dataloader that reuses workers
  74. Uses same syntax as vanilla DataLoader
  75. """
  76. def __init__(self, *args, **kwargs):
  77. super().__init__(*args, **kwargs)
  78. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  79. self.iterator = super().__iter__()
  80. def __len__(self):
  81. return len(self.batch_sampler.sampler)
  82. def __iter__(self):
  83. for i in range(len(self)):
  84. yield next(self.iterator)
  85. class _RepeatSampler(object):
  86. """ Sampler that repeats forever
  87. Args:
  88. sampler (Sampler)
  89. """
  90. def __init__(self, sampler):
  91. self.sampler = sampler
  92. def __iter__(self):
  93. while True:
  94. yield from iter(self.sampler)
  95. class LoadImages: # for inference
  96. def __init__(self, path, img_size=640):
  97. p = str(Path(path)) # os-agnostic
  98. p = os.path.abspath(p) # absolute path
  99. if '*' in p:
  100. files = sorted(glob.glob(p, recursive=True)) # glob
  101. elif os.path.isdir(p):
  102. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  103. elif os.path.isfile(p):
  104. files = [p] # files
  105. else:
  106. raise Exception(f'ERROR: {p} does not exist')
  107. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  108. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  109. ni, nv = len(images), len(videos)
  110. self.img_size = img_size
  111. self.files = images + videos
  112. self.nf = ni + nv # number of files
  113. self.video_flag = [False] * ni + [True] * nv
  114. self.mode = 'image'
  115. if any(videos):
  116. self.new_video(videos[0]) # new video
  117. else:
  118. self.cap = None
  119. assert self.nf > 0, f'No images or videos found in {p}. ' \
  120. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
  121. def __iter__(self):
  122. self.count = 0
  123. return self
  124. def __next__(self):
  125. if self.count == self.nf:
  126. raise StopIteration
  127. path = self.files[self.count]
  128. if self.video_flag[self.count]:
  129. # Read video
  130. self.mode = 'video'
  131. ret_val, img0 = self.cap.read()
  132. if not ret_val:
  133. self.count += 1
  134. self.cap.release()
  135. if self.count == self.nf: # last video
  136. raise StopIteration
  137. else:
  138. path = self.files[self.count]
  139. self.new_video(path)
  140. ret_val, img0 = self.cap.read()
  141. self.frame += 1
  142. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
  143. else:
  144. # Read image
  145. self.count += 1
  146. img0 = cv2.imread(path) # BGR
  147. assert img0 is not None, 'Image Not Found ' + path
  148. print(f'image {self.count}/{self.nf} {path}: ', end='')
  149. # Padded resize
  150. img = letterbox(img0, new_shape=self.img_size)[0]
  151. # Convert
  152. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  153. img = np.ascontiguousarray(img)
  154. return path, img, img0, self.cap
  155. def new_video(self, path):
  156. self.frame = 0
  157. self.cap = cv2.VideoCapture(path)
  158. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  159. def __len__(self):
  160. return self.nf # number of files
  161. class LoadWebcam: # for inference
  162. def __init__(self, pipe='0', img_size=640):
  163. self.img_size = img_size
  164. if pipe.isnumeric():
  165. pipe = eval(pipe) # local camera
  166. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  167. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  168. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  169. self.pipe = pipe
  170. self.cap = cv2.VideoCapture(pipe) # video capture object
  171. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  172. def __iter__(self):
  173. self.count = -1
  174. return self
  175. def __next__(self):
  176. self.count += 1
  177. if cv2.waitKey(1) == ord('q'): # q to quit
  178. self.cap.release()
  179. cv2.destroyAllWindows()
  180. raise StopIteration
  181. # Read frame
  182. if self.pipe == 0: # local camera
  183. ret_val, img0 = self.cap.read()
  184. img0 = cv2.flip(img0, 1) # flip left-right
  185. else: # IP camera
  186. n = 0
  187. while True:
  188. n += 1
  189. self.cap.grab()
  190. if n % 30 == 0: # skip frames
  191. ret_val, img0 = self.cap.retrieve()
  192. if ret_val:
  193. break
  194. # Print
  195. assert ret_val, f'Camera Error {self.pipe}'
  196. img_path = 'webcam.jpg'
  197. print(f'webcam {self.count}: ', end='')
  198. # Padded resize
  199. img = letterbox(img0, new_shape=self.img_size)[0]
  200. # Convert
  201. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  202. img = np.ascontiguousarray(img)
  203. return img_path, img, img0, None
  204. def __len__(self):
  205. return 0
  206. class LoadStreams: # multiple IP or RTSP cameras
  207. def __init__(self, sources='streams.txt', img_size=640):
  208. self.mode = 'stream'
  209. self.img_size = img_size
  210. if os.path.isfile(sources):
  211. with open(sources, 'r') as f:
  212. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  213. else:
  214. sources = [sources]
  215. n = len(sources)
  216. self.imgs = [None] * n
  217. self.sources = [clean_str(x) for x in sources] # clean source names for later
  218. for i, s in enumerate(sources):
  219. # Start the thread to read frames from the video stream
  220. print(f'{i + 1}/{n}: {s}... ', end='')
  221. cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s)
  222. assert cap.isOpened(), f'Failed to open {s}'
  223. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  224. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  225. fps = cap.get(cv2.CAP_PROP_FPS) % 100
  226. _, self.imgs[i] = cap.read() # guarantee first frame
  227. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  228. print(f' success ({w}x{h} at {fps:.2f} FPS).')
  229. thread.start()
  230. print('') # newline
  231. # check for common shapes
  232. s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
  233. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  234. if not self.rect:
  235. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  236. def update(self, index, cap):
  237. # Read next stream frame in a daemon thread
  238. n = 0
  239. while cap.isOpened():
  240. n += 1
  241. # _, self.imgs[index] = cap.read()
  242. cap.grab()
  243. if n == 4: # read every 4th frame
  244. _, self.imgs[index] = cap.retrieve()
  245. n = 0
  246. time.sleep(0.01) # wait time
  247. def __iter__(self):
  248. self.count = -1
  249. return self
  250. def __next__(self):
  251. self.count += 1
  252. img0 = self.imgs.copy()
  253. if cv2.waitKey(1) == ord('q'): # q to quit
  254. cv2.destroyAllWindows()
  255. raise StopIteration
  256. # Letterbox
  257. img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
  258. # Stack
  259. img = np.stack(img, 0)
  260. # Convert
  261. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  262. img = np.ascontiguousarray(img)
  263. return self.sources, img, img0, None
  264. def __len__(self):
  265. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  266. def img2label_paths(img_paths):
  267. # Define label paths as a function of image paths
  268. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  269. return [x.replace(sa, sb, 1).replace('.' + x.split('.')[-1], '.txt') for x in img_paths]
  270. class LoadImagesAndLabels(Dataset): # for training/testing
  271. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  272. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  273. self.img_size = img_size
  274. self.augment = augment
  275. self.hyp = hyp
  276. self.image_weights = image_weights
  277. self.rect = False if image_weights else rect
  278. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  279. self.mosaic_border = [-img_size // 2, -img_size // 2]
  280. self.stride = stride
  281. try:
  282. f = [] # image files
  283. for p in path if isinstance(path, list) else [path]:
  284. p = Path(p) # os-agnostic
  285. if p.is_dir(): # dir
  286. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  287. elif p.is_file(): # file
  288. with open(p, 'r') as t:
  289. t = t.read().strip().splitlines()
  290. parent = str(p.parent) + os.sep
  291. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  292. else:
  293. raise Exception(f'{prefix}{p} does not exist')
  294. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
  295. assert self.img_files, f'{prefix}No images found'
  296. except Exception as e:
  297. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
  298. # Check cache
  299. self.label_files = img2label_paths(self.img_files) # labels
  300. cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') # cached labels
  301. if cache_path.is_file():
  302. cache = torch.load(cache_path) # load
  303. if cache['hash'] != get_hash(self.label_files + self.img_files) or 'results' not in cache: # changed
  304. cache = self.cache_labels(cache_path, prefix) # re-cache
  305. else:
  306. cache = self.cache_labels(cache_path, prefix) # cache
  307. # Display cache
  308. [nf, nm, ne, nc, n] = cache.pop('results') # found, missing, empty, corrupted, total
  309. desc = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  310. tqdm(None, desc=prefix + desc, total=n, initial=n)
  311. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
  312. # Read cache
  313. cache.pop('hash') # remove hash
  314. labels, shapes = zip(*cache.values())
  315. self.labels = list(labels)
  316. self.shapes = np.array(shapes, dtype=np.float64)
  317. self.img_files = list(cache.keys()) # update
  318. self.label_files = img2label_paths(cache.keys()) # update
  319. if single_cls:
  320. for x in self.labels:
  321. x[:, 0] = 0
  322. n = len(shapes) # number of images
  323. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  324. nb = bi[-1] + 1 # number of batches
  325. self.batch = bi # batch index of image
  326. self.n = n
  327. self.indices = range(n)
  328. # Rectangular Training
  329. if self.rect:
  330. # Sort by aspect ratio
  331. s = self.shapes # wh
  332. ar = s[:, 1] / s[:, 0] # aspect ratio
  333. irect = ar.argsort()
  334. self.img_files = [self.img_files[i] for i in irect]
  335. self.label_files = [self.label_files[i] for i in irect]
  336. self.labels = [self.labels[i] for i in irect]
  337. self.shapes = s[irect] # wh
  338. ar = ar[irect]
  339. # Set training image shapes
  340. shapes = [[1, 1]] * nb
  341. for i in range(nb):
  342. ari = ar[bi == i]
  343. mini, maxi = ari.min(), ari.max()
  344. if maxi < 1:
  345. shapes[i] = [maxi, 1]
  346. elif mini > 1:
  347. shapes[i] = [1, 1 / mini]
  348. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  349. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  350. self.imgs = [None] * n
  351. if cache_images:
  352. gb = 0 # Gigabytes of cached images
  353. self.img_hw0, self.img_hw = [None] * n, [None] * n
  354. results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads
  355. pbar = tqdm(enumerate(results), total=n)
  356. for i, x in pbar:
  357. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
  358. gb += self.imgs[i].nbytes
  359. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
  360. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  361. # Cache dataset labels, check images and read shapes
  362. x = {} # dict
  363. nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
  364. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  365. for i, (im_file, lb_file) in enumerate(pbar):
  366. try:
  367. # verify images
  368. im = Image.open(im_file)
  369. im.verify() # PIL verify
  370. shape = exif_size(im) # image size
  371. assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
  372. # verify labels
  373. if os.path.isfile(lb_file):
  374. nf += 1 # label found
  375. with open(lb_file, 'r') as f:
  376. l = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  377. if len(l):
  378. assert l.shape[1] == 5, 'labels require 5 columns each'
  379. assert (l >= 0).all(), 'negative labels'
  380. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  381. assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
  382. else:
  383. ne += 1 # label empty
  384. l = np.zeros((0, 5), dtype=np.float32)
  385. else:
  386. nm += 1 # label missing
  387. l = np.zeros((0, 5), dtype=np.float32)
  388. x[im_file] = [l, shape]
  389. except Exception as e:
  390. nc += 1
  391. print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
  392. pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \
  393. f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  394. if nf == 0:
  395. print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
  396. x['hash'] = get_hash(self.label_files + self.img_files)
  397. x['results'] = [nf, nm, ne, nc, i + 1]
  398. torch.save(x, path) # save for next time
  399. logging.info(f'{prefix}New cache created: {path}')
  400. return x
  401. def __len__(self):
  402. return len(self.img_files)
  403. # def __iter__(self):
  404. # self.count = -1
  405. # print('ran dataset iter')
  406. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  407. # return self
  408. def __getitem__(self, index):
  409. index = self.indices[index] # linear, shuffled, or image_weights
  410. hyp = self.hyp
  411. mosaic = self.mosaic and random.random() < hyp['mosaic']
  412. if mosaic:
  413. # Load mosaic
  414. img, labels = load_mosaic(self, index)
  415. shapes = None
  416. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  417. if random.random() < hyp['mixup']:
  418. img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
  419. r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
  420. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  421. labels = np.concatenate((labels, labels2), 0)
  422. else:
  423. # Load image
  424. img, (h0, w0), (h, w) = load_image(self, index)
  425. # Letterbox
  426. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  427. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  428. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  429. labels = self.labels[index].copy()
  430. if labels.size: # normalized xywh to pixel xyxy format
  431. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  432. if self.augment:
  433. # Augment imagespace
  434. if not mosaic:
  435. img, labels = random_perspective(img, labels,
  436. degrees=hyp['degrees'],
  437. translate=hyp['translate'],
  438. scale=hyp['scale'],
  439. shear=hyp['shear'],
  440. perspective=hyp['perspective'])
  441. # Augment colorspace
  442. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  443. # Apply cutouts
  444. # if random.random() < 0.9:
  445. # labels = cutout(img, labels)
  446. nL = len(labels) # number of labels
  447. if nL:
  448. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  449. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  450. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  451. if self.augment:
  452. # flip up-down
  453. if random.random() < hyp['flipud']:
  454. img = np.flipud(img)
  455. if nL:
  456. labels[:, 2] = 1 - labels[:, 2]
  457. # flip left-right
  458. if random.random() < hyp['fliplr']:
  459. img = np.fliplr(img)
  460. if nL:
  461. labels[:, 1] = 1 - labels[:, 1]
  462. labels_out = torch.zeros((nL, 6))
  463. if nL:
  464. labels_out[:, 1:] = torch.from_numpy(labels)
  465. # Convert
  466. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  467. img = np.ascontiguousarray(img)
  468. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  469. @staticmethod
  470. def collate_fn(batch):
  471. img, label, path, shapes = zip(*batch) # transposed
  472. for i, l in enumerate(label):
  473. l[:, 0] = i # add target image index for build_targets()
  474. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  475. @staticmethod
  476. def collate_fn4(batch):
  477. img, label, path, shapes = zip(*batch) # transposed
  478. n = len(shapes) // 4
  479. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  480. ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
  481. wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
  482. s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
  483. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  484. i *= 4
  485. if random.random() < 0.5:
  486. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
  487. 0].type(img[i].type())
  488. l = label[i]
  489. else:
  490. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  491. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  492. img4.append(im)
  493. label4.append(l)
  494. for i, l in enumerate(label4):
  495. l[:, 0] = i # add target image index for build_targets()
  496. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  497. # Ancillary functions --------------------------------------------------------------------------------------------------
  498. def load_image(self, index):
  499. # loads 1 image from dataset, returns img, original hw, resized hw
  500. img = self.imgs[index]
  501. if img is None: # not cached
  502. path = self.img_files[index]
  503. img = cv2.imread(path) # BGR
  504. assert img is not None, 'Image Not Found ' + path
  505. h0, w0 = img.shape[:2] # orig hw
  506. r = self.img_size / max(h0, w0) # resize image to img_size
  507. if r != 1: # always resize down, only resize up if training with augmentation
  508. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  509. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  510. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  511. else:
  512. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  513. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  514. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  515. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  516. dtype = img.dtype # uint8
  517. x = np.arange(0, 256, dtype=np.int16)
  518. lut_hue = ((x * r[0]) % 180).astype(dtype)
  519. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  520. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  521. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  522. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  523. # Histogram equalization
  524. # if random.random() < 0.2:
  525. # for i in range(3):
  526. # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
  527. def load_mosaic(self, index):
  528. # loads images in a 4-mosaic
  529. labels4 = []
  530. s = self.img_size
  531. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  532. indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(3)] # 3 additional image indices
  533. for i, index in enumerate(indices):
  534. # Load image
  535. img, _, (h, w) = load_image(self, index)
  536. # place img in img4
  537. if i == 0: # top left
  538. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  539. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  540. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  541. elif i == 1: # top right
  542. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  543. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  544. elif i == 2: # bottom left
  545. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  546. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  547. elif i == 3: # bottom right
  548. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  549. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  550. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  551. padw = x1a - x1b
  552. padh = y1a - y1b
  553. # Labels
  554. labels = self.labels[index].copy()
  555. if labels.size:
  556. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  557. labels4.append(labels)
  558. # Concat/clip labels
  559. if len(labels4):
  560. labels4 = np.concatenate(labels4, 0)
  561. np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective
  562. # img4, labels4 = replicate(img4, labels4) # replicate
  563. # Augment
  564. img4, labels4 = random_perspective(img4, labels4,
  565. degrees=self.hyp['degrees'],
  566. translate=self.hyp['translate'],
  567. scale=self.hyp['scale'],
  568. shear=self.hyp['shear'],
  569. perspective=self.hyp['perspective'],
  570. border=self.mosaic_border) # border to remove
  571. return img4, labels4
  572. def load_mosaic9(self, index):
  573. # loads images in a 9-mosaic
  574. labels9 = []
  575. s = self.img_size
  576. indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(8)] # 8 additional image indices
  577. for i, index in enumerate(indices):
  578. # Load image
  579. img, _, (h, w) = load_image(self, index)
  580. # place img in img9
  581. if i == 0: # center
  582. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  583. h0, w0 = h, w
  584. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  585. elif i == 1: # top
  586. c = s, s - h, s + w, s
  587. elif i == 2: # top right
  588. c = s + wp, s - h, s + wp + w, s
  589. elif i == 3: # right
  590. c = s + w0, s, s + w0 + w, s + h
  591. elif i == 4: # bottom right
  592. c = s + w0, s + hp, s + w0 + w, s + hp + h
  593. elif i == 5: # bottom
  594. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  595. elif i == 6: # bottom left
  596. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  597. elif i == 7: # left
  598. c = s - w, s + h0 - h, s, s + h0
  599. elif i == 8: # top left
  600. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  601. padx, pady = c[:2]
  602. x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
  603. # Labels
  604. labels = self.labels[index].copy()
  605. if labels.size:
  606. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  607. labels9.append(labels)
  608. # Image
  609. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  610. hp, wp = h, w # height, width previous
  611. # Offset
  612. yc, xc = [int(random.uniform(0, s)) for x in self.mosaic_border] # mosaic center x, y
  613. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  614. # Concat/clip labels
  615. if len(labels9):
  616. labels9 = np.concatenate(labels9, 0)
  617. labels9[:, [1, 3]] -= xc
  618. labels9[:, [2, 4]] -= yc
  619. np.clip(labels9[:, 1:], 0, 2 * s, out=labels9[:, 1:]) # use with random_perspective
  620. # img9, labels9 = replicate(img9, labels9) # replicate
  621. # Augment
  622. img9, labels9 = random_perspective(img9, labels9,
  623. degrees=self.hyp['degrees'],
  624. translate=self.hyp['translate'],
  625. scale=self.hyp['scale'],
  626. shear=self.hyp['shear'],
  627. perspective=self.hyp['perspective'],
  628. border=self.mosaic_border) # border to remove
  629. return img9, labels9
  630. def replicate(img, labels):
  631. # Replicate labels
  632. h, w = img.shape[:2]
  633. boxes = labels[:, 1:].astype(int)
  634. x1, y1, x2, y2 = boxes.T
  635. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  636. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  637. x1b, y1b, x2b, y2b = boxes[i]
  638. bh, bw = y2b - y1b, x2b - x1b
  639. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  640. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  641. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  642. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  643. return img, labels
  644. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
  645. # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
  646. shape = img.shape[:2] # current shape [height, width]
  647. if isinstance(new_shape, int):
  648. new_shape = (new_shape, new_shape)
  649. # Scale ratio (new / old)
  650. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  651. if not scaleup: # only scale down, do not scale up (for better test mAP)
  652. r = min(r, 1.0)
  653. # Compute padding
  654. ratio = r, r # width, height ratios
  655. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  656. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  657. if auto: # minimum rectangle
  658. dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
  659. elif scaleFill: # stretch
  660. dw, dh = 0.0, 0.0
  661. new_unpad = (new_shape[1], new_shape[0])
  662. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  663. dw /= 2 # divide padding into 2 sides
  664. dh /= 2
  665. if shape[::-1] != new_unpad: # resize
  666. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  667. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  668. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  669. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  670. return img, ratio, (dw, dh)
  671. def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, border=(0, 0)):
  672. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  673. # targets = [cls, xyxy]
  674. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  675. width = img.shape[1] + border[1] * 2
  676. # Center
  677. C = np.eye(3)
  678. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  679. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  680. # Perspective
  681. P = np.eye(3)
  682. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  683. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  684. # Rotation and Scale
  685. R = np.eye(3)
  686. a = random.uniform(-degrees, degrees)
  687. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  688. s = random.uniform(1 - scale, 1 + scale)
  689. # s = 2 ** random.uniform(-scale, scale)
  690. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  691. # Shear
  692. S = np.eye(3)
  693. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  694. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  695. # Translation
  696. T = np.eye(3)
  697. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  698. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  699. # Combined rotation matrix
  700. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  701. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  702. if perspective:
  703. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  704. else: # affine
  705. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  706. # Visualize
  707. # import matplotlib.pyplot as plt
  708. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  709. # ax[0].imshow(img[:, :, ::-1]) # base
  710. # ax[1].imshow(img2[:, :, ::-1]) # warped
  711. # Transform label coordinates
  712. n = len(targets)
  713. if n:
  714. # warp points
  715. xy = np.ones((n * 4, 3))
  716. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  717. xy = xy @ M.T # transform
  718. if perspective:
  719. xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
  720. else: # affine
  721. xy = xy[:, :2].reshape(n, 8)
  722. # create new boxes
  723. x = xy[:, [0, 2, 4, 6]]
  724. y = xy[:, [1, 3, 5, 7]]
  725. xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  726. # # apply angle-based reduction of bounding boxes
  727. # radians = a * math.pi / 180
  728. # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
  729. # x = (xy[:, 2] + xy[:, 0]) / 2
  730. # y = (xy[:, 3] + xy[:, 1]) / 2
  731. # w = (xy[:, 2] - xy[:, 0]) * reduction
  732. # h = (xy[:, 3] - xy[:, 1]) * reduction
  733. # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  734. # clip boxes
  735. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  736. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  737. # filter candidates
  738. i = box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T)
  739. targets = targets[i]
  740. targets[:, 1:5] = xy[i]
  741. return img, targets
  742. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
  743. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  744. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  745. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  746. ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
  747. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
  748. def cutout(image, labels):
  749. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  750. h, w = image.shape[:2]
  751. def bbox_ioa(box1, box2):
  752. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  753. box2 = box2.transpose()
  754. # Get the coordinates of bounding boxes
  755. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  756. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  757. # Intersection area
  758. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  759. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  760. # box2 area
  761. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  762. # Intersection over box2 area
  763. return inter_area / box2_area
  764. # create random masks
  765. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  766. for s in scales:
  767. mask_h = random.randint(1, int(h * s))
  768. mask_w = random.randint(1, int(w * s))
  769. # box
  770. xmin = max(0, random.randint(0, w) - mask_w // 2)
  771. ymin = max(0, random.randint(0, h) - mask_h // 2)
  772. xmax = min(w, xmin + mask_w)
  773. ymax = min(h, ymin + mask_h)
  774. # apply random color mask
  775. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  776. # return unobscured labels
  777. if len(labels) and s > 0.03:
  778. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  779. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  780. labels = labels[ioa < 0.60] # remove >60% obscured labels
  781. return labels
  782. def create_folder(path='./new'):
  783. # Create folder
  784. if os.path.exists(path):
  785. shutil.rmtree(path) # delete output folder
  786. os.makedirs(path) # make new output folder
  787. def flatten_recursive(path='../coco128'):
  788. # Flatten a recursive directory by bringing all files to top level
  789. new_path = Path(path + '_flat')
  790. create_folder(new_path)
  791. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  792. shutil.copyfile(file, new_path / Path(file).name)
  793. def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128')
  794. # Convert detection dataset into classification dataset, with one directory per class
  795. path = Path(path) # images dir
  796. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  797. files = list(path.rglob('*.*'))
  798. n = len(files) # number of files
  799. for im_file in tqdm(files, total=n):
  800. if im_file.suffix[1:] in img_formats:
  801. # image
  802. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  803. h, w = im.shape[:2]
  804. # labels
  805. lb_file = Path(img2label_paths([str(im_file)])[0])
  806. if Path(lb_file).exists():
  807. with open(lb_file, 'r') as f:
  808. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  809. for j, x in enumerate(lb):
  810. c = int(x[0]) # class
  811. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  812. if not f.parent.is_dir():
  813. f.parent.mkdir(parents=True)
  814. b = x[1:] * [w, h, w, h] # box
  815. # b[2:] = b[2:].max() # rectangle to square
  816. b[2:] = b[2:] * 1.2 + 3 # pad
  817. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  818. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  819. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  820. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  821. def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0)): # from utils.datasets import *; autosplit('../coco128')
  822. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  823. # Arguments
  824. path: Path to images directory
  825. weights: Train, val, test weights (list)
  826. """
  827. path = Path(path) # images dir
  828. files = list(path.rglob('*.*'))
  829. n = len(files) # number of files
  830. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  831. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  832. [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing
  833. for i, img in tqdm(zip(indices, files), total=n):
  834. if img.suffix[1:] in img_formats:
  835. with open(path / txt[i], 'a') as f:
  836. f.write(str(img) + '\n') # add image to txt file