You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

889 lines
36KB

  1. import glob
  2. import math
  3. import os
  4. import random
  5. import shutil
  6. import time
  7. from pathlib import Path
  8. from threading import Thread
  9. import cv2
  10. import numpy as np
  11. import torch
  12. from PIL import Image, ExifTags
  13. from torch.utils.data import Dataset
  14. from tqdm import tqdm
  15. from utils.utils import xyxy2xywh, xywh2xyxy
  16. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  17. img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
  18. vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
  19. # Get orientation exif tag
  20. for orientation in ExifTags.TAGS.keys():
  21. if ExifTags.TAGS[orientation] == 'Orientation':
  22. break
  23. def get_hash(files):
  24. # Returns a single hash value of a list of files
  25. return sum(os.path.getsize(f) for f in files)
  26. def exif_size(img):
  27. # Returns exif-corrected PIL size
  28. s = img.size # (width, height)
  29. try:
  30. rotation = dict(img._getexif().items())[orientation]
  31. if rotation == 6: # rotation 270
  32. s = (s[1], s[0])
  33. elif rotation == 8: # rotation 90
  34. s = (s[1], s[0])
  35. except:
  36. pass
  37. return s
  38. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):
  39. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  40. augment=augment, # augment images
  41. hyp=hyp, # augmentation hyperparameters
  42. rect=rect, # rectangular training
  43. cache_images=cache,
  44. single_cls=opt.single_cls,
  45. stride=int(stride),
  46. pad=pad)
  47. batch_size = min(batch_size, len(dataset))
  48. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
  49. dataloader = torch.utils.data.DataLoader(dataset,
  50. batch_size=batch_size,
  51. num_workers=nw,
  52. pin_memory=True,
  53. collate_fn=LoadImagesAndLabels.collate_fn)
  54. return dataloader, dataset
  55. class LoadImages: # for inference
  56. def __init__(self, path, img_size=640):
  57. path = str(Path(path)) # os-agnostic
  58. files = []
  59. if os.path.isdir(path):
  60. files = sorted(glob.glob(os.path.join(path, '*.*')))
  61. elif os.path.isfile(path):
  62. files = [path]
  63. images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
  64. videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
  65. nI, nV = len(images), len(videos)
  66. self.img_size = img_size
  67. self.files = images + videos
  68. self.nF = nI + nV # number of files
  69. self.video_flag = [False] * nI + [True] * nV
  70. self.mode = 'images'
  71. if any(videos):
  72. self.new_video(videos[0]) # new video
  73. else:
  74. self.cap = None
  75. assert self.nF > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
  76. (path, img_formats, vid_formats)
  77. def __iter__(self):
  78. self.count = 0
  79. return self
  80. def __next__(self):
  81. if self.count == self.nF:
  82. raise StopIteration
  83. path = self.files[self.count]
  84. if self.video_flag[self.count]:
  85. # Read video
  86. self.mode = 'video'
  87. ret_val, img0 = self.cap.read()
  88. if not ret_val:
  89. self.count += 1
  90. self.cap.release()
  91. if self.count == self.nF: # last video
  92. raise StopIteration
  93. else:
  94. path = self.files[self.count]
  95. self.new_video(path)
  96. ret_val, img0 = self.cap.read()
  97. self.frame += 1
  98. print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nF, self.frame, self.nframes, path), end='')
  99. else:
  100. # Read image
  101. self.count += 1
  102. img0 = cv2.imread(path) # BGR
  103. assert img0 is not None, 'Image Not Found ' + path
  104. print('image %g/%g %s: ' % (self.count, self.nF, path), end='')
  105. # Padded resize
  106. img = letterbox(img0, new_shape=self.img_size)[0]
  107. # Convert
  108. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  109. img = np.ascontiguousarray(img)
  110. # cv2.imwrite(path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
  111. return path, img, img0, self.cap
  112. def new_video(self, path):
  113. self.frame = 0
  114. self.cap = cv2.VideoCapture(path)
  115. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  116. def __len__(self):
  117. return self.nF # number of files
  118. class LoadWebcam: # for inference
  119. def __init__(self, pipe=0, img_size=640):
  120. self.img_size = img_size
  121. if pipe == '0':
  122. pipe = 0 # local camera
  123. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  124. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  125. # pipe = 'rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa' # IP traffic camera
  126. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  127. # https://answers.opencv.org/question/215996/changing-gstreamer-pipeline-to-opencv-in-pythonsolved/
  128. # pipe = '"rtspsrc location="rtsp://username:password@192.168.1.64/1" latency=10 ! appsink' # GStreamer
  129. # https://answers.opencv.org/question/200787/video-acceleration-gstremer-pipeline-in-videocapture/
  130. # https://stackoverflow.com/questions/54095699/install-gstreamer-support-for-opencv-python-package # install help
  131. # pipe = "rtspsrc location=rtsp://root:root@192.168.0.91:554/axis-media/media.amp?videocodec=h264&resolution=3840x2160 protocols=GST_RTSP_LOWER_TRANS_TCP ! rtph264depay ! queue ! vaapih264dec ! videoconvert ! appsink" # GStreamer
  132. self.pipe = pipe
  133. self.cap = cv2.VideoCapture(pipe) # video capture object
  134. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  135. def __iter__(self):
  136. self.count = -1
  137. return self
  138. def __next__(self):
  139. self.count += 1
  140. if cv2.waitKey(1) == ord('q'): # q to quit
  141. self.cap.release()
  142. cv2.destroyAllWindows()
  143. raise StopIteration
  144. # Read frame
  145. if self.pipe == 0: # local camera
  146. ret_val, img0 = self.cap.read()
  147. img0 = cv2.flip(img0, 1) # flip left-right
  148. else: # IP camera
  149. n = 0
  150. while True:
  151. n += 1
  152. self.cap.grab()
  153. if n % 30 == 0: # skip frames
  154. ret_val, img0 = self.cap.retrieve()
  155. if ret_val:
  156. break
  157. # Print
  158. assert ret_val, 'Camera Error %s' % self.pipe
  159. img_path = 'webcam.jpg'
  160. print('webcam %g: ' % self.count, end='')
  161. # Padded resize
  162. img = letterbox(img0, new_shape=self.img_size)[0]
  163. # Convert
  164. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  165. img = np.ascontiguousarray(img)
  166. return img_path, img, img0, None
  167. def __len__(self):
  168. return 0
  169. class LoadStreams: # multiple IP or RTSP cameras
  170. def __init__(self, sources='streams.txt', img_size=640):
  171. self.mode = 'images'
  172. self.img_size = img_size
  173. if os.path.isfile(sources):
  174. with open(sources, 'r') as f:
  175. sources = [x.strip() for x in f.read().splitlines() if len(x.strip())]
  176. else:
  177. sources = [sources]
  178. n = len(sources)
  179. self.imgs = [None] * n
  180. self.sources = sources
  181. for i, s in enumerate(sources):
  182. # Start the thread to read frames from the video stream
  183. print('%g/%g: %s... ' % (i + 1, n, s), end='')
  184. cap = cv2.VideoCapture(0 if s == '0' else s)
  185. assert cap.isOpened(), 'Failed to open %s' % s
  186. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  187. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  188. fps = cap.get(cv2.CAP_PROP_FPS) % 100
  189. _, self.imgs[i] = cap.read() # guarantee first frame
  190. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  191. print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
  192. thread.start()
  193. print('') # newline
  194. # check for common shapes
  195. s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
  196. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  197. if not self.rect:
  198. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  199. def update(self, index, cap):
  200. # Read next stream frame in a daemon thread
  201. n = 0
  202. while cap.isOpened():
  203. n += 1
  204. # _, self.imgs[index] = cap.read()
  205. cap.grab()
  206. if n == 4: # read every 4th frame
  207. _, self.imgs[index] = cap.retrieve()
  208. n = 0
  209. time.sleep(0.01) # wait time
  210. def __iter__(self):
  211. self.count = -1
  212. return self
  213. def __next__(self):
  214. self.count += 1
  215. img0 = self.imgs.copy()
  216. if cv2.waitKey(1) == ord('q'): # q to quit
  217. cv2.destroyAllWindows()
  218. raise StopIteration
  219. # Letterbox
  220. img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
  221. # Stack
  222. img = np.stack(img, 0)
  223. # Convert
  224. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  225. img = np.ascontiguousarray(img)
  226. return self.sources, img, img0, None
  227. def __len__(self):
  228. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  229. class LoadImagesAndLabels(Dataset): # for training/testing
  230. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  231. cache_images=False, single_cls=False, stride=32, pad=0.0):
  232. try:
  233. f = [] # image files
  234. for p in path if isinstance(path, list) else [path]:
  235. p = str(Path(p)) # os-agnostic
  236. parent = str(Path(p).parent) + os.sep
  237. if os.path.isfile(p): # file
  238. with open(p, 'r') as t:
  239. t = t.read().splitlines()
  240. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  241. elif os.path.isdir(p): # folder
  242. f += glob.iglob(p + os.sep + '*.*')
  243. else:
  244. raise Exception('%s does not exist' % p)
  245. self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
  246. except Exception as e:
  247. raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
  248. n = len(self.img_files)
  249. assert n > 0, 'No images found in %s. See %s' % (path, help_url)
  250. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  251. nb = bi[-1] + 1 # number of batches
  252. self.n = n # number of images
  253. self.batch = bi # batch index of image
  254. self.img_size = img_size
  255. self.augment = augment
  256. self.hyp = hyp
  257. self.image_weights = image_weights
  258. self.rect = False if image_weights else rect
  259. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  260. self.mosaic_border = [-img_size // 2, -img_size // 2]
  261. self.stride = stride
  262. # Define labels
  263. self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in
  264. self.img_files]
  265. # Check cache
  266. cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
  267. if os.path.isfile(cache_path):
  268. cache = torch.load(cache_path) # load
  269. if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
  270. cache = self.cache_labels(cache_path) # re-cache
  271. else:
  272. cache = self.cache_labels(cache_path) # cache
  273. # Get labels
  274. labels, shapes = zip(*[cache[x] for x in self.img_files])
  275. self.shapes = np.array(shapes, dtype=np.float64)
  276. self.labels = list(labels)
  277. # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
  278. if self.rect:
  279. # Sort by aspect ratio
  280. s = self.shapes # wh
  281. ar = s[:, 1] / s[:, 0] # aspect ratio
  282. irect = ar.argsort()
  283. self.img_files = [self.img_files[i] for i in irect]
  284. self.label_files = [self.label_files[i] for i in irect]
  285. self.shapes = s[irect] # wh
  286. ar = ar[irect]
  287. # Set training image shapes
  288. shapes = [[1, 1]] * nb
  289. for i in range(nb):
  290. ari = ar[bi == i]
  291. mini, maxi = ari.min(), ari.max()
  292. if maxi < 1:
  293. shapes[i] = [maxi, 1]
  294. elif mini > 1:
  295. shapes[i] = [1, 1 / mini]
  296. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  297. # Cache labels
  298. create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
  299. nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
  300. pbar = tqdm(self.label_files)
  301. for i, file in enumerate(pbar):
  302. l = self.labels[i] # label
  303. if l.shape[0]:
  304. assert l.shape[1] == 5, '> 5 label columns: %s' % file
  305. assert (l >= 0).all(), 'negative labels: %s' % file
  306. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
  307. if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
  308. nd += 1 # print('WARNING: duplicate rows in %s' % self.label_files[i]) # duplicate rows
  309. if single_cls:
  310. l[:, 0] = 0 # force dataset into single-class mode
  311. self.labels[i] = l
  312. nf += 1 # file found
  313. # Create subdataset (a smaller dataset)
  314. if create_datasubset and ns < 1E4:
  315. if ns == 0:
  316. create_folder(path='./datasubset')
  317. os.makedirs('./datasubset/images')
  318. exclude_classes = 43
  319. if exclude_classes not in l[:, 0]:
  320. ns += 1
  321. # shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
  322. with open('./datasubset/images.txt', 'a') as f:
  323. f.write(self.img_files[i] + '\n')
  324. # Extract object detection boxes for a second stage classifier
  325. if extract_bounding_boxes:
  326. p = Path(self.img_files[i])
  327. img = cv2.imread(str(p))
  328. h, w = img.shape[:2]
  329. for j, x in enumerate(l):
  330. f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
  331. if not os.path.exists(Path(f).parent):
  332. os.makedirs(Path(f).parent) # make new output folder
  333. b = x[1:] * [w, h, w, h] # box
  334. b[2:] = b[2:].max() # rectangle to square
  335. b[2:] = b[2:] * 1.3 + 30 # pad
  336. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  337. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  338. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  339. assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
  340. else:
  341. ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
  342. # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
  343. pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
  344. cache_path, nf, nm, ne, nd, n)
  345. assert nf > 0, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
  346. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  347. self.imgs = [None] * n
  348. if cache_images:
  349. gb = 0 # Gigabytes of cached images
  350. pbar = tqdm(range(len(self.img_files)), desc='Caching images')
  351. self.img_hw0, self.img_hw = [None] * n, [None] * n
  352. for i in pbar: # max 10k images
  353. self.imgs[i], self.img_hw0[i], self.img_hw[i] = load_image(self, i) # img, hw_original, hw_resized
  354. gb += self.imgs[i].nbytes
  355. pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
  356. def cache_labels(self, path='labels.cache'):
  357. # Cache dataset labels, check images and read shapes
  358. x = {} # dict
  359. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  360. for (img, label) in pbar:
  361. try:
  362. l = []
  363. image = Image.open(img)
  364. image.verify() # PIL verify
  365. # _ = io.imread(img) # skimage verify (from skimage import io)
  366. shape = exif_size(image) # image size
  367. if os.path.isfile(label):
  368. with open(label, 'r') as f:
  369. l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels
  370. if len(l) == 0:
  371. l = np.zeros((0, 5), dtype=np.float32)
  372. x[img] = [l, shape]
  373. except Exception as e:
  374. x[img] = None
  375. print('WARNING: %s: %s' % (img, e))
  376. x['hash'] = get_hash(self.label_files + self.img_files)
  377. torch.save(x, path) # save for next time
  378. return x
  379. def __len__(self):
  380. return len(self.img_files)
  381. # def __iter__(self):
  382. # self.count = -1
  383. # print('ran dataset iter')
  384. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  385. # return self
  386. def __getitem__(self, index):
  387. if self.image_weights:
  388. index = self.indices[index]
  389. hyp = self.hyp
  390. if self.mosaic:
  391. # Load mosaic
  392. img, labels = load_mosaic(self, index)
  393. shapes = None
  394. else:
  395. # Load image
  396. img, (h0, w0), (h, w) = load_image(self, index)
  397. # Letterbox
  398. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  399. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  400. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  401. # Load labels
  402. labels = []
  403. x = self.labels[index]
  404. if x.size > 0:
  405. # Normalized xywh to pixel xyxy format
  406. labels = x.copy()
  407. labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
  408. labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
  409. labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
  410. labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
  411. if self.augment:
  412. # Augment imagespace
  413. if not self.mosaic:
  414. img, labels = random_affine(img, labels,
  415. degrees=hyp['degrees'],
  416. translate=hyp['translate'],
  417. scale=hyp['scale'],
  418. shear=hyp['shear'])
  419. # Augment colorspace
  420. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  421. # Apply cutouts
  422. # if random.random() < 0.9:
  423. # labels = cutout(img, labels)
  424. nL = len(labels) # number of labels
  425. if nL:
  426. # convert xyxy to xywh
  427. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])
  428. # Normalize coordinates 0 - 1
  429. labels[:, [2, 4]] /= img.shape[0] # height
  430. labels[:, [1, 3]] /= img.shape[1] # width
  431. if self.augment:
  432. # random left-right flip
  433. lr_flip = True
  434. if lr_flip and random.random() < 0.5:
  435. img = np.fliplr(img)
  436. if nL:
  437. labels[:, 1] = 1 - labels[:, 1]
  438. # random up-down flip
  439. ud_flip = False
  440. if ud_flip and random.random() < 0.5:
  441. img = np.flipud(img)
  442. if nL:
  443. labels[:, 2] = 1 - labels[:, 2]
  444. labels_out = torch.zeros((nL, 6))
  445. if nL:
  446. labels_out[:, 1:] = torch.from_numpy(labels)
  447. # Convert
  448. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  449. img = np.ascontiguousarray(img)
  450. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  451. @staticmethod
  452. def collate_fn(batch):
  453. img, label, path, shapes = zip(*batch) # transposed
  454. for i, l in enumerate(label):
  455. l[:, 0] = i # add target image index for build_targets()
  456. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  457. def load_image(self, index):
  458. # loads 1 image from dataset, returns img, original hw, resized hw
  459. img = self.imgs[index]
  460. if img is None: # not cached
  461. path = self.img_files[index]
  462. img = cv2.imread(path) # BGR
  463. assert img is not None, 'Image Not Found ' + path
  464. h0, w0 = img.shape[:2] # orig hw
  465. r = self.img_size / max(h0, w0) # resize image to img_size
  466. if r != 1: # always resize down, only resize up if training with augmentation
  467. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  468. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  469. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  470. else:
  471. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  472. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  473. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  474. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  475. dtype = img.dtype # uint8
  476. x = np.arange(0, 256, dtype=np.int16)
  477. lut_hue = ((x * r[0]) % 180).astype(dtype)
  478. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  479. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  480. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  481. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  482. # Histogram equalization
  483. # if random.random() < 0.2:
  484. # for i in range(3):
  485. # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
  486. def load_mosaic(self, index):
  487. # loads images in a mosaic
  488. labels4 = []
  489. s = self.img_size
  490. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  491. indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
  492. for i, index in enumerate(indices):
  493. # Load image
  494. img, _, (h, w) = load_image(self, index)
  495. # place img in img4
  496. if i == 0: # top left
  497. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  498. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  499. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  500. elif i == 1: # top right
  501. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  502. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  503. elif i == 2: # bottom left
  504. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  505. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(y2a - y1a, h)
  506. elif i == 3: # bottom right
  507. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  508. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  509. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  510. padw = x1a - x1b
  511. padh = y1a - y1b
  512. # Labels
  513. x = self.labels[index]
  514. labels = x.copy()
  515. if x.size > 0: # Normalized xywh to pixel xyxy format
  516. labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
  517. labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
  518. labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
  519. labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
  520. labels4.append(labels)
  521. # Concat/clip labels
  522. if len(labels4):
  523. labels4 = np.concatenate(labels4, 0)
  524. # np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop
  525. np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine
  526. # Replicate
  527. # img4, labels4 = replicate(img4, labels4)
  528. # Augment
  529. # img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
  530. img4, labels4 = random_affine(img4, labels4,
  531. degrees=self.hyp['degrees'],
  532. translate=self.hyp['translate'],
  533. scale=self.hyp['scale'],
  534. shear=self.hyp['shear'],
  535. border=self.mosaic_border) # border to remove
  536. return img4, labels4
  537. def replicate(img, labels):
  538. # Replicate labels
  539. h, w = img.shape[:2]
  540. boxes = labels[:, 1:].astype(int)
  541. x1, y1, x2, y2 = boxes.T
  542. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  543. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  544. x1b, y1b, x2b, y2b = boxes[i]
  545. bh, bw = y2b - y1b, x2b - x1b
  546. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  547. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  548. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  549. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  550. return img, labels
  551. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
  552. # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
  553. shape = img.shape[:2] # current shape [height, width]
  554. if isinstance(new_shape, int):
  555. new_shape = (new_shape, new_shape)
  556. # Scale ratio (new / old)
  557. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  558. if not scaleup: # only scale down, do not scale up (for better test mAP)
  559. r = min(r, 1.0)
  560. # Compute padding
  561. ratio = r, r # width, height ratios
  562. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  563. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  564. if auto: # minimum rectangle
  565. dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
  566. elif scaleFill: # stretch
  567. dw, dh = 0.0, 0.0
  568. new_unpad = (new_shape[1], new_shape[0])
  569. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  570. dw /= 2 # divide padding into 2 sides
  571. dh /= 2
  572. if shape[::-1] != new_unpad: # resize
  573. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  574. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  575. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  576. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  577. return img, ratio, (dw, dh)
  578. def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, border=(0, 0)):
  579. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  580. # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
  581. # targets = [cls, xyxy]
  582. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  583. width = img.shape[1] + border[1] * 2
  584. # Rotation and Scale
  585. R = np.eye(3)
  586. a = random.uniform(-degrees, degrees)
  587. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  588. s = random.uniform(1 - scale, 1 + scale)
  589. # s = 2 ** random.uniform(-scale, scale)
  590. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s)
  591. # Translation
  592. T = np.eye(3)
  593. T[0, 2] = random.uniform(-translate, translate) * img.shape[1] + border[1] # x translation (pixels)
  594. T[1, 2] = random.uniform(-translate, translate) * img.shape[0] + border[0] # y translation (pixels)
  595. # Shear
  596. S = np.eye(3)
  597. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  598. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  599. # Combined rotation matrix
  600. M = S @ T @ R # ORDER IS IMPORTANT HERE!!
  601. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  602. img = cv2.warpAffine(img, M[:2], dsize=(width, height), flags=cv2.INTER_LINEAR, borderValue=(114, 114, 114))
  603. # Transform label coordinates
  604. n = len(targets)
  605. if n:
  606. # warp points
  607. xy = np.ones((n * 4, 3))
  608. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  609. xy = (xy @ M.T)[:, :2].reshape(n, 8)
  610. # create new boxes
  611. x = xy[:, [0, 2, 4, 6]]
  612. y = xy[:, [1, 3, 5, 7]]
  613. xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  614. # # apply angle-based reduction of bounding boxes
  615. # radians = a * math.pi / 180
  616. # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
  617. # x = (xy[:, 2] + xy[:, 0]) / 2
  618. # y = (xy[:, 3] + xy[:, 1]) / 2
  619. # w = (xy[:, 2] - xy[:, 0]) * reduction
  620. # h = (xy[:, 3] - xy[:, 1]) * reduction
  621. # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  622. # reject warped points outside of image
  623. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  624. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  625. w = xy[:, 2] - xy[:, 0]
  626. h = xy[:, 3] - xy[:, 1]
  627. area = w * h
  628. area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
  629. ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
  630. i = (w > 2) & (h > 2) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 20)
  631. targets = targets[i]
  632. targets[:, 1:5] = xy[i]
  633. return img, targets
  634. def cutout(image, labels):
  635. # https://arxiv.org/abs/1708.04552
  636. # https://github.com/hysts/pytorch_cutout/blob/master/dataloader.py
  637. # https://towardsdatascience.com/when-conventional-wisdom-fails-revisiting-data-augmentation-for-self-driving-cars-4831998c5509
  638. h, w = image.shape[:2]
  639. def bbox_ioa(box1, box2):
  640. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  641. box2 = box2.transpose()
  642. # Get the coordinates of bounding boxes
  643. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  644. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  645. # Intersection area
  646. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  647. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  648. # box2 area
  649. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  650. # Intersection over box2 area
  651. return inter_area / box2_area
  652. # create random masks
  653. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  654. for s in scales:
  655. mask_h = random.randint(1, int(h * s))
  656. mask_w = random.randint(1, int(w * s))
  657. # box
  658. xmin = max(0, random.randint(0, w) - mask_w // 2)
  659. ymin = max(0, random.randint(0, h) - mask_h // 2)
  660. xmax = min(w, xmin + mask_w)
  661. ymax = min(h, ymin + mask_h)
  662. # apply random color mask
  663. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  664. # return unobscured labels
  665. if len(labels) and s > 0.03:
  666. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  667. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  668. labels = labels[ioa < 0.60] # remove >60% obscured labels
  669. return labels
  670. def reduce_img_size(path='../data/sm4/images', img_size=1024): # from utils.datasets import *; reduce_img_size()
  671. # creates a new ./images_reduced folder with reduced size images of maximum size img_size
  672. path_new = path + '_reduced' # reduced images path
  673. create_folder(path_new)
  674. for f in tqdm(glob.glob('%s/*.*' % path)):
  675. try:
  676. img = cv2.imread(f)
  677. h, w = img.shape[:2]
  678. r = img_size / max(h, w) # size ratio
  679. if r < 1.0:
  680. img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_AREA) # _LINEAR fastest
  681. fnew = f.replace(path, path_new) # .replace(Path(f).suffix, '.jpg')
  682. cv2.imwrite(fnew, img)
  683. except:
  684. print('WARNING: image failure %s' % f)
  685. def convert_images2bmp(): # from utils.datasets import *; convert_images2bmp()
  686. # Save images
  687. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  688. # for path in ['../coco/images/val2014', '../coco/images/train2014']:
  689. for path in ['../data/sm4/images', '../data/sm4/background']:
  690. create_folder(path + 'bmp')
  691. for ext in formats: # ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
  692. for f in tqdm(glob.glob('%s/*%s' % (path, ext)), desc='Converting %s' % ext):
  693. cv2.imwrite(f.replace(ext.lower(), '.bmp').replace(path, path + 'bmp'), cv2.imread(f))
  694. # Save labels
  695. # for path in ['../coco/trainvalno5k.txt', '../coco/5k.txt']:
  696. for file in ['../data/sm4/out_train.txt', '../data/sm4/out_test.txt']:
  697. with open(file, 'r') as f:
  698. lines = f.read()
  699. # lines = f.read().replace('2014/', '2014bmp/') # coco
  700. lines = lines.replace('/images', '/imagesbmp')
  701. lines = lines.replace('/background', '/backgroundbmp')
  702. for ext in formats:
  703. lines = lines.replace(ext, '.bmp')
  704. with open(file.replace('.txt', 'bmp.txt'), 'w') as f:
  705. f.write(lines)
  706. def recursive_dataset2bmp(dataset='../data/sm4_bmp'): # from utils.datasets import *; recursive_dataset2bmp()
  707. # Converts dataset to bmp (for faster training)
  708. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  709. for a, b, files in os.walk(dataset):
  710. for file in tqdm(files, desc=a):
  711. p = a + '/' + file
  712. s = Path(file).suffix
  713. if s == '.txt': # replace text
  714. with open(p, 'r') as f:
  715. lines = f.read()
  716. for f in formats:
  717. lines = lines.replace(f, '.bmp')
  718. with open(p, 'w') as f:
  719. f.write(lines)
  720. elif s in formats: # replace image
  721. cv2.imwrite(p.replace(s, '.bmp'), cv2.imread(p))
  722. if s != '.bmp':
  723. os.system("rm '%s'" % p)
  724. def imagelist2folder(path='data/coco_64img.txt'): # from utils.datasets import *; imagelist2folder()
  725. # Copies all the images in a text file (list of images) into a folder
  726. create_folder(path[:-4])
  727. with open(path, 'r') as f:
  728. for line in f.read().splitlines():
  729. os.system('cp "%s" %s' % (line, path[:-4]))
  730. print(line)
  731. def create_folder(path='./new_folder'):
  732. # Create folder
  733. if os.path.exists(path):
  734. shutil.rmtree(path) # delete output folder
  735. os.makedirs(path) # make new output folder