You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

912 lines
37KB

  1. import glob
  2. import math
  3. import os
  4. import random
  5. import shutil
  6. import time
  7. from pathlib import Path
  8. from threading import Thread
  9. import cv2
  10. import numpy as np
  11. import torch
  12. from PIL import Image, ExifTags
  13. from torch.utils.data import Dataset
  14. from tqdm import tqdm
  15. from utils.general import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first
  16. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  17. img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
  18. vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
  19. # Get orientation exif tag
  20. for orientation in ExifTags.TAGS.keys():
  21. if ExifTags.TAGS[orientation] == 'Orientation':
  22. break
  23. def get_hash(files):
  24. # Returns a single hash value of a list of files
  25. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  26. def exif_size(img):
  27. # Returns exif-corrected PIL size
  28. s = img.size # (width, height)
  29. try:
  30. rotation = dict(img._getexif().items())[orientation]
  31. if rotation == 6: # rotation 270
  32. s = (s[1], s[0])
  33. elif rotation == 8: # rotation 90
  34. s = (s[1], s[0])
  35. except:
  36. pass
  37. return s
  38. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  39. rank=-1, world_size=1, workers=8):
  40. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
  41. with torch_distributed_zero_first(rank):
  42. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  43. augment=augment, # augment images
  44. hyp=hyp, # augmentation hyperparameters
  45. rect=rect, # rectangular training
  46. cache_images=cache,
  47. single_cls=opt.single_cls,
  48. stride=int(stride),
  49. pad=pad,
  50. rank=rank)
  51. batch_size = min(batch_size, len(dataset))
  52. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  53. train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  54. dataloader = torch.utils.data.DataLoader(dataset,
  55. batch_size=batch_size,
  56. num_workers=nw,
  57. sampler=train_sampler,
  58. pin_memory=True,
  59. collate_fn=LoadImagesAndLabels.collate_fn)
  60. return dataloader, dataset
  61. class LoadImages: # for inference
  62. def __init__(self, path, img_size=640):
  63. p = str(Path(path)) # os-agnostic
  64. p = os.path.abspath(p) # absolute path
  65. if '*' in p:
  66. files = sorted(glob.glob(p)) # glob
  67. elif os.path.isdir(p):
  68. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  69. elif os.path.isfile(p):
  70. files = [p] # files
  71. else:
  72. raise Exception('ERROR: %s does not exist' % p)
  73. images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
  74. videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
  75. ni, nv = len(images), len(videos)
  76. self.img_size = img_size
  77. self.files = images + videos
  78. self.nf = ni + nv # number of files
  79. self.video_flag = [False] * ni + [True] * nv
  80. self.mode = 'images'
  81. if any(videos):
  82. self.new_video(videos[0]) # new video
  83. else:
  84. self.cap = None
  85. assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
  86. (p, img_formats, vid_formats)
  87. def __iter__(self):
  88. self.count = 0
  89. return self
  90. def __next__(self):
  91. if self.count == self.nf:
  92. raise StopIteration
  93. path = self.files[self.count]
  94. if self.video_flag[self.count]:
  95. # Read video
  96. self.mode = 'video'
  97. ret_val, img0 = self.cap.read()
  98. if not ret_val:
  99. self.count += 1
  100. self.cap.release()
  101. if self.count == self.nf: # last video
  102. raise StopIteration
  103. else:
  104. path = self.files[self.count]
  105. self.new_video(path)
  106. ret_val, img0 = self.cap.read()
  107. self.frame += 1
  108. print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')
  109. else:
  110. # Read image
  111. self.count += 1
  112. img0 = cv2.imread(path) # BGR
  113. assert img0 is not None, 'Image Not Found ' + path
  114. print('image %g/%g %s: ' % (self.count, self.nf, path), end='')
  115. # Padded resize
  116. img = letterbox(img0, new_shape=self.img_size)[0]
  117. # Convert
  118. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  119. img = np.ascontiguousarray(img)
  120. # cv2.imwrite(path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
  121. return path, img, img0, self.cap
  122. def new_video(self, path):
  123. self.frame = 0
  124. self.cap = cv2.VideoCapture(path)
  125. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  126. def __len__(self):
  127. return self.nf # number of files
  128. class LoadWebcam: # for inference
  129. def __init__(self, pipe=0, img_size=640):
  130. self.img_size = img_size
  131. if pipe == '0':
  132. pipe = 0 # local camera
  133. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  134. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  135. # pipe = 'rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa' # IP traffic camera
  136. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  137. # https://answers.opencv.org/question/215996/changing-gstreamer-pipeline-to-opencv-in-pythonsolved/
  138. # pipe = '"rtspsrc location="rtsp://username:password@192.168.1.64/1" latency=10 ! appsink' # GStreamer
  139. # https://answers.opencv.org/question/200787/video-acceleration-gstremer-pipeline-in-videocapture/
  140. # https://stackoverflow.com/questions/54095699/install-gstreamer-support-for-opencv-python-package # install help
  141. # pipe = "rtspsrc location=rtsp://root:root@192.168.0.91:554/axis-media/media.amp?videocodec=h264&resolution=3840x2160 protocols=GST_RTSP_LOWER_TRANS_TCP ! rtph264depay ! queue ! vaapih264dec ! videoconvert ! appsink" # GStreamer
  142. self.pipe = pipe
  143. self.cap = cv2.VideoCapture(pipe) # video capture object
  144. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  145. def __iter__(self):
  146. self.count = -1
  147. return self
  148. def __next__(self):
  149. self.count += 1
  150. if cv2.waitKey(1) == ord('q'): # q to quit
  151. self.cap.release()
  152. cv2.destroyAllWindows()
  153. raise StopIteration
  154. # Read frame
  155. if self.pipe == 0: # local camera
  156. ret_val, img0 = self.cap.read()
  157. img0 = cv2.flip(img0, 1) # flip left-right
  158. else: # IP camera
  159. n = 0
  160. while True:
  161. n += 1
  162. self.cap.grab()
  163. if n % 30 == 0: # skip frames
  164. ret_val, img0 = self.cap.retrieve()
  165. if ret_val:
  166. break
  167. # Print
  168. assert ret_val, 'Camera Error %s' % self.pipe
  169. img_path = 'webcam.jpg'
  170. print('webcam %g: ' % self.count, end='')
  171. # Padded resize
  172. img = letterbox(img0, new_shape=self.img_size)[0]
  173. # Convert
  174. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  175. img = np.ascontiguousarray(img)
  176. return img_path, img, img0, None
  177. def __len__(self):
  178. return 0
  179. class LoadStreams: # multiple IP or RTSP cameras
  180. def __init__(self, sources='streams.txt', img_size=640):
  181. self.mode = 'images'
  182. self.img_size = img_size
  183. if os.path.isfile(sources):
  184. with open(sources, 'r') as f:
  185. sources = [x.strip() for x in f.read().splitlines() if len(x.strip())]
  186. else:
  187. sources = [sources]
  188. n = len(sources)
  189. self.imgs = [None] * n
  190. self.sources = sources
  191. for i, s in enumerate(sources):
  192. # Start the thread to read frames from the video stream
  193. print('%g/%g: %s... ' % (i + 1, n, s), end='')
  194. cap = cv2.VideoCapture(0 if s == '0' else s)
  195. assert cap.isOpened(), 'Failed to open %s' % s
  196. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  197. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  198. fps = cap.get(cv2.CAP_PROP_FPS) % 100
  199. _, self.imgs[i] = cap.read() # guarantee first frame
  200. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  201. print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
  202. thread.start()
  203. print('') # newline
  204. # check for common shapes
  205. s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
  206. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  207. if not self.rect:
  208. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  209. def update(self, index, cap):
  210. # Read next stream frame in a daemon thread
  211. n = 0
  212. while cap.isOpened():
  213. n += 1
  214. # _, self.imgs[index] = cap.read()
  215. cap.grab()
  216. if n == 4: # read every 4th frame
  217. _, self.imgs[index] = cap.retrieve()
  218. n = 0
  219. time.sleep(0.01) # wait time
  220. def __iter__(self):
  221. self.count = -1
  222. return self
  223. def __next__(self):
  224. self.count += 1
  225. img0 = self.imgs.copy()
  226. if cv2.waitKey(1) == ord('q'): # q to quit
  227. cv2.destroyAllWindows()
  228. raise StopIteration
  229. # Letterbox
  230. img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
  231. # Stack
  232. img = np.stack(img, 0)
  233. # Convert
  234. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  235. img = np.ascontiguousarray(img)
  236. return self.sources, img, img0, None
  237. def __len__(self):
  238. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  239. class LoadImagesAndLabels(Dataset): # for training/testing
  240. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  241. cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
  242. try:
  243. f = [] # image files
  244. for p in path if isinstance(path, list) else [path]:
  245. p = str(Path(p)) # os-agnostic
  246. parent = str(Path(p).parent) + os.sep
  247. if os.path.isfile(p): # file
  248. with open(p, 'r') as t:
  249. t = t.read().splitlines()
  250. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  251. elif os.path.isdir(p): # folder
  252. f += glob.iglob(p + os.sep + '*.*')
  253. else:
  254. raise Exception('%s does not exist' % p)
  255. self.img_files = sorted(
  256. [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
  257. except Exception as e:
  258. raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
  259. n = len(self.img_files)
  260. assert n > 0, 'No images found in %s. See %s' % (path, help_url)
  261. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  262. nb = bi[-1] + 1 # number of batches
  263. self.n = n # number of images
  264. self.batch = bi # batch index of image
  265. self.img_size = img_size
  266. self.augment = augment
  267. self.hyp = hyp
  268. self.image_weights = image_weights
  269. self.rect = False if image_weights else rect
  270. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  271. self.mosaic_border = [-img_size // 2, -img_size // 2]
  272. self.stride = stride
  273. # Define labels
  274. self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in
  275. self.img_files]
  276. # Check cache
  277. cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
  278. if os.path.isfile(cache_path):
  279. cache = torch.load(cache_path) # load
  280. if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
  281. cache = self.cache_labels(cache_path) # re-cache
  282. else:
  283. cache = self.cache_labels(cache_path) # cache
  284. # Get labels
  285. labels, shapes = zip(*[cache[x] for x in self.img_files])
  286. self.shapes = np.array(shapes, dtype=np.float64)
  287. self.labels = list(labels)
  288. # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
  289. if self.rect:
  290. # Sort by aspect ratio
  291. s = self.shapes # wh
  292. ar = s[:, 1] / s[:, 0] # aspect ratio
  293. irect = ar.argsort()
  294. self.img_files = [self.img_files[i] for i in irect]
  295. self.label_files = [self.label_files[i] for i in irect]
  296. self.labels = [self.labels[i] for i in irect]
  297. self.shapes = s[irect] # wh
  298. ar = ar[irect]
  299. # Set training image shapes
  300. shapes = [[1, 1]] * nb
  301. for i in range(nb):
  302. ari = ar[bi == i]
  303. mini, maxi = ari.min(), ari.max()
  304. if maxi < 1:
  305. shapes[i] = [maxi, 1]
  306. elif mini > 1:
  307. shapes[i] = [1, 1 / mini]
  308. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  309. # Cache labels
  310. create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
  311. nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
  312. pbar = enumerate(self.label_files)
  313. if rank in [-1, 0]:
  314. pbar = tqdm(pbar)
  315. for i, file in pbar:
  316. l = self.labels[i] # label
  317. if l is not None and l.shape[0]:
  318. assert l.shape[1] == 5, '> 5 label columns: %s' % file
  319. assert (l >= 0).all(), 'negative labels: %s' % file
  320. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
  321. if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
  322. nd += 1 # print('WARNING: duplicate rows in %s' % self.label_files[i]) # duplicate rows
  323. if single_cls:
  324. l[:, 0] = 0 # force dataset into single-class mode
  325. self.labels[i] = l
  326. nf += 1 # file found
  327. # Create subdataset (a smaller dataset)
  328. if create_datasubset and ns < 1E4:
  329. if ns == 0:
  330. create_folder(path='./datasubset')
  331. os.makedirs('./datasubset/images')
  332. exclude_classes = 43
  333. if exclude_classes not in l[:, 0]:
  334. ns += 1
  335. # shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
  336. with open('./datasubset/images.txt', 'a') as f:
  337. f.write(self.img_files[i] + '\n')
  338. # Extract object detection boxes for a second stage classifier
  339. if extract_bounding_boxes:
  340. p = Path(self.img_files[i])
  341. img = cv2.imread(str(p))
  342. h, w = img.shape[:2]
  343. for j, x in enumerate(l):
  344. f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
  345. if not os.path.exists(Path(f).parent):
  346. os.makedirs(Path(f).parent) # make new output folder
  347. b = x[1:] * [w, h, w, h] # box
  348. b[2:] = b[2:].max() # rectangle to square
  349. b[2:] = b[2:] * 1.3 + 30 # pad
  350. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  351. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  352. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  353. assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
  354. else:
  355. ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
  356. # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
  357. if rank in [-1, 0]:
  358. pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
  359. cache_path, nf, nm, ne, nd, n)
  360. if nf == 0:
  361. s = 'WARNING: No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
  362. print(s)
  363. assert not augment, '%s. Can not train without labels.' % s
  364. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  365. self.imgs = [None] * n
  366. if cache_images:
  367. gb = 0 # Gigabytes of cached images
  368. pbar = tqdm(range(len(self.img_files)), desc='Caching images')
  369. self.img_hw0, self.img_hw = [None] * n, [None] * n
  370. for i in pbar: # max 10k images
  371. self.imgs[i], self.img_hw0[i], self.img_hw[i] = load_image(self, i) # img, hw_original, hw_resized
  372. gb += self.imgs[i].nbytes
  373. pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
  374. def cache_labels(self, path='labels.cache'):
  375. # Cache dataset labels, check images and read shapes
  376. x = {} # dict
  377. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  378. for (img, label) in pbar:
  379. try:
  380. l = []
  381. image = Image.open(img)
  382. image.verify() # PIL verify
  383. # _ = io.imread(img) # skimage verify (from skimage import io)
  384. shape = exif_size(image) # image size
  385. assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
  386. if os.path.isfile(label):
  387. with open(label, 'r') as f:
  388. l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels
  389. if len(l) == 0:
  390. l = np.zeros((0, 5), dtype=np.float32)
  391. x[img] = [l, shape]
  392. except Exception as e:
  393. x[img] = [None, None]
  394. print('WARNING: %s: %s' % (img, e))
  395. x['hash'] = get_hash(self.label_files + self.img_files)
  396. torch.save(x, path) # save for next time
  397. return x
  398. def __len__(self):
  399. return len(self.img_files)
  400. # def __iter__(self):
  401. # self.count = -1
  402. # print('ran dataset iter')
  403. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  404. # return self
  405. def __getitem__(self, index):
  406. if self.image_weights:
  407. index = self.indices[index]
  408. hyp = self.hyp
  409. if self.mosaic:
  410. # Load mosaic
  411. img, labels = load_mosaic(self, index)
  412. shapes = None
  413. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  414. if random.random() < hyp['mixup']:
  415. img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
  416. r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
  417. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  418. labels = np.concatenate((labels, labels2), 0)
  419. else:
  420. # Load image
  421. img, (h0, w0), (h, w) = load_image(self, index)
  422. # Letterbox
  423. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  424. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  425. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  426. # Load labels
  427. labels = []
  428. x = self.labels[index]
  429. if x.size > 0:
  430. # Normalized xywh to pixel xyxy format
  431. labels = x.copy()
  432. labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
  433. labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
  434. labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
  435. labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
  436. if self.augment:
  437. # Augment imagespace
  438. if not self.mosaic:
  439. img, labels = random_perspective(img, labels,
  440. degrees=hyp['degrees'],
  441. translate=hyp['translate'],
  442. scale=hyp['scale'],
  443. shear=hyp['shear'],
  444. perspective=hyp['perspective'])
  445. # Augment colorspace
  446. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  447. # Apply cutouts
  448. # if random.random() < 0.9:
  449. # labels = cutout(img, labels)
  450. nL = len(labels) # number of labels
  451. if nL:
  452. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  453. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  454. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  455. if self.augment:
  456. # flip up-down
  457. if random.random() < hyp['flipud']:
  458. img = np.flipud(img)
  459. if nL:
  460. labels[:, 2] = 1 - labels[:, 2]
  461. # flip left-right
  462. if random.random() < hyp['fliplr']:
  463. img = np.fliplr(img)
  464. if nL:
  465. labels[:, 1] = 1 - labels[:, 1]
  466. labels_out = torch.zeros((nL, 6))
  467. if nL:
  468. labels_out[:, 1:] = torch.from_numpy(labels)
  469. # Convert
  470. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  471. img = np.ascontiguousarray(img)
  472. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  473. @staticmethod
  474. def collate_fn(batch):
  475. img, label, path, shapes = zip(*batch) # transposed
  476. for i, l in enumerate(label):
  477. l[:, 0] = i # add target image index for build_targets()
  478. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  479. # Ancillary functions --------------------------------------------------------------------------------------------------
  480. def load_image(self, index):
  481. # loads 1 image from dataset, returns img, original hw, resized hw
  482. img = self.imgs[index]
  483. if img is None: # not cached
  484. path = self.img_files[index]
  485. img = cv2.imread(path) # BGR
  486. assert img is not None, 'Image Not Found ' + path
  487. h0, w0 = img.shape[:2] # orig hw
  488. r = self.img_size / max(h0, w0) # resize image to img_size
  489. if r != 1: # always resize down, only resize up if training with augmentation
  490. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  491. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  492. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  493. else:
  494. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  495. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  496. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  497. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  498. dtype = img.dtype # uint8
  499. x = np.arange(0, 256, dtype=np.int16)
  500. lut_hue = ((x * r[0]) % 180).astype(dtype)
  501. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  502. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  503. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  504. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  505. # Histogram equalization
  506. # if random.random() < 0.2:
  507. # for i in range(3):
  508. # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
  509. def load_mosaic(self, index):
  510. # loads images in a mosaic
  511. labels4 = []
  512. s = self.img_size
  513. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  514. indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
  515. for i, index in enumerate(indices):
  516. # Load image
  517. img, _, (h, w) = load_image(self, index)
  518. # place img in img4
  519. if i == 0: # top left
  520. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  521. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  522. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  523. elif i == 1: # top right
  524. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  525. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  526. elif i == 2: # bottom left
  527. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  528. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(y2a - y1a, h)
  529. elif i == 3: # bottom right
  530. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  531. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  532. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  533. padw = x1a - x1b
  534. padh = y1a - y1b
  535. # Labels
  536. x = self.labels[index]
  537. labels = x.copy()
  538. if x.size > 0: # Normalized xywh to pixel xyxy format
  539. labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
  540. labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
  541. labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
  542. labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
  543. labels4.append(labels)
  544. # Concat/clip labels
  545. if len(labels4):
  546. labels4 = np.concatenate(labels4, 0)
  547. # np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop
  548. np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine
  549. # Replicate
  550. # img4, labels4 = replicate(img4, labels4)
  551. # Augment
  552. # img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
  553. img4, labels4 = random_perspective(img4, labels4,
  554. degrees=self.hyp['degrees'],
  555. translate=self.hyp['translate'],
  556. scale=self.hyp['scale'],
  557. shear=self.hyp['shear'],
  558. perspective=self.hyp['perspective'],
  559. border=self.mosaic_border) # border to remove
  560. return img4, labels4
  561. def replicate(img, labels):
  562. # Replicate labels
  563. h, w = img.shape[:2]
  564. boxes = labels[:, 1:].astype(int)
  565. x1, y1, x2, y2 = boxes.T
  566. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  567. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  568. x1b, y1b, x2b, y2b = boxes[i]
  569. bh, bw = y2b - y1b, x2b - x1b
  570. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  571. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  572. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  573. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  574. return img, labels
  575. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
  576. # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
  577. shape = img.shape[:2] # current shape [height, width]
  578. if isinstance(new_shape, int):
  579. new_shape = (new_shape, new_shape)
  580. # Scale ratio (new / old)
  581. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  582. if not scaleup: # only scale down, do not scale up (for better test mAP)
  583. r = min(r, 1.0)
  584. # Compute padding
  585. ratio = r, r # width, height ratios
  586. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  587. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  588. if auto: # minimum rectangle
  589. dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
  590. elif scaleFill: # stretch
  591. dw, dh = 0.0, 0.0
  592. new_unpad = (new_shape[1], new_shape[0])
  593. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  594. dw /= 2 # divide padding into 2 sides
  595. dh /= 2
  596. if shape[::-1] != new_unpad: # resize
  597. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  598. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  599. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  600. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  601. return img, ratio, (dw, dh)
  602. def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, border=(0, 0)):
  603. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  604. # targets = [cls, xyxy]
  605. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  606. width = img.shape[1] + border[1] * 2
  607. # Center
  608. C = np.eye(3)
  609. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  610. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  611. # Perspective
  612. P = np.eye(3)
  613. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  614. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  615. # Rotation and Scale
  616. R = np.eye(3)
  617. a = random.uniform(-degrees, degrees)
  618. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  619. s = random.uniform(1 - scale, 1 + scale)
  620. # s = 2 ** random.uniform(-scale, scale)
  621. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  622. # Shear
  623. S = np.eye(3)
  624. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  625. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  626. # Translation
  627. T = np.eye(3)
  628. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  629. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  630. # Combined rotation matrix
  631. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  632. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  633. if perspective:
  634. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  635. else: # affine
  636. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  637. # Visualize
  638. # import matplotlib.pyplot as plt
  639. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  640. # ax[0].imshow(img[:, :, ::-1]) # base
  641. # ax[1].imshow(img2[:, :, ::-1]) # warped
  642. # Transform label coordinates
  643. n = len(targets)
  644. if n:
  645. # warp points
  646. xy = np.ones((n * 4, 3))
  647. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  648. xy = xy @ M.T # transform
  649. if perspective:
  650. xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
  651. else: # affine
  652. xy = xy[:, :2].reshape(n, 8)
  653. # create new boxes
  654. x = xy[:, [0, 2, 4, 6]]
  655. y = xy[:, [1, 3, 5, 7]]
  656. xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  657. # # apply angle-based reduction of bounding boxes
  658. # radians = a * math.pi / 180
  659. # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
  660. # x = (xy[:, 2] + xy[:, 0]) / 2
  661. # y = (xy[:, 3] + xy[:, 1]) / 2
  662. # w = (xy[:, 2] - xy[:, 0]) * reduction
  663. # h = (xy[:, 3] - xy[:, 1]) * reduction
  664. # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  665. # clip boxes
  666. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  667. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  668. # filter candidates
  669. i = box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T)
  670. targets = targets[i]
  671. targets[:, 1:5] = xy[i]
  672. return img, targets
  673. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1): # box1(4,n), box2(4,n)
  674. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  675. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  676. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  677. ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
  678. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates
  679. def cutout(image, labels):
  680. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  681. h, w = image.shape[:2]
  682. def bbox_ioa(box1, box2):
  683. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  684. box2 = box2.transpose()
  685. # Get the coordinates of bounding boxes
  686. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  687. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  688. # Intersection area
  689. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  690. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  691. # box2 area
  692. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  693. # Intersection over box2 area
  694. return inter_area / box2_area
  695. # create random masks
  696. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  697. for s in scales:
  698. mask_h = random.randint(1, int(h * s))
  699. mask_w = random.randint(1, int(w * s))
  700. # box
  701. xmin = max(0, random.randint(0, w) - mask_w // 2)
  702. ymin = max(0, random.randint(0, h) - mask_h // 2)
  703. xmax = min(w, xmin + mask_w)
  704. ymax = min(h, ymin + mask_h)
  705. # apply random color mask
  706. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  707. # return unobscured labels
  708. if len(labels) and s > 0.03:
  709. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  710. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  711. labels = labels[ioa < 0.60] # remove >60% obscured labels
  712. return labels
  713. def reduce_img_size(path='path/images', img_size=1024): # from utils.datasets import *; reduce_img_size()
  714. # creates a new ./images_reduced folder with reduced size images of maximum size img_size
  715. path_new = path + '_reduced' # reduced images path
  716. create_folder(path_new)
  717. for f in tqdm(glob.glob('%s/*.*' % path)):
  718. try:
  719. img = cv2.imread(f)
  720. h, w = img.shape[:2]
  721. r = img_size / max(h, w) # size ratio
  722. if r < 1.0:
  723. img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_AREA) # _LINEAR fastest
  724. fnew = f.replace(path, path_new) # .replace(Path(f).suffix, '.jpg')
  725. cv2.imwrite(fnew, img)
  726. except:
  727. print('WARNING: image failure %s' % f)
  728. def recursive_dataset2bmp(dataset='path/dataset_bmp'): # from utils.datasets import *; recursive_dataset2bmp()
  729. # Converts dataset to bmp (for faster training)
  730. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  731. for a, b, files in os.walk(dataset):
  732. for file in tqdm(files, desc=a):
  733. p = a + '/' + file
  734. s = Path(file).suffix
  735. if s == '.txt': # replace text
  736. with open(p, 'r') as f:
  737. lines = f.read()
  738. for f in formats:
  739. lines = lines.replace(f, '.bmp')
  740. with open(p, 'w') as f:
  741. f.write(lines)
  742. elif s in formats: # replace image
  743. cv2.imwrite(p.replace(s, '.bmp'), cv2.imread(p))
  744. if s != '.bmp':
  745. os.system("rm '%s'" % p)
  746. def imagelist2folder(path='path/images.txt'): # from utils.datasets import *; imagelist2folder()
  747. # Copies all the images in a text file (list of images) into a folder
  748. create_folder(path[:-4])
  749. with open(path, 'r') as f:
  750. for line in f.read().splitlines():
  751. os.system('cp "%s" %s' % (line, path[:-4]))
  752. print(line)
  753. def create_folder(path='./new'):
  754. # Create folder
  755. if os.path.exists(path):
  756. shutil.rmtree(path) # delete output folder
  757. os.makedirs(path) # make new output folder