Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

941 linhas
38KB

  1. import glob
  2. import math
  3. import os
  4. import random
  5. import shutil
  6. import time
  7. from pathlib import Path
  8. from threading import Thread
  9. import cv2
  10. import numpy as np
  11. import torch
  12. from PIL import Image, ExifTags
  13. from torch.utils.data import Dataset
  14. from tqdm import tqdm
  15. from utils.general import xyxy2xywh, xywh2xyxy, torch_distributed_zero_first
  16. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  17. img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']
  18. vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
  19. # Get orientation exif tag
  20. for orientation in ExifTags.TAGS.keys():
  21. if ExifTags.TAGS[orientation] == 'Orientation':
  22. break
  23. def get_hash(files):
  24. # Returns a single hash value of a list of files
  25. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  26. def exif_size(img):
  27. # Returns exif-corrected PIL size
  28. s = img.size # (width, height)
  29. try:
  30. rotation = dict(img._getexif().items())[orientation]
  31. if rotation == 6: # rotation 270
  32. s = (s[1], s[0])
  33. elif rotation == 8: # rotation 90
  34. s = (s[1], s[0])
  35. except:
  36. pass
  37. return s
  38. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  39. rank=-1, world_size=1, workers=8):
  40. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
  41. with torch_distributed_zero_first(rank):
  42. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  43. augment=augment, # augment images
  44. hyp=hyp, # augmentation hyperparameters
  45. rect=rect, # rectangular training
  46. cache_images=cache,
  47. single_cls=opt.single_cls,
  48. stride=int(stride),
  49. pad=pad,
  50. rank=rank)
  51. batch_size = min(batch_size, len(dataset))
  52. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  53. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  54. dataloader = InfiniteDataLoader(dataset,
  55. batch_size=batch_size,
  56. num_workers=nw,
  57. sampler=sampler,
  58. pin_memory=True,
  59. collate_fn=LoadImagesAndLabels.collate_fn)
  60. return dataloader, dataset
  61. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  62. """ Dataloader that reuses workers.
  63. Uses same syntax as vanilla DataLoader.
  64. """
  65. def __init__(self, *args, **kwargs):
  66. super().__init__(*args, **kwargs)
  67. object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler))
  68. self.iterator = super().__iter__()
  69. def __len__(self):
  70. return len(self.batch_sampler.sampler)
  71. def __iter__(self):
  72. for i in range(len(self)):
  73. yield next(self.iterator)
  74. class _RepeatSampler(object):
  75. """ Sampler that repeats forever.
  76. Args:
  77. sampler (Sampler)
  78. """
  79. def __init__(self, sampler):
  80. self.sampler = sampler
  81. def __iter__(self):
  82. while True:
  83. yield from iter(self.sampler)
  84. class LoadImages: # for inference
  85. def __init__(self, path, img_size=640):
  86. p = str(Path(path)) # os-agnostic
  87. p = os.path.abspath(p) # absolute path
  88. if '*' in p:
  89. files = sorted(glob.glob(p, recursive=True)) # glob
  90. elif os.path.isdir(p):
  91. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  92. elif os.path.isfile(p):
  93. files = [p] # files
  94. else:
  95. raise Exception('ERROR: %s does not exist' % p)
  96. images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
  97. videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
  98. ni, nv = len(images), len(videos)
  99. self.img_size = img_size
  100. self.files = images + videos
  101. self.nf = ni + nv # number of files
  102. self.video_flag = [False] * ni + [True] * nv
  103. self.mode = 'images'
  104. if any(videos):
  105. self.new_video(videos[0]) # new video
  106. else:
  107. self.cap = None
  108. assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
  109. (p, img_formats, vid_formats)
  110. def __iter__(self):
  111. self.count = 0
  112. return self
  113. def __next__(self):
  114. if self.count == self.nf:
  115. raise StopIteration
  116. path = self.files[self.count]
  117. if self.video_flag[self.count]:
  118. # Read video
  119. self.mode = 'video'
  120. ret_val, img0 = self.cap.read()
  121. if not ret_val:
  122. self.count += 1
  123. self.cap.release()
  124. if self.count == self.nf: # last video
  125. raise StopIteration
  126. else:
  127. path = self.files[self.count]
  128. self.new_video(path)
  129. ret_val, img0 = self.cap.read()
  130. self.frame += 1
  131. print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')
  132. else:
  133. # Read image
  134. self.count += 1
  135. img0 = cv2.imread(path) # BGR
  136. assert img0 is not None, 'Image Not Found ' + path
  137. print('image %g/%g %s: ' % (self.count, self.nf, path), end='')
  138. # Padded resize
  139. img = letterbox(img0, new_shape=self.img_size)[0]
  140. # Convert
  141. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  142. img = np.ascontiguousarray(img)
  143. # cv2.imwrite(path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
  144. return path, img, img0, self.cap
  145. def new_video(self, path):
  146. self.frame = 0
  147. self.cap = cv2.VideoCapture(path)
  148. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  149. def __len__(self):
  150. return self.nf # number of files
  151. class LoadWebcam: # for inference
  152. def __init__(self, pipe=0, img_size=640):
  153. self.img_size = img_size
  154. if pipe == '0':
  155. pipe = 0 # local camera
  156. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  157. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  158. # pipe = 'rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa' # IP traffic camera
  159. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  160. # https://answers.opencv.org/question/215996/changing-gstreamer-pipeline-to-opencv-in-pythonsolved/
  161. # pipe = '"rtspsrc location="rtsp://username:password@192.168.1.64/1" latency=10 ! appsink' # GStreamer
  162. # https://answers.opencv.org/question/200787/video-acceleration-gstremer-pipeline-in-videocapture/
  163. # https://stackoverflow.com/questions/54095699/install-gstreamer-support-for-opencv-python-package # install help
  164. # pipe = "rtspsrc location=rtsp://root:root@192.168.0.91:554/axis-media/media.amp?videocodec=h264&resolution=3840x2160 protocols=GST_RTSP_LOWER_TRANS_TCP ! rtph264depay ! queue ! vaapih264dec ! videoconvert ! appsink" # GStreamer
  165. self.pipe = pipe
  166. self.cap = cv2.VideoCapture(pipe) # video capture object
  167. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  168. def __iter__(self):
  169. self.count = -1
  170. return self
  171. def __next__(self):
  172. self.count += 1
  173. if cv2.waitKey(1) == ord('q'): # q to quit
  174. self.cap.release()
  175. cv2.destroyAllWindows()
  176. raise StopIteration
  177. # Read frame
  178. if self.pipe == 0: # local camera
  179. ret_val, img0 = self.cap.read()
  180. img0 = cv2.flip(img0, 1) # flip left-right
  181. else: # IP camera
  182. n = 0
  183. while True:
  184. n += 1
  185. self.cap.grab()
  186. if n % 30 == 0: # skip frames
  187. ret_val, img0 = self.cap.retrieve()
  188. if ret_val:
  189. break
  190. # Print
  191. assert ret_val, 'Camera Error %s' % self.pipe
  192. img_path = 'webcam.jpg'
  193. print('webcam %g: ' % self.count, end='')
  194. # Padded resize
  195. img = letterbox(img0, new_shape=self.img_size)[0]
  196. # Convert
  197. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  198. img = np.ascontiguousarray(img)
  199. return img_path, img, img0, None
  200. def __len__(self):
  201. return 0
  202. class LoadStreams: # multiple IP or RTSP cameras
  203. def __init__(self, sources='streams.txt', img_size=640):
  204. self.mode = 'images'
  205. self.img_size = img_size
  206. if os.path.isfile(sources):
  207. with open(sources, 'r') as f:
  208. sources = [x.strip() for x in f.read().splitlines() if len(x.strip())]
  209. else:
  210. sources = [sources]
  211. n = len(sources)
  212. self.imgs = [None] * n
  213. self.sources = sources
  214. for i, s in enumerate(sources):
  215. # Start the thread to read frames from the video stream
  216. print('%g/%g: %s... ' % (i + 1, n, s), end='')
  217. cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s)
  218. assert cap.isOpened(), 'Failed to open %s' % s
  219. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  220. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  221. fps = cap.get(cv2.CAP_PROP_FPS) % 100
  222. _, self.imgs[i] = cap.read() # guarantee first frame
  223. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  224. print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
  225. thread.start()
  226. print('') # newline
  227. # check for common shapes
  228. s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
  229. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  230. if not self.rect:
  231. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  232. def update(self, index, cap):
  233. # Read next stream frame in a daemon thread
  234. n = 0
  235. while cap.isOpened():
  236. n += 1
  237. # _, self.imgs[index] = cap.read()
  238. cap.grab()
  239. if n == 4: # read every 4th frame
  240. _, self.imgs[index] = cap.retrieve()
  241. n = 0
  242. time.sleep(0.01) # wait time
  243. def __iter__(self):
  244. self.count = -1
  245. return self
  246. def __next__(self):
  247. self.count += 1
  248. img0 = self.imgs.copy()
  249. if cv2.waitKey(1) == ord('q'): # q to quit
  250. cv2.destroyAllWindows()
  251. raise StopIteration
  252. # Letterbox
  253. img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
  254. # Stack
  255. img = np.stack(img, 0)
  256. # Convert
  257. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  258. img = np.ascontiguousarray(img)
  259. return self.sources, img, img0, None
  260. def __len__(self):
  261. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  262. class LoadImagesAndLabels(Dataset): # for training/testing
  263. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  264. cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
  265. try:
  266. f = [] # image files
  267. for p in path if isinstance(path, list) else [path]:
  268. p = str(Path(p)) # os-agnostic
  269. parent = str(Path(p).parent) + os.sep
  270. if os.path.isfile(p): # file
  271. with open(p, 'r') as t:
  272. t = t.read().splitlines()
  273. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  274. elif os.path.isdir(p): # folder
  275. f += glob.iglob(p + os.sep + '*.*')
  276. else:
  277. raise Exception('%s does not exist' % p)
  278. self.img_files = sorted(
  279. [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
  280. except Exception as e:
  281. raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
  282. n = len(self.img_files)
  283. assert n > 0, 'No images found in %s. See %s' % (path, help_url)
  284. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  285. nb = bi[-1] + 1 # number of batches
  286. self.n = n # number of images
  287. self.batch = bi # batch index of image
  288. self.img_size = img_size
  289. self.augment = augment
  290. self.hyp = hyp
  291. self.image_weights = image_weights
  292. self.rect = False if image_weights else rect
  293. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  294. self.mosaic_border = [-img_size // 2, -img_size // 2]
  295. self.stride = stride
  296. # Define labels
  297. self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in
  298. self.img_files]
  299. # Check cache
  300. cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
  301. if os.path.isfile(cache_path):
  302. cache = torch.load(cache_path) # load
  303. if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
  304. cache = self.cache_labels(cache_path) # re-cache
  305. else:
  306. cache = self.cache_labels(cache_path) # cache
  307. # Get labels
  308. labels, shapes = zip(*[cache[x] for x in self.img_files])
  309. self.shapes = np.array(shapes, dtype=np.float64)
  310. self.labels = list(labels)
  311. # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
  312. if self.rect:
  313. # Sort by aspect ratio
  314. s = self.shapes # wh
  315. ar = s[:, 1] / s[:, 0] # aspect ratio
  316. irect = ar.argsort()
  317. self.img_files = [self.img_files[i] for i in irect]
  318. self.label_files = [self.label_files[i] for i in irect]
  319. self.labels = [self.labels[i] for i in irect]
  320. self.shapes = s[irect] # wh
  321. ar = ar[irect]
  322. # Set training image shapes
  323. shapes = [[1, 1]] * nb
  324. for i in range(nb):
  325. ari = ar[bi == i]
  326. mini, maxi = ari.min(), ari.max()
  327. if maxi < 1:
  328. shapes[i] = [maxi, 1]
  329. elif mini > 1:
  330. shapes[i] = [1, 1 / mini]
  331. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  332. # Cache labels
  333. create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
  334. nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
  335. pbar = enumerate(self.label_files)
  336. if rank in [-1, 0]:
  337. pbar = tqdm(pbar)
  338. for i, file in pbar:
  339. l = self.labels[i] # label
  340. if l is not None and l.shape[0]:
  341. assert l.shape[1] == 5, '> 5 label columns: %s' % file
  342. assert (l >= 0).all(), 'negative labels: %s' % file
  343. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
  344. if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
  345. nd += 1 # print('WARNING: duplicate rows in %s' % self.label_files[i]) # duplicate rows
  346. if single_cls:
  347. l[:, 0] = 0 # force dataset into single-class mode
  348. self.labels[i] = l
  349. nf += 1 # file found
  350. # Create subdataset (a smaller dataset)
  351. if create_datasubset and ns < 1E4:
  352. if ns == 0:
  353. create_folder(path='./datasubset')
  354. os.makedirs('./datasubset/images')
  355. exclude_classes = 43
  356. if exclude_classes not in l[:, 0]:
  357. ns += 1
  358. # shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
  359. with open('./datasubset/images.txt', 'a') as f:
  360. f.write(self.img_files[i] + '\n')
  361. # Extract object detection boxes for a second stage classifier
  362. if extract_bounding_boxes:
  363. p = Path(self.img_files[i])
  364. img = cv2.imread(str(p))
  365. h, w = img.shape[:2]
  366. for j, x in enumerate(l):
  367. f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
  368. if not os.path.exists(Path(f).parent):
  369. os.makedirs(Path(f).parent) # make new output folder
  370. b = x[1:] * [w, h, w, h] # box
  371. b[2:] = b[2:].max() # rectangle to square
  372. b[2:] = b[2:] * 1.3 + 30 # pad
  373. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  374. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  375. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  376. assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
  377. else:
  378. ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
  379. # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
  380. if rank in [-1, 0]:
  381. pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
  382. cache_path, nf, nm, ne, nd, n)
  383. if nf == 0:
  384. s = 'WARNING: No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
  385. print(s)
  386. assert not augment, '%s. Can not train without labels.' % s
  387. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  388. self.imgs = [None] * n
  389. if cache_images:
  390. gb = 0 # Gigabytes of cached images
  391. pbar = tqdm(range(len(self.img_files)), desc='Caching images')
  392. self.img_hw0, self.img_hw = [None] * n, [None] * n
  393. for i in pbar: # max 10k images
  394. self.imgs[i], self.img_hw0[i], self.img_hw[i] = load_image(self, i) # img, hw_original, hw_resized
  395. gb += self.imgs[i].nbytes
  396. pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
  397. def cache_labels(self, path='labels.cache'):
  398. # Cache dataset labels, check images and read shapes
  399. x = {} # dict
  400. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  401. for (img, label) in pbar:
  402. try:
  403. l = []
  404. image = Image.open(img)
  405. image.verify() # PIL verify
  406. # _ = io.imread(img) # skimage verify (from skimage import io)
  407. shape = exif_size(image) # image size
  408. assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
  409. if os.path.isfile(label):
  410. with open(label, 'r') as f:
  411. l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels
  412. if len(l) == 0:
  413. l = np.zeros((0, 5), dtype=np.float32)
  414. x[img] = [l, shape]
  415. except Exception as e:
  416. x[img] = [None, None]
  417. print('WARNING: %s: %s' % (img, e))
  418. x['hash'] = get_hash(self.label_files + self.img_files)
  419. torch.save(x, path) # save for next time
  420. return x
  421. def __len__(self):
  422. return len(self.img_files)
  423. # def __iter__(self):
  424. # self.count = -1
  425. # print('ran dataset iter')
  426. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  427. # return self
  428. def __getitem__(self, index):
  429. if self.image_weights:
  430. index = self.indices[index]
  431. hyp = self.hyp
  432. if self.mosaic:
  433. # Load mosaic
  434. img, labels = load_mosaic(self, index)
  435. shapes = None
  436. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  437. if random.random() < hyp['mixup']:
  438. img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
  439. r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
  440. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  441. labels = np.concatenate((labels, labels2), 0)
  442. else:
  443. # Load image
  444. img, (h0, w0), (h, w) = load_image(self, index)
  445. # Letterbox
  446. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  447. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  448. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  449. # Load labels
  450. labels = []
  451. x = self.labels[index]
  452. if x.size > 0:
  453. # Normalized xywh to pixel xyxy format
  454. labels = x.copy()
  455. labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
  456. labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
  457. labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
  458. labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
  459. if self.augment:
  460. # Augment imagespace
  461. if not self.mosaic:
  462. img, labels = random_perspective(img, labels,
  463. degrees=hyp['degrees'],
  464. translate=hyp['translate'],
  465. scale=hyp['scale'],
  466. shear=hyp['shear'],
  467. perspective=hyp['perspective'])
  468. # Augment colorspace
  469. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  470. # Apply cutouts
  471. # if random.random() < 0.9:
  472. # labels = cutout(img, labels)
  473. nL = len(labels) # number of labels
  474. if nL:
  475. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  476. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  477. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  478. if self.augment:
  479. # flip up-down
  480. if random.random() < hyp['flipud']:
  481. img = np.flipud(img)
  482. if nL:
  483. labels[:, 2] = 1 - labels[:, 2]
  484. # flip left-right
  485. if random.random() < hyp['fliplr']:
  486. img = np.fliplr(img)
  487. if nL:
  488. labels[:, 1] = 1 - labels[:, 1]
  489. labels_out = torch.zeros((nL, 6))
  490. if nL:
  491. labels_out[:, 1:] = torch.from_numpy(labels)
  492. # Convert
  493. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  494. img = np.ascontiguousarray(img)
  495. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  496. @staticmethod
  497. def collate_fn(batch):
  498. img, label, path, shapes = zip(*batch) # transposed
  499. for i, l in enumerate(label):
  500. l[:, 0] = i # add target image index for build_targets()
  501. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  502. # Ancillary functions --------------------------------------------------------------------------------------------------
  503. def load_image(self, index):
  504. # loads 1 image from dataset, returns img, original hw, resized hw
  505. img = self.imgs[index]
  506. if img is None: # not cached
  507. path = self.img_files[index]
  508. img = cv2.imread(path) # BGR
  509. assert img is not None, 'Image Not Found ' + path
  510. h0, w0 = img.shape[:2] # orig hw
  511. r = self.img_size / max(h0, w0) # resize image to img_size
  512. if r != 1: # always resize down, only resize up if training with augmentation
  513. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  514. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  515. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  516. else:
  517. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  518. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  519. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  520. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  521. dtype = img.dtype # uint8
  522. x = np.arange(0, 256, dtype=np.int16)
  523. lut_hue = ((x * r[0]) % 180).astype(dtype)
  524. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  525. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  526. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  527. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  528. # Histogram equalization
  529. # if random.random() < 0.2:
  530. # for i in range(3):
  531. # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
  532. def load_mosaic(self, index):
  533. # loads images in a mosaic
  534. labels4 = []
  535. s = self.img_size
  536. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  537. indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
  538. for i, index in enumerate(indices):
  539. # Load image
  540. img, _, (h, w) = load_image(self, index)
  541. # place img in img4
  542. if i == 0: # top left
  543. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  544. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  545. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  546. elif i == 1: # top right
  547. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  548. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  549. elif i == 2: # bottom left
  550. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  551. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(y2a - y1a, h)
  552. elif i == 3: # bottom right
  553. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  554. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  555. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  556. padw = x1a - x1b
  557. padh = y1a - y1b
  558. # Labels
  559. x = self.labels[index]
  560. labels = x.copy()
  561. if x.size > 0: # Normalized xywh to pixel xyxy format
  562. labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
  563. labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
  564. labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
  565. labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
  566. labels4.append(labels)
  567. # Concat/clip labels
  568. if len(labels4):
  569. labels4 = np.concatenate(labels4, 0)
  570. np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective
  571. # img4, labels4 = replicate(img4, labels4) # replicate
  572. # Augment
  573. img4, labels4 = random_perspective(img4, labels4,
  574. degrees=self.hyp['degrees'],
  575. translate=self.hyp['translate'],
  576. scale=self.hyp['scale'],
  577. shear=self.hyp['shear'],
  578. perspective=self.hyp['perspective'],
  579. border=self.mosaic_border) # border to remove
  580. return img4, labels4
  581. def replicate(img, labels):
  582. # Replicate labels
  583. h, w = img.shape[:2]
  584. boxes = labels[:, 1:].astype(int)
  585. x1, y1, x2, y2 = boxes.T
  586. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  587. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  588. x1b, y1b, x2b, y2b = boxes[i]
  589. bh, bw = y2b - y1b, x2b - x1b
  590. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  591. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  592. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  593. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  594. return img, labels
  595. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
  596. # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
  597. shape = img.shape[:2] # current shape [height, width]
  598. if isinstance(new_shape, int):
  599. new_shape = (new_shape, new_shape)
  600. # Scale ratio (new / old)
  601. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  602. if not scaleup: # only scale down, do not scale up (for better test mAP)
  603. r = min(r, 1.0)
  604. # Compute padding
  605. ratio = r, r # width, height ratios
  606. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  607. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  608. if auto: # minimum rectangle
  609. dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
  610. elif scaleFill: # stretch
  611. dw, dh = 0.0, 0.0
  612. new_unpad = (new_shape[1], new_shape[0])
  613. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  614. dw /= 2 # divide padding into 2 sides
  615. dh /= 2
  616. if shape[::-1] != new_unpad: # resize
  617. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  618. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  619. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  620. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  621. return img, ratio, (dw, dh)
  622. def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0, border=(0, 0)):
  623. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  624. # targets = [cls, xyxy]
  625. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  626. width = img.shape[1] + border[1] * 2
  627. # Center
  628. C = np.eye(3)
  629. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  630. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  631. # Perspective
  632. P = np.eye(3)
  633. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  634. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  635. # Rotation and Scale
  636. R = np.eye(3)
  637. a = random.uniform(-degrees, degrees)
  638. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  639. s = random.uniform(1 - scale, 1 + scale)
  640. # s = 2 ** random.uniform(-scale, scale)
  641. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  642. # Shear
  643. S = np.eye(3)
  644. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  645. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  646. # Translation
  647. T = np.eye(3)
  648. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  649. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  650. # Combined rotation matrix
  651. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  652. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  653. if perspective:
  654. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  655. else: # affine
  656. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  657. # Visualize
  658. # import matplotlib.pyplot as plt
  659. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  660. # ax[0].imshow(img[:, :, ::-1]) # base
  661. # ax[1].imshow(img2[:, :, ::-1]) # warped
  662. # Transform label coordinates
  663. n = len(targets)
  664. if n:
  665. # warp points
  666. xy = np.ones((n * 4, 3))
  667. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  668. xy = xy @ M.T # transform
  669. if perspective:
  670. xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
  671. else: # affine
  672. xy = xy[:, :2].reshape(n, 8)
  673. # create new boxes
  674. x = xy[:, [0, 2, 4, 6]]
  675. y = xy[:, [1, 3, 5, 7]]
  676. xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  677. # # apply angle-based reduction of bounding boxes
  678. # radians = a * math.pi / 180
  679. # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
  680. # x = (xy[:, 2] + xy[:, 0]) / 2
  681. # y = (xy[:, 3] + xy[:, 1]) / 2
  682. # w = (xy[:, 2] - xy[:, 0]) * reduction
  683. # h = (xy[:, 3] - xy[:, 1]) * reduction
  684. # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  685. # clip boxes
  686. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  687. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  688. # filter candidates
  689. i = box_candidates(box1=targets[:, 1:5].T * s, box2=xy.T)
  690. targets = targets[i]
  691. targets[:, 1:5] = xy[i]
  692. return img, targets
  693. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1): # box1(4,n), box2(4,n)
  694. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  695. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  696. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  697. ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
  698. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates
  699. def cutout(image, labels):
  700. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  701. h, w = image.shape[:2]
  702. def bbox_ioa(box1, box2):
  703. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  704. box2 = box2.transpose()
  705. # Get the coordinates of bounding boxes
  706. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  707. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  708. # Intersection area
  709. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  710. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  711. # box2 area
  712. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  713. # Intersection over box2 area
  714. return inter_area / box2_area
  715. # create random masks
  716. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  717. for s in scales:
  718. mask_h = random.randint(1, int(h * s))
  719. mask_w = random.randint(1, int(w * s))
  720. # box
  721. xmin = max(0, random.randint(0, w) - mask_w // 2)
  722. ymin = max(0, random.randint(0, h) - mask_h // 2)
  723. xmax = min(w, xmin + mask_w)
  724. ymax = min(h, ymin + mask_h)
  725. # apply random color mask
  726. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  727. # return unobscured labels
  728. if len(labels) and s > 0.03:
  729. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  730. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  731. labels = labels[ioa < 0.60] # remove >60% obscured labels
  732. return labels
  733. def reduce_img_size(path='path/images', img_size=1024): # from utils.datasets import *; reduce_img_size()
  734. # creates a new ./images_reduced folder with reduced size images of maximum size img_size
  735. path_new = path + '_reduced' # reduced images path
  736. create_folder(path_new)
  737. for f in tqdm(glob.glob('%s/*.*' % path)):
  738. try:
  739. img = cv2.imread(f)
  740. h, w = img.shape[:2]
  741. r = img_size / max(h, w) # size ratio
  742. if r < 1.0:
  743. img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_AREA) # _LINEAR fastest
  744. fnew = f.replace(path, path_new) # .replace(Path(f).suffix, '.jpg')
  745. cv2.imwrite(fnew, img)
  746. except:
  747. print('WARNING: image failure %s' % f)
  748. def recursive_dataset2bmp(dataset='path/dataset_bmp'): # from utils.datasets import *; recursive_dataset2bmp()
  749. # Converts dataset to bmp (for faster training)
  750. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  751. for a, b, files in os.walk(dataset):
  752. for file in tqdm(files, desc=a):
  753. p = a + '/' + file
  754. s = Path(file).suffix
  755. if s == '.txt': # replace text
  756. with open(p, 'r') as f:
  757. lines = f.read()
  758. for f in formats:
  759. lines = lines.replace(f, '.bmp')
  760. with open(p, 'w') as f:
  761. f.write(lines)
  762. elif s in formats: # replace image
  763. cv2.imwrite(p.replace(s, '.bmp'), cv2.imread(p))
  764. if s != '.bmp':
  765. os.system("rm '%s'" % p)
  766. def imagelist2folder(path='path/images.txt'): # from utils.datasets import *; imagelist2folder()
  767. # Copies all the images in a text file (list of images) into a folder
  768. create_folder(path[:-4])
  769. with open(path, 'r') as f:
  770. for line in f.read().splitlines():
  771. os.system('cp "%s" %s' % (line, path[:-4]))
  772. print(line)
  773. def create_folder(path='./new'):
  774. # Create folder
  775. if os.path.exists(path):
  776. shutil.rmtree(path) # delete output folder
  777. os.makedirs(path) # make new output folder