You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

949 lines
38KB

  1. import glob
  2. import math
  3. import os
  4. import random
  5. import shutil
  6. import time
  7. from itertools import repeat
  8. from multiprocessing.pool import ThreadPool
  9. from pathlib import Path
  10. from threading import Thread
  11. import cv2
  12. import numpy as np
  13. import torch
  14. from PIL import Image, ExifTags
  15. from torch.utils.data import Dataset
  16. from tqdm import tqdm
  17. from utils.general import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first
  18. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  19. img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
  20. vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
  21. # Get orientation exif tag
  22. for orientation in ExifTags.TAGS.keys():
  23. if ExifTags.TAGS[orientation] == 'Orientation':
  24. break
  25. def get_hash(files):
  26. # Returns a single hash value of a list of files
  27. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  28. def exif_size(img):
  29. # Returns exif-corrected PIL size
  30. s = img.size # (width, height)
  31. try:
  32. rotation = dict(img._getexif().items())[orientation]
  33. if rotation == 6: # rotation 270
  34. s = (s[1], s[0])
  35. elif rotation == 8: # rotation 90
  36. s = (s[1], s[0])
  37. except:
  38. pass
  39. return s
  40. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  41. rank=-1, world_size=1, workers=8):
  42. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
  43. with torch_distributed_zero_first(rank):
  44. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  45. augment=augment, # augment images
  46. hyp=hyp, # augmentation hyperparameters
  47. rect=rect, # rectangular training
  48. cache_images=cache,
  49. single_cls=opt.single_cls,
  50. stride=int(stride),
  51. pad=pad,
  52. rank=rank)
  53. batch_size = min(batch_size, len(dataset))
  54. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  55. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  56. dataloader = InfiniteDataLoader(dataset,
  57. batch_size=batch_size,
  58. num_workers=nw,
  59. sampler=sampler,
  60. pin_memory=True,
  61. collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader()
  62. return dataloader, dataset
  63. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  64. """ Dataloader that reuses workers.
  65. Uses same syntax as vanilla DataLoader.
  66. """
  67. def __init__(self, *args, **kwargs):
  68. super().__init__(*args, **kwargs)
  69. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  70. self.iterator = super().__iter__()
  71. def __len__(self):
  72. return len(self.batch_sampler.sampler)
  73. def __iter__(self):
  74. for i in range(len(self)):
  75. yield next(self.iterator)
  76. class _RepeatSampler(object):
  77. """ Sampler that repeats forever.
  78. Args:
  79. sampler (Sampler)
  80. """
  81. def __init__(self, sampler):
  82. self.sampler = sampler
  83. def __iter__(self):
  84. while True:
  85. yield from iter(self.sampler)
  86. class LoadImages: # for inference
  87. def __init__(self, path, img_size=640):
  88. p = str(Path(path)) # os-agnostic
  89. p = os.path.abspath(p) # absolute path
  90. if '*' in p:
  91. files = sorted(glob.glob(p, recursive=True)) # glob
  92. elif os.path.isdir(p):
  93. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  94. elif os.path.isfile(p):
  95. files = [p] # files
  96. else:
  97. raise Exception('ERROR: %s does not exist' % p)
  98. images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
  99. videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
  100. ni, nv = len(images), len(videos)
  101. self.img_size = img_size
  102. self.files = images + videos
  103. self.nf = ni + nv # number of files
  104. self.video_flag = [False] * ni + [True] * nv
  105. self.mode = 'images'
  106. if any(videos):
  107. self.new_video(videos[0]) # new video
  108. else:
  109. self.cap = None
  110. assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
  111. (p, img_formats, vid_formats)
  112. def __iter__(self):
  113. self.count = 0
  114. return self
  115. def __next__(self):
  116. if self.count == self.nf:
  117. raise StopIteration
  118. path = self.files[self.count]
  119. if self.video_flag[self.count]:
  120. # Read video
  121. self.mode = 'video'
  122. ret_val, img0 = self.cap.read()
  123. if not ret_val:
  124. self.count += 1
  125. self.cap.release()
  126. if self.count == self.nf: # last video
  127. raise StopIteration
  128. else:
  129. path = self.files[self.count]
  130. self.new_video(path)
  131. ret_val, img0 = self.cap.read()
  132. self.frame += 1
  133. print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')
  134. else:
  135. # Read image
  136. self.count += 1
  137. img0 = cv2.imread(path) # BGR
  138. assert img0 is not None, 'Image Not Found ' + path
  139. print('image %g/%g %s: ' % (self.count, self.nf, path), end='')
  140. # Padded resize
  141. img = letterbox(img0, new_shape=self.img_size)[0]
  142. # Convert
  143. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  144. img = np.ascontiguousarray(img)
  145. # cv2.imwrite(path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
  146. return path, img, img0, self.cap
  147. def new_video(self, path):
  148. self.frame = 0
  149. self.cap = cv2.VideoCapture(path)
  150. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  151. def __len__(self):
  152. return self.nf # number of files
  153. class LoadWebcam: # for inference
  154. def __init__(self, pipe=0, img_size=640):
  155. self.img_size = img_size
  156. if pipe == '0':
  157. pipe = 0 # local camera
  158. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  159. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  160. # pipe = 'rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa' # IP traffic camera
  161. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  162. # https://answers.opencv.org/question/215996/changing-gstreamer-pipeline-to-opencv-in-pythonsolved/
  163. # pipe = '"rtspsrc location="rtsp://username:password@192.168.1.64/1" latency=10 ! appsink' # GStreamer
  164. # https://answers.opencv.org/question/200787/video-acceleration-gstremer-pipeline-in-videocapture/
  165. # https://stackoverflow.com/questions/54095699/install-gstreamer-support-for-opencv-python-package # install help
  166. # pipe = "rtspsrc location=rtsp://root:root@192.168.0.91:554/axis-media/media.amp?videocodec=h264&resolution=3840x2160 protocols=GST_RTSP_LOWER_TRANS_TCP ! rtph264depay ! queue ! vaapih264dec ! videoconvert ! appsink" # GStreamer
  167. self.pipe = pipe
  168. self.cap = cv2.VideoCapture(pipe) # video capture object
  169. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  170. def __iter__(self):
  171. self.count = -1
  172. return self
  173. def __next__(self):
  174. self.count += 1
  175. if cv2.waitKey(1) == ord('q'): # q to quit
  176. self.cap.release()
  177. cv2.destroyAllWindows()
  178. raise StopIteration
  179. # Read frame
  180. if self.pipe == 0: # local camera
  181. ret_val, img0 = self.cap.read()
  182. img0 = cv2.flip(img0, 1) # flip left-right
  183. else: # IP camera
  184. n = 0
  185. while True:
  186. n += 1
  187. self.cap.grab()
  188. if n % 30 == 0: # skip frames
  189. ret_val, img0 = self.cap.retrieve()
  190. if ret_val:
  191. break
  192. # Print
  193. assert ret_val, 'Camera Error %s' % self.pipe
  194. img_path = 'webcam.jpg'
  195. print('webcam %g: ' % self.count, end='')
  196. # Padded resize
  197. img = letterbox(img0, new_shape=self.img_size)[0]
  198. # Convert
  199. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  200. img = np.ascontiguousarray(img)
  201. return img_path, img, img0, None
  202. def __len__(self):
  203. return 0
  204. class LoadStreams: # multiple IP or RTSP cameras
  205. def __init__(self, sources='streams.txt', img_size=640):
  206. self.mode = 'images'
  207. self.img_size = img_size
  208. if os.path.isfile(sources):
  209. with open(sources, 'r') as f:
  210. sources = [x.strip() for x in f.read().splitlines() if len(x.strip())]
  211. else:
  212. sources = [sources]
  213. n = len(sources)
  214. self.imgs = [None] * n
  215. self.sources = sources
  216. for i, s in enumerate(sources):
  217. # Start the thread to read frames from the video stream
  218. print('%g/%g: %s... ' % (i + 1, n, s), end='')
  219. cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s)
  220. assert cap.isOpened(), 'Failed to open %s' % s
  221. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  222. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  223. fps = cap.get(cv2.CAP_PROP_FPS) % 100
  224. _, self.imgs[i] = cap.read() # guarantee first frame
  225. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  226. print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
  227. thread.start()
  228. print('') # newline
  229. # check for common shapes
  230. s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
  231. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  232. if not self.rect:
  233. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  234. def update(self, index, cap):
  235. # Read next stream frame in a daemon thread
  236. n = 0
  237. while cap.isOpened():
  238. n += 1
  239. # _, self.imgs[index] = cap.read()
  240. cap.grab()
  241. if n == 4: # read every 4th frame
  242. _, self.imgs[index] = cap.retrieve()
  243. n = 0
  244. time.sleep(0.01) # wait time
  245. def __iter__(self):
  246. self.count = -1
  247. return self
  248. def __next__(self):
  249. self.count += 1
  250. img0 = self.imgs.copy()
  251. if cv2.waitKey(1) == ord('q'): # q to quit
  252. cv2.destroyAllWindows()
  253. raise StopIteration
  254. # Letterbox
  255. img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
  256. # Stack
  257. img = np.stack(img, 0)
  258. # Convert
  259. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  260. img = np.ascontiguousarray(img)
  261. return self.sources, img, img0, None
  262. def __len__(self):
  263. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  264. class LoadImagesAndLabels(Dataset): # for training/testing
  265. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  266. cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
  267. self.img_size = img_size
  268. self.augment = augment
  269. self.hyp = hyp
  270. self.image_weights = image_weights
  271. self.rect = False if image_weights else rect
  272. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  273. self.mosaic_border = [-img_size // 2, -img_size // 2]
  274. self.stride = stride
  275. def img2label_paths(img_paths):
  276. # Define label paths as a function of image paths
  277. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  278. return [x.replace(sa, sb, 1).replace(os.path.splitext(x)[-1], '.txt') for x in img_paths]
  279. try:
  280. f = [] # image files
  281. for p in path if isinstance(path, list) else [path]:
  282. p = str(Path(p)) # os-agnostic
  283. parent = str(Path(p).parent) + os.sep
  284. if os.path.isfile(p): # file
  285. with open(p, 'r') as t:
  286. t = t.read().splitlines()
  287. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  288. elif os.path.isdir(p): # folder
  289. f += glob.iglob(p + os.sep + '*.*')
  290. else:
  291. raise Exception('%s does not exist' % p)
  292. self.img_files = sorted(
  293. [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
  294. assert len(self.img_files) > 0, 'No images found'
  295. except Exception as e:
  296. raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
  297. # Check cache
  298. self.label_files = img2label_paths(self.img_files) # labels
  299. cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
  300. if os.path.isfile(cache_path):
  301. cache = torch.load(cache_path) # load
  302. if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
  303. cache = self.cache_labels(cache_path) # re-cache
  304. else:
  305. cache = self.cache_labels(cache_path) # cache
  306. # Read cache
  307. cache.pop('hash') # remove hash
  308. labels, shapes = zip(*cache.values())
  309. self.labels = list(labels)
  310. self.shapes = np.array(shapes, dtype=np.float64)
  311. self.img_files = list(cache.keys()) # update
  312. self.label_files = img2label_paths(cache.keys()) # update
  313. n = len(shapes) # number of images
  314. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  315. nb = bi[-1] + 1 # number of batches
  316. self.batch = bi # batch index of image
  317. self.n = n
  318. # Rectangular Training
  319. if self.rect:
  320. # Sort by aspect ratio
  321. s = self.shapes # wh
  322. ar = s[:, 1] / s[:, 0] # aspect ratio
  323. irect = ar.argsort()
  324. self.img_files = [self.img_files[i] for i in irect]
  325. self.label_files = [self.label_files[i] for i in irect]
  326. self.labels = [self.labels[i] for i in irect]
  327. self.shapes = s[irect] # wh
  328. ar = ar[irect]
  329. # Set training image shapes
  330. shapes = [[1, 1]] * nb
  331. for i in range(nb):
  332. ari = ar[bi == i]
  333. mini, maxi = ari.min(), ari.max()
  334. if maxi < 1:
  335. shapes[i] = [maxi, 1]
  336. elif mini > 1:
  337. shapes[i] = [1, 1 / mini]
  338. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  339. # Check labels
  340. create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
  341. nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
  342. pbar = enumerate(self.label_files)
  343. if rank in [-1, 0]:
  344. pbar = tqdm(pbar)
  345. for i, file in pbar:
  346. l = self.labels[i] # label
  347. if l is not None and l.shape[0]:
  348. assert l.shape[1] == 5, '> 5 label columns: %s' % file
  349. assert (l >= 0).all(), 'negative labels: %s' % file
  350. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
  351. if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
  352. nd += 1 # print('WARNING: duplicate rows in %s' % self.label_files[i]) # duplicate rows
  353. if single_cls:
  354. l[:, 0] = 0 # force dataset into single-class mode
  355. self.labels[i] = l
  356. nf += 1 # file found
  357. # Create subdataset (a smaller dataset)
  358. if create_datasubset and ns < 1E4:
  359. if ns == 0:
  360. create_folder(path='./datasubset')
  361. os.makedirs('./datasubset/images')
  362. exclude_classes = 43
  363. if exclude_classes not in l[:, 0]:
  364. ns += 1
  365. # shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
  366. with open('./datasubset/images.txt', 'a') as f:
  367. f.write(self.img_files[i] + '\n')
  368. # Extract object detection boxes for a second stage classifier
  369. if extract_bounding_boxes:
  370. p = Path(self.img_files[i])
  371. img = cv2.imread(str(p))
  372. h, w = img.shape[:2]
  373. for j, x in enumerate(l):
  374. f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
  375. if not os.path.exists(Path(f).parent):
  376. os.makedirs(Path(f).parent) # make new output folder
  377. b = x[1:] * [w, h, w, h] # box
  378. b[2:] = b[2:].max() # rectangle to square
  379. b[2:] = b[2:] * 1.3 + 30 # pad
  380. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  381. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  382. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  383. assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
  384. else:
  385. ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
  386. # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
  387. if rank in [-1, 0]:
  388. pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
  389. cache_path, nf, nm, ne, nd, n)
  390. if nf == 0:
  391. s = 'WARNING: No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
  392. print(s)
  393. assert not augment, '%s. Can not train without labels.' % s
  394. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  395. self.imgs = [None] * n
  396. if cache_images:
  397. gb = 0 # Gigabytes of cached images
  398. self.img_hw0, self.img_hw = [None] * n, [None] * n
  399. results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads
  400. pbar = tqdm(enumerate(results), total=n)
  401. for i, x in pbar:
  402. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
  403. gb += self.imgs[i].nbytes
  404. pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
  405. def cache_labels(self, path='labels.cache'):
  406. # Cache dataset labels, check images and read shapes
  407. x = {} # dict
  408. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  409. for (img, label) in pbar:
  410. try:
  411. l = []
  412. im = Image.open(img)
  413. im.verify() # PIL verify
  414. shape = exif_size(im) # image size
  415. assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
  416. if os.path.isfile(label):
  417. with open(label, 'r') as f:
  418. l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels
  419. if len(l) == 0:
  420. l = np.zeros((0, 5), dtype=np.float32)
  421. x[img] = [l, shape]
  422. except Exception as e:
  423. print('WARNING: Ignoring corrupted image and/or label %s: %s' % (img, e))
  424. x['hash'] = get_hash(self.label_files + self.img_files)
  425. torch.save(x, path) # save for next time
  426. return x
  427. def __len__(self):
  428. return len(self.img_files)
  429. # def __iter__(self):
  430. # self.count = -1
  431. # print('ran dataset iter')
  432. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  433. # return self
  434. def __getitem__(self, index):
  435. if self.image_weights:
  436. index = self.indices[index]
  437. hyp = self.hyp
  438. mosaic = self.mosaic and random.random() < hyp['mosaic']
  439. if mosaic:
  440. # Load mosaic
  441. img, labels = load_mosaic(self, index)
  442. shapes = None
  443. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  444. if random.random() < hyp['mixup']:
  445. img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
  446. r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
  447. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  448. labels = np.concatenate((labels, labels2), 0)
  449. else:
  450. # Load image
  451. img, (h0, w0), (h, w) = load_image(self, index)
  452. # Letterbox
  453. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  454. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  455. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  456. # Load labels
  457. labels = []
  458. x = self.labels[index]
  459. if x.size > 0:
  460. # Normalized xywh to pixel xyxy format
  461. labels = x.copy()
  462. labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
  463. labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
  464. labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
  465. labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
  466. if self.augment:
  467. # Augment imagespace
  468. if not mosaic:
  469. img, labels = random_perspective(img, labels,
  470. degrees=hyp['degrees'],
  471. translate=hyp['translate'],
  472. scale=hyp['scale'],
  473. shear=hyp['shear'],
  474. perspective=hyp['perspective'])
  475. # Augment colorspace
  476. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  477. # Apply cutouts
  478. # if random.random() < 0.9:
  479. # labels = cutout(img, labels)
  480. nL = len(labels) # number of labels
  481. if nL:
  482. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  483. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  484. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  485. if self.augment:
  486. # flip up-down
  487. if random.random() < hyp['flipud']:
  488. img = np.flipud(img)
  489. if nL:
  490. labels[:, 2] = 1 - labels[:, 2]
  491. # flip left-right
  492. if random.random() < hyp['fliplr']:
  493. img = np.fliplr(img)
  494. if nL:
  495. labels[:, 1] = 1 - labels[:, 1]
  496. labels_out = torch.zeros((nL, 6))
  497. if nL:
  498. labels_out[:, 1:] = torch.from_numpy(labels)
  499. # Convert
  500. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  501. img = np.ascontiguousarray(img)
  502. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  503. @staticmethod
  504. def collate_fn(batch):
  505. img, label, path, shapes = zip(*batch) # transposed
  506. for i, l in enumerate(label):
  507. l[:, 0] = i # add target image index for build_targets()
  508. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  509. # Ancillary functions --------------------------------------------------------------------------------------------------
  510. def load_image(self, index):
  511. # loads 1 image from dataset, returns img, original hw, resized hw
  512. img = self.imgs[index]
  513. if img is None: # not cached
  514. path = self.img_files[index]
  515. img = cv2.imread(path) # BGR
  516. assert img is not None, 'Image Not Found ' + path
  517. h0, w0 = img.shape[:2] # orig hw
  518. r = self.img_size / max(h0, w0) # resize image to img_size
  519. if r != 1: # always resize down, only resize up if training with augmentation
  520. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  521. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  522. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  523. else:
  524. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  525. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  526. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  527. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  528. dtype = img.dtype # uint8
  529. x = np.arange(0, 256, dtype=np.int16)
  530. lut_hue = ((x * r[0]) % 180).astype(dtype)
  531. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  532. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  533. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  534. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  535. # Histogram equalization
  536. # if random.random() < 0.2:
  537. # for i in range(3):
  538. # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
  539. def load_mosaic(self, index):
  540. # loads images in a mosaic
  541. labels4 = []
  542. s = self.img_size
  543. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  544. indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
  545. for i, index in enumerate(indices):
  546. # Load image
  547. img, _, (h, w) = load_image(self, index)
  548. # place img in img4
  549. if i == 0: # top left
  550. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  551. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  552. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  553. elif i == 1: # top right
  554. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  555. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  556. elif i == 2: # bottom left
  557. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  558. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  559. elif i == 3: # bottom right
  560. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  561. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  562. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  563. padw = x1a - x1b
  564. padh = y1a - y1b
  565. # Labels
  566. x = self.labels[index]
  567. labels = x.copy()
  568. if x.size > 0: # Normalized xywh to pixel xyxy format
  569. labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
  570. labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
  571. labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
  572. labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
  573. labels4.append(labels)
  574. # Concat/clip labels
  575. if len(labels4):
  576. labels4 = np.concatenate(labels4, 0)
  577. np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective
  578. # img4, labels4 = replicate(img4, labels4) # replicate
  579. # Augment
  580. img4, labels4 = random_perspective(img4, labels4,
  581. degrees=self.hyp['degrees'],
  582. translate=self.hyp['translate'],
  583. scale=self.hyp['scale'],
  584. shear=self.hyp['shear'],
  585. perspective=self.hyp['perspective'],
  586. border=self.mosaic_border) # border to remove
  587. return img4, labels4
  588. def replicate(img, labels):
  589. # Replicate labels
  590. h, w = img.shape[:2]
  591. boxes = labels[:, 1:].astype(int)
  592. x1, y1, x2, y2 = boxes.T
  593. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  594. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  595. x1b, y1b, x2b, y2b = boxes[i]
  596. bh, bw = y2b - y1b, x2b - x1b
  597. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  598. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  599. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  600. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  601. return img, labels
  602. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
  603. # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
  604. shape = img.shape[:2] # current shape [height, width]
  605. if isinstance(new_shape, int):
  606. new_shape = (new_shape, new_shape)
  607. # Scale ratio (new / old)
  608. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  609. if not scaleup: # only scale down, do not scale up (for better test mAP)
  610. r = min(r, 1.0)
  611. # Compute padding
  612. ratio = r, r # width, height ratios
  613. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  614. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  615. if auto: # minimum rectangle
  616. dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
  617. elif scaleFill: # stretch
  618. dw, dh = 0.0, 0.0
  619. new_unpad = (new_shape[1], new_shape[0])
  620. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  621. dw /= 2 # divide padding into 2 sides
  622. dh /= 2
  623. if shape[::-1] != new_unpad: # resize
  624. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  625. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  626. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  627. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  628. return img, ratio, (dw, dh)
  629. def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, border=(0, 0)):
  630. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  631. # targets = [cls, xyxy]
  632. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  633. width = img.shape[1] + border[1] * 2
  634. # Center
  635. C = np.eye(3)
  636. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  637. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  638. # Perspective
  639. P = np.eye(3)
  640. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  641. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  642. # Rotation and Scale
  643. R = np.eye(3)
  644. a = random.uniform(-degrees, degrees)
  645. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  646. s = random.uniform(1 - scale, 1 + scale)
  647. # s = 2 ** random.uniform(-scale, scale)
  648. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  649. # Shear
  650. S = np.eye(3)
  651. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  652. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  653. # Translation
  654. T = np.eye(3)
  655. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  656. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  657. # Combined rotation matrix
  658. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  659. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  660. if perspective:
  661. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  662. else: # affine
  663. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  664. # Visualize
  665. # import matplotlib.pyplot as plt
  666. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  667. # ax[0].imshow(img[:, :, ::-1]) # base
  668. # ax[1].imshow(img2[:, :, ::-1]) # warped
  669. # Transform label coordinates
  670. n = len(targets)
  671. if n:
  672. # warp points
  673. xy = np.ones((n * 4, 3))
  674. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  675. xy = xy @ M.T # transform
  676. if perspective:
  677. xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
  678. else: # affine
  679. xy = xy[:, :2].reshape(n, 8)
  680. # create new boxes
  681. x = xy[:, [0, 2, 4, 6]]
  682. y = xy[:, [1, 3, 5, 7]]
  683. xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  684. # # apply angle-based reduction of bounding boxes
  685. # radians = a * math.pi / 180
  686. # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
  687. # x = (xy[:, 2] + xy[:, 0]) / 2
  688. # y = (xy[:, 3] + xy[:, 1]) / 2
  689. # w = (xy[:, 2] - xy[:, 0]) * reduction
  690. # h = (xy[:, 3] - xy[:, 1]) * reduction
  691. # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  692. # clip boxes
  693. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  694. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  695. # filter candidates
  696. i = box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T)
  697. targets = targets[i]
  698. targets[:, 1:5] = xy[i]
  699. return img, targets
  700. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1): # box1(4,n), box2(4,n)
  701. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  702. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  703. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  704. ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
  705. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates
  706. def cutout(image, labels):
  707. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  708. h, w = image.shape[:2]
  709. def bbox_ioa(box1, box2):
  710. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  711. box2 = box2.transpose()
  712. # Get the coordinates of bounding boxes
  713. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  714. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  715. # Intersection area
  716. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  717. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  718. # box2 area
  719. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  720. # Intersection over box2 area
  721. return inter_area / box2_area
  722. # create random masks
  723. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  724. for s in scales:
  725. mask_h = random.randint(1, int(h * s))
  726. mask_w = random.randint(1, int(w * s))
  727. # box
  728. xmin = max(0, random.randint(0, w) - mask_w // 2)
  729. ymin = max(0, random.randint(0, h) - mask_h // 2)
  730. xmax = min(w, xmin + mask_w)
  731. ymax = min(h, ymin + mask_h)
  732. # apply random color mask
  733. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  734. # return unobscured labels
  735. if len(labels) and s > 0.03:
  736. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  737. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  738. labels = labels[ioa < 0.60] # remove >60% obscured labels
  739. return labels
  740. def reduce_img_size(path='path/images', img_size=1024): # from utils.datasets import *; reduce_img_size()
  741. # creates a new ./images_reduced folder with reduced size images of maximum size img_size
  742. path_new = path + '_reduced' # reduced images path
  743. create_folder(path_new)
  744. for f in tqdm(glob.glob('%s/*.*' % path)):
  745. try:
  746. img = cv2.imread(f)
  747. h, w = img.shape[:2]
  748. r = img_size / max(h, w) # size ratio
  749. if r < 1.0:
  750. img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_AREA) # _LINEAR fastest
  751. fnew = f.replace(path, path_new) # .replace(Path(f).suffix, '.jpg')
  752. cv2.imwrite(fnew, img)
  753. except:
  754. print('WARNING: image failure %s' % f)
  755. def recursive_dataset2bmp(dataset='path/dataset_bmp'): # from utils.datasets import *; recursive_dataset2bmp()
  756. # Converts dataset to bmp (for faster training)
  757. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  758. for a, b, files in os.walk(dataset):
  759. for file in tqdm(files, desc=a):
  760. p = a + '/' + file
  761. s = Path(file).suffix
  762. if s == '.txt': # replace text
  763. with open(p, 'r') as f:
  764. lines = f.read()
  765. for f in formats:
  766. lines = lines.replace(f, '.bmp')
  767. with open(p, 'w') as f:
  768. f.write(lines)
  769. elif s in formats: # replace image
  770. cv2.imwrite(p.replace(s, '.bmp'), cv2.imread(p))
  771. if s != '.bmp':
  772. os.system("rm '%s'" % p)
  773. def imagelist2folder(path='path/images.txt'): # from utils.datasets import *; imagelist2folder()
  774. # Copies all the images in a text file (list of images) into a folder
  775. create_folder(path[:-4])
  776. with open(path, 'r') as f:
  777. for line in f.read().splitlines():
  778. os.system('cp "%s" %s' % (line, path[:-4]))
  779. print(line)
  780. def create_folder(path='./new'):
  781. # Create folder
  782. if os.path.exists(path):
  783. shutil.rmtree(path) # delete output folder
  784. os.makedirs(path) # make new output folder