You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 34KB

4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
4 vuotta sitten
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845
  1. import glob
  2. import math
  3. import os
  4. import random
  5. import shutil
  6. import time
  7. from pathlib import Path
  8. from threading import Thread
  9. import cv2
  10. import numpy as np
  11. import torch
  12. from PIL import Image, ExifTags
  13. from torch.utils.data import Dataset
  14. from tqdm import tqdm
  15. from utils.utils import xyxy2xywh, xywh2xyxy
  16. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  17. img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
  18. vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
  19. # Get orientation exif tag
  20. for orientation in ExifTags.TAGS.keys():
  21. if ExifTags.TAGS[orientation] == 'Orientation':
  22. break
  23. def exif_size(img):
  24. # Returns exif-corrected PIL size
  25. s = img.size # (width, height)
  26. try:
  27. rotation = dict(img._getexif().items())[orientation]
  28. if rotation == 6: # rotation 270
  29. s = (s[1], s[0])
  30. elif rotation == 8: # rotation 90
  31. s = (s[1], s[0])
  32. except:
  33. pass
  34. return s
  35. class LoadImages: # for inference
  36. def __init__(self, path, img_size=416):
  37. path = str(Path(path)) # os-agnostic
  38. files = []
  39. if os.path.isdir(path):
  40. files = sorted(glob.glob(os.path.join(path, '*.*')))
  41. elif os.path.isfile(path):
  42. files = [path]
  43. images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
  44. videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
  45. nI, nV = len(images), len(videos)
  46. self.img_size = img_size
  47. self.files = images + videos
  48. self.nF = nI + nV # number of files
  49. self.video_flag = [False] * nI + [True] * nV
  50. self.mode = 'images'
  51. if any(videos):
  52. self.new_video(videos[0]) # new video
  53. else:
  54. self.cap = None
  55. assert self.nF > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
  56. (path, img_formats, vid_formats)
  57. def __iter__(self):
  58. self.count = 0
  59. return self
  60. def __next__(self):
  61. if self.count == self.nF:
  62. raise StopIteration
  63. path = self.files[self.count]
  64. if self.video_flag[self.count]:
  65. # Read video
  66. self.mode = 'video'
  67. ret_val, img0 = self.cap.read()
  68. if not ret_val:
  69. self.count += 1
  70. self.cap.release()
  71. if self.count == self.nF: # last video
  72. raise StopIteration
  73. else:
  74. path = self.files[self.count]
  75. self.new_video(path)
  76. ret_val, img0 = self.cap.read()
  77. self.frame += 1
  78. print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nF, self.frame, self.nframes, path), end='')
  79. else:
  80. # Read image
  81. self.count += 1
  82. img0 = cv2.imread(path) # BGR
  83. assert img0 is not None, 'Image Not Found ' + path
  84. print('image %g/%g %s: ' % (self.count, self.nF, path), end='')
  85. # Padded resize
  86. img = letterbox(img0, new_shape=self.img_size)[0]
  87. # Convert
  88. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  89. img = np.ascontiguousarray(img)
  90. # cv2.imwrite(path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
  91. return path, img, img0, self.cap
  92. def new_video(self, path):
  93. self.frame = 0
  94. self.cap = cv2.VideoCapture(path)
  95. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  96. def __len__(self):
  97. return self.nF # number of files
  98. class LoadWebcam: # for inference
  99. def __init__(self, pipe=0, img_size=416):
  100. self.img_size = img_size
  101. if pipe == '0':
  102. pipe = 0 # local camera
  103. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  104. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  105. # pipe = 'rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa' # IP traffic camera
  106. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  107. # https://answers.opencv.org/question/215996/changing-gstreamer-pipeline-to-opencv-in-pythonsolved/
  108. # pipe = '"rtspsrc location="rtsp://username:password@192.168.1.64/1" latency=10 ! appsink' # GStreamer
  109. # https://answers.opencv.org/question/200787/video-acceleration-gstremer-pipeline-in-videocapture/
  110. # https://stackoverflow.com/questions/54095699/install-gstreamer-support-for-opencv-python-package # install help
  111. # pipe = "rtspsrc location=rtsp://root:root@192.168.0.91:554/axis-media/media.amp?videocodec=h264&resolution=3840x2160 protocols=GST_RTSP_LOWER_TRANS_TCP ! rtph264depay ! queue ! vaapih264dec ! videoconvert ! appsink" # GStreamer
  112. self.pipe = pipe
  113. self.cap = cv2.VideoCapture(pipe) # video capture object
  114. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  115. def __iter__(self):
  116. self.count = -1
  117. return self
  118. def __next__(self):
  119. self.count += 1
  120. if cv2.waitKey(1) == ord('q'): # q to quit
  121. self.cap.release()
  122. cv2.destroyAllWindows()
  123. raise StopIteration
  124. # Read frame
  125. if self.pipe == 0: # local camera
  126. ret_val, img0 = self.cap.read()
  127. img0 = cv2.flip(img0, 1) # flip left-right
  128. else: # IP camera
  129. n = 0
  130. while True:
  131. n += 1
  132. self.cap.grab()
  133. if n % 30 == 0: # skip frames
  134. ret_val, img0 = self.cap.retrieve()
  135. if ret_val:
  136. break
  137. # Print
  138. assert ret_val, 'Camera Error %s' % self.pipe
  139. img_path = 'webcam.jpg'
  140. print('webcam %g: ' % self.count, end='')
  141. # Padded resize
  142. img = letterbox(img0, new_shape=self.img_size)[0]
  143. # Convert
  144. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  145. img = np.ascontiguousarray(img)
  146. return img_path, img, img0, None
  147. def __len__(self):
  148. return 0
  149. class LoadStreams: # multiple IP or RTSP cameras
  150. def __init__(self, sources='streams.txt', img_size=416):
  151. self.mode = 'images'
  152. self.img_size = img_size
  153. if os.path.isfile(sources):
  154. with open(sources, 'r') as f:
  155. sources = [x.strip() for x in f.read().splitlines() if len(x.strip())]
  156. else:
  157. sources = [sources]
  158. n = len(sources)
  159. self.imgs = [None] * n
  160. self.sources = sources
  161. for i, s in enumerate(sources):
  162. # Start the thread to read frames from the video stream
  163. print('%g/%g: %s... ' % (i + 1, n, s), end='')
  164. cap = cv2.VideoCapture(0 if s == '0' else s)
  165. assert cap.isOpened(), 'Failed to open %s' % s
  166. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  167. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  168. fps = cap.get(cv2.CAP_PROP_FPS) % 100
  169. _, self.imgs[i] = cap.read() # guarantee first frame
  170. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  171. print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
  172. thread.start()
  173. print('') # newline
  174. # check for common shapes
  175. s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
  176. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  177. if not self.rect:
  178. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  179. def update(self, index, cap):
  180. # Read next stream frame in a daemon thread
  181. n = 0
  182. while cap.isOpened():
  183. n += 1
  184. # _, self.imgs[index] = cap.read()
  185. cap.grab()
  186. if n == 4: # read every 4th frame
  187. _, self.imgs[index] = cap.retrieve()
  188. n = 0
  189. time.sleep(0.01) # wait time
  190. def __iter__(self):
  191. self.count = -1
  192. return self
  193. def __next__(self):
  194. self.count += 1
  195. img0 = self.imgs.copy()
  196. if cv2.waitKey(1) == ord('q'): # q to quit
  197. cv2.destroyAllWindows()
  198. raise StopIteration
  199. # Letterbox
  200. img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
  201. # Stack
  202. img = np.stack(img, 0)
  203. # Convert
  204. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  205. img = np.ascontiguousarray(img)
  206. return self.sources, img, img0, None
  207. def __len__(self):
  208. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  209. class LoadImagesAndLabels(Dataset): # for training/testing
  210. def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  211. cache_images=False, single_cls=False, stride=32, pad=0.0):
  212. try:
  213. path = str(Path(path)) # os-agnostic
  214. parent = str(Path(path).parent) + os.sep
  215. if os.path.isfile(path): # file
  216. with open(path, 'r') as f:
  217. f = f.read().splitlines()
  218. f = [x.replace('./', parent) if x.startswith('./') else x for x in f] # local to global path
  219. elif os.path.isdir(path): # folder
  220. f = glob.iglob(path + os.sep + '*.*')
  221. else:
  222. raise Exception('%s does not exist' % path)
  223. self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
  224. except:
  225. raise Exception('Error loading data from %s. See %s' % (path, help_url))
  226. n = len(self.img_files)
  227. assert n > 0, 'No images found in %s. See %s' % (path, help_url)
  228. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  229. nb = bi[-1] + 1 # number of batches
  230. self.n = n # number of images
  231. self.batch = bi # batch index of image
  232. self.img_size = img_size
  233. self.augment = augment
  234. self.hyp = hyp
  235. self.image_weights = image_weights
  236. self.rect = False if image_weights else rect
  237. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  238. # Define labels
  239. self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
  240. for x in self.img_files]
  241. # Read image shapes (wh)
  242. sp = path.replace('.txt', '') + '.shapes' # shapefile path
  243. try:
  244. with open(sp, 'r') as f: # read existing shapefile
  245. s = [x.split() for x in f.read().splitlines()]
  246. assert len(s) == n, 'Shapefile out of sync'
  247. except:
  248. s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
  249. np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)
  250. self.shapes = np.array(s, dtype=np.float64)
  251. # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
  252. if self.rect:
  253. # Sort by aspect ratio
  254. s = self.shapes # wh
  255. ar = s[:, 1] / s[:, 0] # aspect ratio
  256. irect = ar.argsort()
  257. self.img_files = [self.img_files[i] for i in irect]
  258. self.label_files = [self.label_files[i] for i in irect]
  259. self.shapes = s[irect] # wh
  260. ar = ar[irect]
  261. # Set training image shapes
  262. shapes = [[1, 1]] * nb
  263. for i in range(nb):
  264. ari = ar[bi == i]
  265. mini, maxi = ari.min(), ari.max()
  266. if maxi < 1:
  267. shapes[i] = [maxi, 1]
  268. elif mini > 1:
  269. shapes[i] = [1, 1 / mini]
  270. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  271. # Cache labels
  272. self.imgs = [None] * n
  273. self.labels = [np.zeros((0, 5), dtype=np.float32)] * n
  274. create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
  275. nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
  276. np_labels_path = str(Path(self.label_files[0]).parent) + '.npy' # saved labels in *.npy file
  277. if os.path.isfile(np_labels_path):
  278. s = np_labels_path # print string
  279. x = np.load(np_labels_path, allow_pickle=True)
  280. if len(x) == n:
  281. self.labels = x
  282. labels_loaded = True
  283. else:
  284. s = path.replace('images', 'labels')
  285. pbar = tqdm(self.label_files)
  286. for i, file in enumerate(pbar):
  287. if labels_loaded:
  288. l = self.labels[i]
  289. # np.savetxt(file, l, '%g') # save *.txt from *.npy file
  290. else:
  291. try:
  292. with open(file, 'r') as f:
  293. l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
  294. except:
  295. nm += 1 # print('missing labels for image %s' % self.img_files[i]) # file missing
  296. continue
  297. if l.shape[0]:
  298. assert l.shape[1] == 5, '> 5 label columns: %s' % file
  299. assert (l >= 0).all(), 'negative labels: %s' % file
  300. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
  301. if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
  302. nd += 1 # print('WARNING: duplicate rows in %s' % self.label_files[i]) # duplicate rows
  303. if single_cls:
  304. l[:, 0] = 0 # force dataset into single-class mode
  305. self.labels[i] = l
  306. nf += 1 # file found
  307. # Create subdataset (a smaller dataset)
  308. if create_datasubset and ns < 1E4:
  309. if ns == 0:
  310. create_folder(path='./datasubset')
  311. os.makedirs('./datasubset/images')
  312. exclude_classes = 43
  313. if exclude_classes not in l[:, 0]:
  314. ns += 1
  315. # shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
  316. with open('./datasubset/images.txt', 'a') as f:
  317. f.write(self.img_files[i] + '\n')
  318. # Extract object detection boxes for a second stage classifier
  319. if extract_bounding_boxes:
  320. p = Path(self.img_files[i])
  321. img = cv2.imread(str(p))
  322. h, w = img.shape[:2]
  323. for j, x in enumerate(l):
  324. f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
  325. if not os.path.exists(Path(f).parent):
  326. os.makedirs(Path(f).parent) # make new output folder
  327. b = x[1:] * [w, h, w, h] # box
  328. b[2:] = b[2:].max() # rectangle to square
  329. b[2:] = b[2:] * 1.3 + 30 # pad
  330. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  331. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  332. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  333. assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
  334. else:
  335. ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
  336. # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
  337. pbar.desc = 'Caching labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
  338. s, nf, nm, ne, nd, n)
  339. assert nf > 0 or n == 20288, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
  340. if not labels_loaded and n > 1000:
  341. print('Saving labels to %s for faster future loading' % np_labels_path)
  342. np.save(np_labels_path, self.labels) # save for next time
  343. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  344. if cache_images: # if training
  345. gb = 0 # Gigabytes of cached images
  346. pbar = tqdm(range(len(self.img_files)), desc='Caching images')
  347. self.img_hw0, self.img_hw = [None] * n, [None] * n
  348. for i in pbar: # max 10k images
  349. self.imgs[i], self.img_hw0[i], self.img_hw[i] = load_image(self, i) # img, hw_original, hw_resized
  350. gb += self.imgs[i].nbytes
  351. pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
  352. # Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3
  353. detect_corrupted_images = False
  354. if detect_corrupted_images:
  355. from skimage import io # conda install -c conda-forge scikit-image
  356. for file in tqdm(self.img_files, desc='Detecting corrupted images'):
  357. try:
  358. _ = io.imread(file)
  359. except:
  360. print('Corrupted image detected: %s' % file)
  361. def __len__(self):
  362. return len(self.img_files)
  363. # def __iter__(self):
  364. # self.count = -1
  365. # print('ran dataset iter')
  366. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  367. # return self
  368. def __getitem__(self, index):
  369. if self.image_weights:
  370. index = self.indices[index]
  371. hyp = self.hyp
  372. if self.mosaic:
  373. # Load mosaic
  374. img, labels = load_mosaic(self, index)
  375. shapes = None
  376. else:
  377. # Load image
  378. img, (h0, w0), (h, w) = load_image(self, index)
  379. # Letterbox
  380. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  381. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  382. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  383. # Load labels
  384. labels = []
  385. x = self.labels[index]
  386. if x.size > 0:
  387. # Normalized xywh to pixel xyxy format
  388. labels = x.copy()
  389. labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
  390. labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
  391. labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
  392. labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
  393. if self.augment:
  394. # Augment imagespace
  395. if not self.mosaic:
  396. img, labels = random_affine(img, labels,
  397. degrees=hyp['degrees'],
  398. translate=hyp['translate'],
  399. scale=hyp['scale'],
  400. shear=hyp['shear'])
  401. # Augment colorspace
  402. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  403. # Apply cutouts
  404. # if random.random() < 0.9:
  405. # labels = cutout(img, labels)
  406. nL = len(labels) # number of labels
  407. if nL:
  408. # convert xyxy to xywh
  409. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])
  410. # Normalize coordinates 0 - 1
  411. labels[:, [2, 4]] /= img.shape[0] # height
  412. labels[:, [1, 3]] /= img.shape[1] # width
  413. if self.augment:
  414. # random left-right flip
  415. lr_flip = True
  416. if lr_flip and random.random() < 0.5:
  417. img = np.fliplr(img)
  418. if nL:
  419. labels[:, 1] = 1 - labels[:, 1]
  420. # random up-down flip
  421. ud_flip = False
  422. if ud_flip and random.random() < 0.5:
  423. img = np.flipud(img)
  424. if nL:
  425. labels[:, 2] = 1 - labels[:, 2]
  426. labels_out = torch.zeros((nL, 6))
  427. if nL:
  428. labels_out[:, 1:] = torch.from_numpy(labels)
  429. # Convert
  430. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  431. img = np.ascontiguousarray(img)
  432. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  433. @staticmethod
  434. def collate_fn(batch):
  435. img, label, path, shapes = zip(*batch) # transposed
  436. for i, l in enumerate(label):
  437. l[:, 0] = i # add target image index for build_targets()
  438. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  439. def load_image(self, index):
  440. # loads 1 image from dataset, returns img, original hw, resized hw
  441. img = self.imgs[index]
  442. if img is None: # not cached
  443. path = self.img_files[index]
  444. img = cv2.imread(path) # BGR
  445. assert img is not None, 'Image Not Found ' + path
  446. h0, w0 = img.shape[:2] # orig hw
  447. r = self.img_size / max(h0, w0) # resize image to img_size
  448. if r != 1: # always resize down, only resize up if training with augmentation
  449. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  450. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  451. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  452. else:
  453. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  454. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  455. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  456. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  457. dtype = img.dtype # uint8
  458. x = np.arange(0, 256, dtype=np.int16)
  459. lut_hue = ((x * r[0]) % 180).astype(dtype)
  460. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  461. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  462. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  463. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  464. # Histogram equalization
  465. # if random.random() < 0.2:
  466. # for i in range(3):
  467. # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
  468. def load_mosaic(self, index):
  469. # loads images in a mosaic
  470. labels4 = []
  471. s = self.img_size
  472. xc, yc = [int(random.uniform(s * 0.5, s * 1.5)) for _ in range(2)] # mosaic center x, y
  473. indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
  474. for i, index in enumerate(indices):
  475. # Load image
  476. img, _, (h, w) = load_image(self, index)
  477. # place img in img4
  478. if i == 0: # top left
  479. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  480. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  481. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  482. elif i == 1: # top right
  483. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  484. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  485. elif i == 2: # bottom left
  486. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  487. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(y2a - y1a, h)
  488. elif i == 3: # bottom right
  489. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  490. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  491. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  492. padw = x1a - x1b
  493. padh = y1a - y1b
  494. # Labels
  495. x = self.labels[index]
  496. labels = x.copy()
  497. if x.size > 0: # Normalized xywh to pixel xyxy format
  498. labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
  499. labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
  500. labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
  501. labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
  502. labels4.append(labels)
  503. # Concat/clip labels
  504. if len(labels4):
  505. labels4 = np.concatenate(labels4, 0)
  506. # np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop
  507. np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine
  508. # Augment
  509. # img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
  510. img4, labels4 = random_affine(img4, labels4,
  511. degrees=self.hyp['degrees'],
  512. translate=self.hyp['translate'],
  513. scale=self.hyp['scale'],
  514. shear=self.hyp['shear'],
  515. border=-s // 2) # border to remove
  516. return img4, labels4
  517. def letterbox(img, new_shape=(416, 416), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
  518. # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
  519. shape = img.shape[:2] # current shape [height, width]
  520. if isinstance(new_shape, int):
  521. new_shape = (new_shape, new_shape)
  522. # Scale ratio (new / old)
  523. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  524. if not scaleup: # only scale down, do not scale up (for better test mAP)
  525. r = min(r, 1.0)
  526. # Compute padding
  527. ratio = r, r # width, height ratios
  528. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  529. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  530. if auto: # minimum rectangle
  531. dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
  532. elif scaleFill: # stretch
  533. dw, dh = 0.0, 0.0
  534. new_unpad = new_shape
  535. ratio = new_shape[0] / shape[1], new_shape[1] / shape[0] # width, height ratios
  536. dw /= 2 # divide padding into 2 sides
  537. dh /= 2
  538. if shape[::-1] != new_unpad: # resize
  539. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  540. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  541. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  542. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  543. return img, ratio, (dw, dh)
  544. def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, border=0):
  545. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  546. # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
  547. # targets = [cls, xyxy]
  548. height = img.shape[0] + border * 2
  549. width = img.shape[1] + border * 2
  550. # Rotation and Scale
  551. R = np.eye(3)
  552. a = random.uniform(-degrees, degrees)
  553. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  554. s = random.uniform(1 - scale, 1 + scale)
  555. # s = 2 ** random.uniform(-scale, scale)
  556. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s)
  557. # Translation
  558. T = np.eye(3)
  559. T[0, 2] = random.uniform(-translate, translate) * img.shape[0] + border # x translation (pixels)
  560. T[1, 2] = random.uniform(-translate, translate) * img.shape[1] + border # y translation (pixels)
  561. # Shear
  562. S = np.eye(3)
  563. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  564. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  565. # Combined rotation matrix
  566. M = S @ T @ R # ORDER IS IMPORTANT HERE!!
  567. if (border != 0) or (M != np.eye(3)).any(): # image changed
  568. img = cv2.warpAffine(img, M[:2], dsize=(width, height), flags=cv2.INTER_LINEAR, borderValue=(114, 114, 114))
  569. # Transform label coordinates
  570. n = len(targets)
  571. if n:
  572. # warp points
  573. xy = np.ones((n * 4, 3))
  574. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  575. xy = (xy @ M.T)[:, :2].reshape(n, 8)
  576. # create new boxes
  577. x = xy[:, [0, 2, 4, 6]]
  578. y = xy[:, [1, 3, 5, 7]]
  579. xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  580. # # apply angle-based reduction of bounding boxes
  581. # radians = a * math.pi / 180
  582. # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
  583. # x = (xy[:, 2] + xy[:, 0]) / 2
  584. # y = (xy[:, 3] + xy[:, 1]) / 2
  585. # w = (xy[:, 2] - xy[:, 0]) * reduction
  586. # h = (xy[:, 3] - xy[:, 1]) * reduction
  587. # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  588. # reject warped points outside of image
  589. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  590. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  591. w = xy[:, 2] - xy[:, 0]
  592. h = xy[:, 3] - xy[:, 1]
  593. area = w * h
  594. area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
  595. ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
  596. i = (w > 4) & (h > 4) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 10)
  597. targets = targets[i]
  598. targets[:, 1:5] = xy[i]
  599. return img, targets
  600. def cutout(image, labels):
  601. # https://arxiv.org/abs/1708.04552
  602. # https://github.com/hysts/pytorch_cutout/blob/master/dataloader.py
  603. # https://towardsdatascience.com/when-conventional-wisdom-fails-revisiting-data-augmentation-for-self-driving-cars-4831998c5509
  604. h, w = image.shape[:2]
  605. def bbox_ioa(box1, box2):
  606. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  607. box2 = box2.transpose()
  608. # Get the coordinates of bounding boxes
  609. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  610. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  611. # Intersection area
  612. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  613. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  614. # box2 area
  615. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  616. # Intersection over box2 area
  617. return inter_area / box2_area
  618. # create random masks
  619. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  620. for s in scales:
  621. mask_h = random.randint(1, int(h * s))
  622. mask_w = random.randint(1, int(w * s))
  623. # box
  624. xmin = max(0, random.randint(0, w) - mask_w // 2)
  625. ymin = max(0, random.randint(0, h) - mask_h // 2)
  626. xmax = min(w, xmin + mask_w)
  627. ymax = min(h, ymin + mask_h)
  628. # apply random color mask
  629. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  630. # return unobscured labels
  631. if len(labels) and s > 0.03:
  632. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  633. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  634. labels = labels[ioa < 0.60] # remove >60% obscured labels
  635. return labels
  636. def reduce_img_size(path='../data/sm4/images', img_size=1024): # from utils.datasets import *; reduce_img_size()
  637. # creates a new ./images_reduced folder with reduced size images of maximum size img_size
  638. path_new = path + '_reduced' # reduced images path
  639. create_folder(path_new)
  640. for f in tqdm(glob.glob('%s/*.*' % path)):
  641. try:
  642. img = cv2.imread(f)
  643. h, w = img.shape[:2]
  644. r = img_size / max(h, w) # size ratio
  645. if r < 1.0:
  646. img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_AREA) # _LINEAR fastest
  647. fnew = f.replace(path, path_new) # .replace(Path(f).suffix, '.jpg')
  648. cv2.imwrite(fnew, img)
  649. except:
  650. print('WARNING: image failure %s' % f)
  651. def convert_images2bmp(): # from utils.datasets import *; convert_images2bmp()
  652. # Save images
  653. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  654. # for path in ['../coco/images/val2014', '../coco/images/train2014']:
  655. for path in ['../data/sm4/images', '../data/sm4/background']:
  656. create_folder(path + 'bmp')
  657. for ext in formats: # ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
  658. for f in tqdm(glob.glob('%s/*%s' % (path, ext)), desc='Converting %s' % ext):
  659. cv2.imwrite(f.replace(ext.lower(), '.bmp').replace(path, path + 'bmp'), cv2.imread(f))
  660. # Save labels
  661. # for path in ['../coco/trainvalno5k.txt', '../coco/5k.txt']:
  662. for file in ['../data/sm4/out_train.txt', '../data/sm4/out_test.txt']:
  663. with open(file, 'r') as f:
  664. lines = f.read()
  665. # lines = f.read().replace('2014/', '2014bmp/') # coco
  666. lines = lines.replace('/images', '/imagesbmp')
  667. lines = lines.replace('/background', '/backgroundbmp')
  668. for ext in formats:
  669. lines = lines.replace(ext, '.bmp')
  670. with open(file.replace('.txt', 'bmp.txt'), 'w') as f:
  671. f.write(lines)
  672. def recursive_dataset2bmp(dataset='../data/sm4_bmp'): # from utils.datasets import *; recursive_dataset2bmp()
  673. # Converts dataset to bmp (for faster training)
  674. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  675. for a, b, files in os.walk(dataset):
  676. for file in tqdm(files, desc=a):
  677. p = a + '/' + file
  678. s = Path(file).suffix
  679. if s == '.txt': # replace text
  680. with open(p, 'r') as f:
  681. lines = f.read()
  682. for f in formats:
  683. lines = lines.replace(f, '.bmp')
  684. with open(p, 'w') as f:
  685. f.write(lines)
  686. elif s in formats: # replace image
  687. cv2.imwrite(p.replace(s, '.bmp'), cv2.imread(p))
  688. if s != '.bmp':
  689. os.system("rm '%s'" % p)
  690. def imagelist2folder(path='data/coco_64img.txt'): # from utils.datasets import *; imagelist2folder()
  691. # Copies all the images in a text file (list of images) into a folder
  692. create_folder(path[:-4])
  693. with open(path, 'r') as f:
  694. for line in f.read().splitlines():
  695. os.system('cp "%s" %s' % (line, path[:-4]))
  696. print(line)
  697. def create_folder(path='./new_folder'):
  698. # Create folder
  699. if os.path.exists(path):
  700. shutil.rmtree(path) # delete output folder
  701. os.makedirs(path) # make new output folder