You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 34KB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842
  1. import glob
  2. import math
  3. import os
  4. import random
  5. import shutil
  6. import time
  7. from pathlib import Path
  8. from threading import Thread
  9. import cv2
  10. import numpy as np
  11. import torch
  12. from PIL import Image, ExifTags
  13. from torch.utils.data import Dataset
  14. from tqdm import tqdm
  15. from utils.utils import xyxy2xywh, xywh2xyxy
  16. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  17. img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
  18. vid_formats = ['.mov', '.avi', '.mp4']
  19. # Get orientation exif tag
  20. for orientation in ExifTags.TAGS.keys():
  21. if ExifTags.TAGS[orientation] == 'Orientation':
  22. break
  23. def exif_size(img):
  24. # Returns exif-corrected PIL size
  25. s = img.size # (width, height)
  26. try:
  27. rotation = dict(img._getexif().items())[orientation]
  28. if rotation == 6: # rotation 270
  29. s = (s[1], s[0])
  30. elif rotation == 8: # rotation 90
  31. s = (s[1], s[0])
  32. except:
  33. pass
  34. return s
  35. class LoadImages: # for inference
  36. def __init__(self, path, img_size=416):
  37. path = str(Path(path)) # os-agnostic
  38. files = []
  39. if os.path.isdir(path):
  40. files = sorted(glob.glob(os.path.join(path, '*.*')))
  41. elif os.path.isfile(path):
  42. files = [path]
  43. images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
  44. videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
  45. nI, nV = len(images), len(videos)
  46. self.img_size = img_size
  47. self.files = images + videos
  48. self.nF = nI + nV # number of files
  49. self.video_flag = [False] * nI + [True] * nV
  50. self.mode = 'images'
  51. if any(videos):
  52. self.new_video(videos[0]) # new video
  53. else:
  54. self.cap = None
  55. assert self.nF > 0, 'No images or videos found in ' + path
  56. def __iter__(self):
  57. self.count = 0
  58. return self
  59. def __next__(self):
  60. if self.count == self.nF:
  61. raise StopIteration
  62. path = self.files[self.count]
  63. if self.video_flag[self.count]:
  64. # Read video
  65. self.mode = 'video'
  66. ret_val, img0 = self.cap.read()
  67. if not ret_val:
  68. self.count += 1
  69. self.cap.release()
  70. if self.count == self.nF: # last video
  71. raise StopIteration
  72. else:
  73. path = self.files[self.count]
  74. self.new_video(path)
  75. ret_val, img0 = self.cap.read()
  76. self.frame += 1
  77. print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nF, self.frame, self.nframes, path), end='')
  78. else:
  79. # Read image
  80. self.count += 1
  81. img0 = cv2.imread(path) # BGR
  82. assert img0 is not None, 'Image Not Found ' + path
  83. print('image %g/%g %s: ' % (self.count, self.nF, path), end='')
  84. # Padded resize
  85. img = letterbox(img0, new_shape=self.img_size)[0]
  86. # Convert
  87. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  88. img = np.ascontiguousarray(img)
  89. # cv2.imwrite(path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
  90. return path, img, img0, self.cap
  91. def new_video(self, path):
  92. self.frame = 0
  93. self.cap = cv2.VideoCapture(path)
  94. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  95. def __len__(self):
  96. return self.nF # number of files
  97. class LoadWebcam: # for inference
  98. def __init__(self, pipe=0, img_size=416):
  99. self.img_size = img_size
  100. if pipe == '0':
  101. pipe = 0 # local camera
  102. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  103. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  104. # pipe = 'rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa' # IP traffic camera
  105. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  106. # https://answers.opencv.org/question/215996/changing-gstreamer-pipeline-to-opencv-in-pythonsolved/
  107. # pipe = '"rtspsrc location="rtsp://username:password@192.168.1.64/1" latency=10 ! appsink' # GStreamer
  108. # https://answers.opencv.org/question/200787/video-acceleration-gstremer-pipeline-in-videocapture/
  109. # https://stackoverflow.com/questions/54095699/install-gstreamer-support-for-opencv-python-package # install help
  110. # pipe = "rtspsrc location=rtsp://root:root@192.168.0.91:554/axis-media/media.amp?videocodec=h264&resolution=3840x2160 protocols=GST_RTSP_LOWER_TRANS_TCP ! rtph264depay ! queue ! vaapih264dec ! videoconvert ! appsink" # GStreamer
  111. self.pipe = pipe
  112. self.cap = cv2.VideoCapture(pipe) # video capture object
  113. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  114. def __iter__(self):
  115. self.count = -1
  116. return self
  117. def __next__(self):
  118. self.count += 1
  119. if cv2.waitKey(1) == ord('q'): # q to quit
  120. self.cap.release()
  121. cv2.destroyAllWindows()
  122. raise StopIteration
  123. # Read frame
  124. if self.pipe == 0: # local camera
  125. ret_val, img0 = self.cap.read()
  126. img0 = cv2.flip(img0, 1) # flip left-right
  127. else: # IP camera
  128. n = 0
  129. while True:
  130. n += 1
  131. self.cap.grab()
  132. if n % 30 == 0: # skip frames
  133. ret_val, img0 = self.cap.retrieve()
  134. if ret_val:
  135. break
  136. # Print
  137. assert ret_val, 'Camera Error %s' % self.pipe
  138. img_path = 'webcam.jpg'
  139. print('webcam %g: ' % self.count, end='')
  140. # Padded resize
  141. img = letterbox(img0, new_shape=self.img_size)[0]
  142. # Convert
  143. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  144. img = np.ascontiguousarray(img)
  145. return img_path, img, img0, None
  146. def __len__(self):
  147. return 0
  148. class LoadStreams: # multiple IP or RTSP cameras
  149. def __init__(self, sources='streams.txt', img_size=416):
  150. self.mode = 'images'
  151. self.img_size = img_size
  152. if os.path.isfile(sources):
  153. with open(sources, 'r') as f:
  154. sources = [x.strip() for x in f.read().splitlines() if len(x.strip())]
  155. else:
  156. sources = [sources]
  157. n = len(sources)
  158. self.imgs = [None] * n
  159. self.sources = sources
  160. for i, s in enumerate(sources):
  161. # Start the thread to read frames from the video stream
  162. print('%g/%g: %s... ' % (i + 1, n, s), end='')
  163. cap = cv2.VideoCapture(0 if s == '0' else s)
  164. assert cap.isOpened(), 'Failed to open %s' % s
  165. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  166. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  167. fps = cap.get(cv2.CAP_PROP_FPS) % 100
  168. _, self.imgs[i] = cap.read() # guarantee first frame
  169. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  170. print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
  171. thread.start()
  172. print('') # newline
  173. # check for common shapes
  174. s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
  175. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  176. if not self.rect:
  177. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  178. def update(self, index, cap):
  179. # Read next stream frame in a daemon thread
  180. n = 0
  181. while cap.isOpened():
  182. n += 1
  183. # _, self.imgs[index] = cap.read()
  184. cap.grab()
  185. if n == 4: # read every 4th frame
  186. _, self.imgs[index] = cap.retrieve()
  187. n = 0
  188. time.sleep(0.01) # wait time
  189. def __iter__(self):
  190. self.count = -1
  191. return self
  192. def __next__(self):
  193. self.count += 1
  194. img0 = self.imgs.copy()
  195. if cv2.waitKey(1) == ord('q'): # q to quit
  196. cv2.destroyAllWindows()
  197. raise StopIteration
  198. # Letterbox
  199. img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
  200. # Stack
  201. img = np.stack(img, 0)
  202. # Convert
  203. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  204. img = np.ascontiguousarray(img)
  205. return self.sources, img, img0, None
  206. def __len__(self):
  207. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  208. class LoadImagesAndLabels(Dataset): # for training/testing
  209. def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  210. cache_images=False, single_cls=False, pad=0.0):
  211. try:
  212. path = str(Path(path)) # os-agnostic
  213. parent = str(Path(path).parent) + os.sep
  214. if os.path.isfile(path): # file
  215. with open(path, 'r') as f:
  216. f = f.read().splitlines()
  217. f = [x.replace('./', parent) if x.startswith('./') else x for x in f] # local to global path
  218. elif os.path.isdir(path): # folder
  219. f = glob.iglob(path + os.sep + '*.*')
  220. else:
  221. raise Exception('%s does not exist' % path)
  222. self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
  223. except:
  224. raise Exception('Error loading data from %s. See %s' % (path, help_url))
  225. n = len(self.img_files)
  226. assert n > 0, 'No images found in %s. See %s' % (path, help_url)
  227. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  228. nb = bi[-1] + 1 # number of batches
  229. self.n = n # number of images
  230. self.batch = bi # batch index of image
  231. self.img_size = img_size
  232. self.augment = augment
  233. self.hyp = hyp
  234. self.image_weights = image_weights
  235. self.rect = False if image_weights else rect
  236. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  237. # Define labels
  238. self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
  239. for x in self.img_files]
  240. # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
  241. if self.rect:
  242. # Read image shapes (wh)
  243. sp = path.replace('.txt', '') + '.shapes' # shapefile path
  244. try:
  245. with open(sp, 'r') as f: # read existing shapefile
  246. s = [x.split() for x in f.read().splitlines()]
  247. assert len(s) == n, 'Shapefile out of sync'
  248. except:
  249. s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
  250. np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)
  251. # Sort by aspect ratio
  252. s = np.array(s, dtype=np.float64)
  253. ar = s[:, 1] / s[:, 0] # aspect ratio
  254. irect = ar.argsort()
  255. self.img_files = [self.img_files[i] for i in irect]
  256. self.label_files = [self.label_files[i] for i in irect]
  257. self.shapes = s[irect] # wh
  258. ar = ar[irect]
  259. # Set training image shapes
  260. shapes = [[1, 1]] * nb
  261. for i in range(nb):
  262. ari = ar[bi == i]
  263. mini, maxi = ari.min(), ari.max()
  264. if maxi < 1:
  265. shapes[i] = [maxi, 1]
  266. elif mini > 1:
  267. shapes[i] = [1, 1 / mini]
  268. self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32. + pad).astype(np.int) * 32
  269. # Cache labels
  270. self.imgs = [None] * n
  271. self.labels = [np.zeros((0, 5), dtype=np.float32)] * n
  272. create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
  273. nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
  274. np_labels_path = str(Path(self.label_files[0]).parent) + '.npy' # saved labels in *.npy file
  275. if os.path.isfile(np_labels_path):
  276. s = np_labels_path # print string
  277. x = np.load(np_labels_path, allow_pickle=True)
  278. if len(x) == n:
  279. self.labels = x
  280. labels_loaded = True
  281. else:
  282. s = path.replace('images', 'labels')
  283. pbar = tqdm(self.label_files)
  284. for i, file in enumerate(pbar):
  285. if labels_loaded:
  286. l = self.labels[i]
  287. # np.savetxt(file, l, '%g') # save *.txt from *.npy file
  288. else:
  289. try:
  290. with open(file, 'r') as f:
  291. l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
  292. except:
  293. nm += 1 # print('missing labels for image %s' % self.img_files[i]) # file missing
  294. continue
  295. if l.shape[0]:
  296. assert l.shape[1] == 5, '> 5 label columns: %s' % file
  297. assert (l >= 0).all(), 'negative labels: %s' % file
  298. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
  299. if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
  300. nd += 1 # print('WARNING: duplicate rows in %s' % self.label_files[i]) # duplicate rows
  301. if single_cls:
  302. l[:, 0] = 0 # force dataset into single-class mode
  303. self.labels[i] = l
  304. nf += 1 # file found
  305. # Create subdataset (a smaller dataset)
  306. if create_datasubset and ns < 1E4:
  307. if ns == 0:
  308. create_folder(path='./datasubset')
  309. os.makedirs('./datasubset/images')
  310. exclude_classes = 43
  311. if exclude_classes not in l[:, 0]:
  312. ns += 1
  313. # shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
  314. with open('./datasubset/images.txt', 'a') as f:
  315. f.write(self.img_files[i] + '\n')
  316. # Extract object detection boxes for a second stage classifier
  317. if extract_bounding_boxes:
  318. p = Path(self.img_files[i])
  319. img = cv2.imread(str(p))
  320. h, w = img.shape[:2]
  321. for j, x in enumerate(l):
  322. f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
  323. if not os.path.exists(Path(f).parent):
  324. os.makedirs(Path(f).parent) # make new output folder
  325. b = x[1:] * [w, h, w, h] # box
  326. b[2:] = b[2:].max() # rectangle to square
  327. b[2:] = b[2:] * 1.3 + 30 # pad
  328. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  329. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  330. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  331. assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
  332. else:
  333. ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
  334. # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
  335. pbar.desc = 'Caching labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
  336. s, nf, nm, ne, nd, n)
  337. assert nf > 0 or n == 20288, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
  338. if not labels_loaded and n > 1000:
  339. print('Saving labels to %s for faster future loading' % np_labels_path)
  340. np.save(np_labels_path, self.labels) # save for next time
  341. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  342. if cache_images: # if training
  343. gb = 0 # Gigabytes of cached images
  344. pbar = tqdm(range(len(self.img_files)), desc='Caching images')
  345. self.img_hw0, self.img_hw = [None] * n, [None] * n
  346. for i in pbar: # max 10k images
  347. self.imgs[i], self.img_hw0[i], self.img_hw[i] = load_image(self, i) # img, hw_original, hw_resized
  348. gb += self.imgs[i].nbytes
  349. pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
  350. # Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3
  351. detect_corrupted_images = False
  352. if detect_corrupted_images:
  353. from skimage import io # conda install -c conda-forge scikit-image
  354. for file in tqdm(self.img_files, desc='Detecting corrupted images'):
  355. try:
  356. _ = io.imread(file)
  357. except:
  358. print('Corrupted image detected: %s' % file)
  359. def __len__(self):
  360. return len(self.img_files)
  361. # def __iter__(self):
  362. # self.count = -1
  363. # print('ran dataset iter')
  364. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  365. # return self
  366. def __getitem__(self, index):
  367. if self.image_weights:
  368. index = self.indices[index]
  369. hyp = self.hyp
  370. if self.mosaic:
  371. # Load mosaic
  372. img, labels = load_mosaic(self, index)
  373. shapes = None
  374. else:
  375. # Load image
  376. img, (h0, w0), (h, w) = load_image(self, index)
  377. # Letterbox
  378. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  379. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  380. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  381. # Load labels
  382. labels = []
  383. x = self.labels[index]
  384. if x.size > 0:
  385. # Normalized xywh to pixel xyxy format
  386. labels = x.copy()
  387. labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
  388. labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
  389. labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
  390. labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
  391. if self.augment:
  392. # Augment imagespace
  393. if not self.mosaic:
  394. img, labels = random_affine(img, labels,
  395. degrees=hyp['degrees'],
  396. translate=hyp['translate'],
  397. scale=hyp['scale'],
  398. shear=hyp['shear'])
  399. # Augment colorspace
  400. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  401. # Apply cutouts
  402. # if random.random() < 0.9:
  403. # labels = cutout(img, labels)
  404. nL = len(labels) # number of labels
  405. if nL:
  406. # convert xyxy to xywh
  407. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])
  408. # Normalize coordinates 0 - 1
  409. labels[:, [2, 4]] /= img.shape[0] # height
  410. labels[:, [1, 3]] /= img.shape[1] # width
  411. if self.augment:
  412. # random left-right flip
  413. lr_flip = True
  414. if lr_flip and random.random() < 0.5:
  415. img = np.fliplr(img)
  416. if nL:
  417. labels[:, 1] = 1 - labels[:, 1]
  418. # random up-down flip
  419. ud_flip = False
  420. if ud_flip and random.random() < 0.5:
  421. img = np.flipud(img)
  422. if nL:
  423. labels[:, 2] = 1 - labels[:, 2]
  424. labels_out = torch.zeros((nL, 6))
  425. if nL:
  426. labels_out[:, 1:] = torch.from_numpy(labels)
  427. # Convert
  428. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  429. img = np.ascontiguousarray(img)
  430. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  431. @staticmethod
  432. def collate_fn(batch):
  433. img, label, path, shapes = zip(*batch) # transposed
  434. for i, l in enumerate(label):
  435. l[:, 0] = i # add target image index for build_targets()
  436. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  437. def load_image(self, index):
  438. # loads 1 image from dataset, returns img, original hw, resized hw
  439. img = self.imgs[index]
  440. if img is None: # not cached
  441. path = self.img_files[index]
  442. img = cv2.imread(path) # BGR
  443. assert img is not None, 'Image Not Found ' + path
  444. h0, w0 = img.shape[:2] # orig hw
  445. r = self.img_size / max(h0, w0) # resize image to img_size
  446. if r != 1: # always resize down, only resize up if training with augmentation
  447. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  448. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  449. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  450. else:
  451. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  452. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  453. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  454. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  455. dtype = img.dtype # uint8
  456. x = np.arange(0, 256, dtype=np.int16)
  457. lut_hue = ((x * r[0]) % 180).astype(dtype)
  458. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  459. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  460. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  461. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  462. # Histogram equalization
  463. # if random.random() < 0.2:
  464. # for i in range(3):
  465. # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
  466. def load_mosaic(self, index):
  467. # loads images in a mosaic
  468. labels4 = []
  469. s = self.img_size
  470. xc, yc = [int(random.uniform(s * 0.5, s * 1.5)) for _ in range(2)] # mosaic center x, y
  471. indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
  472. for i, index in enumerate(indices):
  473. # Load image
  474. img, _, (h, w) = load_image(self, index)
  475. # place img in img4
  476. if i == 0: # top left
  477. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  478. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  479. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  480. elif i == 1: # top right
  481. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  482. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  483. elif i == 2: # bottom left
  484. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  485. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(y2a - y1a, h)
  486. elif i == 3: # bottom right
  487. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  488. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  489. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  490. padw = x1a - x1b
  491. padh = y1a - y1b
  492. # Labels
  493. x = self.labels[index]
  494. labels = x.copy()
  495. if x.size > 0: # Normalized xywh to pixel xyxy format
  496. labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
  497. labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
  498. labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
  499. labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
  500. labels4.append(labels)
  501. # Concat/clip labels
  502. if len(labels4):
  503. labels4 = np.concatenate(labels4, 0)
  504. # np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop
  505. np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine
  506. # Augment
  507. # img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
  508. img4, labels4 = random_affine(img4, labels4,
  509. degrees=self.hyp['degrees'],
  510. translate=self.hyp['translate'],
  511. scale=self.hyp['scale'],
  512. shear=self.hyp['shear'],
  513. border=-s // 2) # border to remove
  514. return img4, labels4
  515. def letterbox(img, new_shape=(416, 416), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
  516. # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
  517. shape = img.shape[:2] # current shape [height, width]
  518. if isinstance(new_shape, int):
  519. new_shape = (new_shape, new_shape)
  520. # Scale ratio (new / old)
  521. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  522. if not scaleup: # only scale down, do not scale up (for better test mAP)
  523. r = min(r, 1.0)
  524. # Compute padding
  525. ratio = r, r # width, height ratios
  526. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  527. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  528. if auto: # minimum rectangle
  529. dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
  530. elif scaleFill: # stretch
  531. dw, dh = 0.0, 0.0
  532. new_unpad = new_shape
  533. ratio = new_shape[0] / shape[1], new_shape[1] / shape[0] # width, height ratios
  534. dw /= 2 # divide padding into 2 sides
  535. dh /= 2
  536. if shape[::-1] != new_unpad: # resize
  537. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  538. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  539. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  540. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  541. return img, ratio, (dw, dh)
  542. def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, border=0):
  543. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  544. # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
  545. # targets = [cls, xyxy]
  546. height = img.shape[0] + border * 2
  547. width = img.shape[1] + border * 2
  548. # Rotation and Scale
  549. R = np.eye(3)
  550. a = random.uniform(-degrees, degrees)
  551. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  552. s = random.uniform(1 - scale, 1 + scale)
  553. # s = 2 ** random.uniform(-scale, scale)
  554. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s)
  555. # Translation
  556. T = np.eye(3)
  557. T[0, 2] = random.uniform(-translate, translate) * img.shape[0] + border # x translation (pixels)
  558. T[1, 2] = random.uniform(-translate, translate) * img.shape[1] + border # y translation (pixels)
  559. # Shear
  560. S = np.eye(3)
  561. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  562. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  563. # Combined rotation matrix
  564. M = S @ T @ R # ORDER IS IMPORTANT HERE!!
  565. if (border != 0) or (M != np.eye(3)).any(): # image changed
  566. img = cv2.warpAffine(img, M[:2], dsize=(width, height), flags=cv2.INTER_LINEAR, borderValue=(114, 114, 114))
  567. # Transform label coordinates
  568. n = len(targets)
  569. if n:
  570. # warp points
  571. xy = np.ones((n * 4, 3))
  572. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  573. xy = (xy @ M.T)[:, :2].reshape(n, 8)
  574. # create new boxes
  575. x = xy[:, [0, 2, 4, 6]]
  576. y = xy[:, [1, 3, 5, 7]]
  577. xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  578. # # apply angle-based reduction of bounding boxes
  579. # radians = a * math.pi / 180
  580. # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
  581. # x = (xy[:, 2] + xy[:, 0]) / 2
  582. # y = (xy[:, 3] + xy[:, 1]) / 2
  583. # w = (xy[:, 2] - xy[:, 0]) * reduction
  584. # h = (xy[:, 3] - xy[:, 1]) * reduction
  585. # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  586. # reject warped points outside of image
  587. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  588. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  589. w = xy[:, 2] - xy[:, 0]
  590. h = xy[:, 3] - xy[:, 1]
  591. area = w * h
  592. area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
  593. ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
  594. i = (w > 4) & (h > 4) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 10)
  595. targets = targets[i]
  596. targets[:, 1:5] = xy[i]
  597. return img, targets
  598. def cutout(image, labels):
  599. # https://arxiv.org/abs/1708.04552
  600. # https://github.com/hysts/pytorch_cutout/blob/master/dataloader.py
  601. # https://towardsdatascience.com/when-conventional-wisdom-fails-revisiting-data-augmentation-for-self-driving-cars-4831998c5509
  602. h, w = image.shape[:2]
  603. def bbox_ioa(box1, box2):
  604. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  605. box2 = box2.transpose()
  606. # Get the coordinates of bounding boxes
  607. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  608. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  609. # Intersection area
  610. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  611. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  612. # box2 area
  613. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  614. # Intersection over box2 area
  615. return inter_area / box2_area
  616. # create random masks
  617. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  618. for s in scales:
  619. mask_h = random.randint(1, int(h * s))
  620. mask_w = random.randint(1, int(w * s))
  621. # box
  622. xmin = max(0, random.randint(0, w) - mask_w // 2)
  623. ymin = max(0, random.randint(0, h) - mask_h // 2)
  624. xmax = min(w, xmin + mask_w)
  625. ymax = min(h, ymin + mask_h)
  626. # apply random color mask
  627. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  628. # return unobscured labels
  629. if len(labels) and s > 0.03:
  630. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  631. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  632. labels = labels[ioa < 0.60] # remove >60% obscured labels
  633. return labels
  634. def reduce_img_size(path='../data/sm4/images', img_size=1024): # from utils.datasets import *; reduce_img_size()
  635. # creates a new ./images_reduced folder with reduced size images of maximum size img_size
  636. path_new = path + '_reduced' # reduced images path
  637. create_folder(path_new)
  638. for f in tqdm(glob.glob('%s/*.*' % path)):
  639. try:
  640. img = cv2.imread(f)
  641. h, w = img.shape[:2]
  642. r = img_size / max(h, w) # size ratio
  643. if r < 1.0:
  644. img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_AREA) # _LINEAR fastest
  645. fnew = f.replace(path, path_new) # .replace(Path(f).suffix, '.jpg')
  646. cv2.imwrite(fnew, img)
  647. except:
  648. print('WARNING: image failure %s' % f)
  649. def convert_images2bmp(): # from utils.datasets import *; convert_images2bmp()
  650. # Save images
  651. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  652. # for path in ['../coco/images/val2014', '../coco/images/train2014']:
  653. for path in ['../data/sm4/images', '../data/sm4/background']:
  654. create_folder(path + 'bmp')
  655. for ext in formats: # ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
  656. for f in tqdm(glob.glob('%s/*%s' % (path, ext)), desc='Converting %s' % ext):
  657. cv2.imwrite(f.replace(ext.lower(), '.bmp').replace(path, path + 'bmp'), cv2.imread(f))
  658. # Save labels
  659. # for path in ['../coco/trainvalno5k.txt', '../coco/5k.txt']:
  660. for file in ['../data/sm4/out_train.txt', '../data/sm4/out_test.txt']:
  661. with open(file, 'r') as f:
  662. lines = f.read()
  663. # lines = f.read().replace('2014/', '2014bmp/') # coco
  664. lines = lines.replace('/images', '/imagesbmp')
  665. lines = lines.replace('/background', '/backgroundbmp')
  666. for ext in formats:
  667. lines = lines.replace(ext, '.bmp')
  668. with open(file.replace('.txt', 'bmp.txt'), 'w') as f:
  669. f.write(lines)
  670. def recursive_dataset2bmp(dataset='../data/sm4_bmp'): # from utils.datasets import *; recursive_dataset2bmp()
  671. # Converts dataset to bmp (for faster training)
  672. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  673. for a, b, files in os.walk(dataset):
  674. for file in tqdm(files, desc=a):
  675. p = a + '/' + file
  676. s = Path(file).suffix
  677. if s == '.txt': # replace text
  678. with open(p, 'r') as f:
  679. lines = f.read()
  680. for f in formats:
  681. lines = lines.replace(f, '.bmp')
  682. with open(p, 'w') as f:
  683. f.write(lines)
  684. elif s in formats: # replace image
  685. cv2.imwrite(p.replace(s, '.bmp'), cv2.imread(p))
  686. if s != '.bmp':
  687. os.system("rm '%s'" % p)
  688. def imagelist2folder(path='data/coco_64img.txt'): # from utils.datasets import *; imagelist2folder()
  689. # Copies all the images in a text file (list of images) into a folder
  690. create_folder(path[:-4])
  691. with open(path, 'r') as f:
  692. for line in f.read().splitlines():
  693. os.system('cp "%s" %s' % (line, path[:-4]))
  694. print(line)
  695. def create_folder(path='./new_folder'):
  696. # Create folder
  697. if os.path.exists(path):
  698. shutil.rmtree(path) # delete output folder
  699. os.makedirs(path) # make new output folder