選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

datasets.py 37KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904
  1. import glob
  2. import math
  3. import os
  4. import random
  5. import shutil
  6. import time
  7. from pathlib import Path
  8. from threading import Thread
  9. import cv2
  10. import numpy as np
  11. import torch
  12. from PIL import Image, ExifTags
  13. from torch.utils.data import Dataset
  14. from tqdm import tqdm
  15. from utils.utils import xyxy2xywh, xywh2xyxy
  16. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  17. img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
  18. vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
  19. # Get orientation exif tag
  20. for orientation in ExifTags.TAGS.keys():
  21. if ExifTags.TAGS[orientation] == 'Orientation':
  22. break
  23. def get_hash(files):
  24. # Returns a single hash value of a list of files
  25. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  26. def exif_size(img):
  27. # Returns exif-corrected PIL size
  28. s = img.size # (width, height)
  29. try:
  30. rotation = dict(img._getexif().items())[orientation]
  31. if rotation == 6: # rotation 270
  32. s = (s[1], s[0])
  33. elif rotation == 8: # rotation 90
  34. s = (s[1], s[0])
  35. except:
  36. pass
  37. return s
  38. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False):
  39. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  40. augment=augment, # augment images
  41. hyp=hyp, # augmentation hyperparameters
  42. rect=rect, # rectangular training
  43. cache_images=cache,
  44. single_cls=opt.single_cls,
  45. stride=int(stride),
  46. pad=pad)
  47. batch_size = min(batch_size, len(dataset))
  48. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
  49. dataloader = torch.utils.data.DataLoader(dataset,
  50. batch_size=batch_size,
  51. num_workers=nw,
  52. pin_memory=True,
  53. collate_fn=LoadImagesAndLabels.collate_fn)
  54. return dataloader, dataset
  55. class LoadImages: # for inference
  56. def __init__(self, path, img_size=640):
  57. p = str(Path(path)) # os-agnostic
  58. p = os.path.abspath(p) # absolute path
  59. if '*' in p:
  60. files = sorted(glob.glob(p)) # glob
  61. elif os.path.isdir(p):
  62. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  63. elif os.path.isfile(p):
  64. files = [p] # files
  65. else:
  66. raise Exception('ERROR: %s does not exist' % p)
  67. images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats]
  68. videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats]
  69. ni, nv = len(images), len(videos)
  70. self.img_size = img_size
  71. self.files = images + videos
  72. self.nf = ni + nv # number of files
  73. self.video_flag = [False] * ni + [True] * nv
  74. self.mode = 'images'
  75. if any(videos):
  76. self.new_video(videos[0]) # new video
  77. else:
  78. self.cap = None
  79. assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \
  80. (p, img_formats, vid_formats)
  81. def __iter__(self):
  82. self.count = 0
  83. return self
  84. def __next__(self):
  85. if self.count == self.nf:
  86. raise StopIteration
  87. path = self.files[self.count]
  88. if self.video_flag[self.count]:
  89. # Read video
  90. self.mode = 'video'
  91. ret_val, img0 = self.cap.read()
  92. if not ret_val:
  93. self.count += 1
  94. self.cap.release()
  95. if self.count == self.nf: # last video
  96. raise StopIteration
  97. else:
  98. path = self.files[self.count]
  99. self.new_video(path)
  100. ret_val, img0 = self.cap.read()
  101. self.frame += 1
  102. print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='')
  103. else:
  104. # Read image
  105. self.count += 1
  106. img0 = cv2.imread(path) # BGR
  107. assert img0 is not None, 'Image Not Found ' + path
  108. print('image %g/%g %s: ' % (self.count, self.nf, path), end='')
  109. # Padded resize
  110. img = letterbox(img0, new_shape=self.img_size)[0]
  111. # Convert
  112. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  113. img = np.ascontiguousarray(img)
  114. # cv2.imwrite(path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1]) # save letterbox image
  115. return path, img, img0, self.cap
  116. def new_video(self, path):
  117. self.frame = 0
  118. self.cap = cv2.VideoCapture(path)
  119. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  120. def __len__(self):
  121. return self.nf # number of files
  122. class LoadWebcam: # for inference
  123. def __init__(self, pipe=0, img_size=640):
  124. self.img_size = img_size
  125. if pipe == '0':
  126. pipe = 0 # local camera
  127. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  128. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  129. # pipe = 'rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa' # IP traffic camera
  130. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  131. # https://answers.opencv.org/question/215996/changing-gstreamer-pipeline-to-opencv-in-pythonsolved/
  132. # pipe = '"rtspsrc location="rtsp://username:password@192.168.1.64/1" latency=10 ! appsink' # GStreamer
  133. # https://answers.opencv.org/question/200787/video-acceleration-gstremer-pipeline-in-videocapture/
  134. # https://stackoverflow.com/questions/54095699/install-gstreamer-support-for-opencv-python-package # install help
  135. # pipe = "rtspsrc location=rtsp://root:root@192.168.0.91:554/axis-media/media.amp?videocodec=h264&resolution=3840x2160 protocols=GST_RTSP_LOWER_TRANS_TCP ! rtph264depay ! queue ! vaapih264dec ! videoconvert ! appsink" # GStreamer
  136. self.pipe = pipe
  137. self.cap = cv2.VideoCapture(pipe) # video capture object
  138. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  139. def __iter__(self):
  140. self.count = -1
  141. return self
  142. def __next__(self):
  143. self.count += 1
  144. if cv2.waitKey(1) == ord('q'): # q to quit
  145. self.cap.release()
  146. cv2.destroyAllWindows()
  147. raise StopIteration
  148. # Read frame
  149. if self.pipe == 0: # local camera
  150. ret_val, img0 = self.cap.read()
  151. img0 = cv2.flip(img0, 1) # flip left-right
  152. else: # IP camera
  153. n = 0
  154. while True:
  155. n += 1
  156. self.cap.grab()
  157. if n % 30 == 0: # skip frames
  158. ret_val, img0 = self.cap.retrieve()
  159. if ret_val:
  160. break
  161. # Print
  162. assert ret_val, 'Camera Error %s' % self.pipe
  163. img_path = 'webcam.jpg'
  164. print('webcam %g: ' % self.count, end='')
  165. # Padded resize
  166. img = letterbox(img0, new_shape=self.img_size)[0]
  167. # Convert
  168. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  169. img = np.ascontiguousarray(img)
  170. return img_path, img, img0, None
  171. def __len__(self):
  172. return 0
  173. class LoadStreams: # multiple IP or RTSP cameras
  174. def __init__(self, sources='streams.txt', img_size=640):
  175. self.mode = 'images'
  176. self.img_size = img_size
  177. if os.path.isfile(sources):
  178. with open(sources, 'r') as f:
  179. sources = [x.strip() for x in f.read().splitlines() if len(x.strip())]
  180. else:
  181. sources = [sources]
  182. n = len(sources)
  183. self.imgs = [None] * n
  184. self.sources = sources
  185. for i, s in enumerate(sources):
  186. # Start the thread to read frames from the video stream
  187. print('%g/%g: %s... ' % (i + 1, n, s), end='')
  188. cap = cv2.VideoCapture(0 if s == '0' else s)
  189. assert cap.isOpened(), 'Failed to open %s' % s
  190. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  191. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  192. fps = cap.get(cv2.CAP_PROP_FPS) % 100
  193. _, self.imgs[i] = cap.read() # guarantee first frame
  194. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  195. print(' success (%gx%g at %.2f FPS).' % (w, h, fps))
  196. thread.start()
  197. print('') # newline
  198. # check for common shapes
  199. s = np.stack([letterbox(x, new_shape=self.img_size)[0].shape for x in self.imgs], 0) # inference shapes
  200. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  201. if not self.rect:
  202. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  203. def update(self, index, cap):
  204. # Read next stream frame in a daemon thread
  205. n = 0
  206. while cap.isOpened():
  207. n += 1
  208. # _, self.imgs[index] = cap.read()
  209. cap.grab()
  210. if n == 4: # read every 4th frame
  211. _, self.imgs[index] = cap.retrieve()
  212. n = 0
  213. time.sleep(0.01) # wait time
  214. def __iter__(self):
  215. self.count = -1
  216. return self
  217. def __next__(self):
  218. self.count += 1
  219. img0 = self.imgs.copy()
  220. if cv2.waitKey(1) == ord('q'): # q to quit
  221. cv2.destroyAllWindows()
  222. raise StopIteration
  223. # Letterbox
  224. img = [letterbox(x, new_shape=self.img_size, auto=self.rect)[0] for x in img0]
  225. # Stack
  226. img = np.stack(img, 0)
  227. # Convert
  228. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  229. img = np.ascontiguousarray(img)
  230. return self.sources, img, img0, None
  231. def __len__(self):
  232. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  233. class LoadImagesAndLabels(Dataset): # for training/testing
  234. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  235. cache_images=False, single_cls=False, stride=32, pad=0.0):
  236. try:
  237. f = [] # image files
  238. for p in path if isinstance(path, list) else [path]:
  239. p = str(Path(p)) # os-agnostic
  240. parent = str(Path(p).parent) + os.sep
  241. if os.path.isfile(p): # file
  242. with open(p, 'r') as t:
  243. t = t.read().splitlines()
  244. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  245. elif os.path.isdir(p): # folder
  246. f += glob.iglob(p + os.sep + '*.*')
  247. else:
  248. raise Exception('%s does not exist' % p)
  249. self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
  250. except Exception as e:
  251. raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
  252. n = len(self.img_files)
  253. assert n > 0, 'No images found in %s. See %s' % (path, help_url)
  254. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  255. nb = bi[-1] + 1 # number of batches
  256. self.n = n # number of images
  257. self.batch = bi # batch index of image
  258. self.img_size = img_size
  259. self.augment = augment
  260. self.hyp = hyp
  261. self.image_weights = image_weights
  262. self.rect = False if image_weights else rect
  263. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  264. self.mosaic_border = [-img_size // 2, -img_size // 2]
  265. self.stride = stride
  266. # Define labels
  267. self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in
  268. self.img_files]
  269. # Check cache
  270. cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
  271. if os.path.isfile(cache_path):
  272. cache = torch.load(cache_path) # load
  273. if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
  274. cache = self.cache_labels(cache_path) # re-cache
  275. else:
  276. cache = self.cache_labels(cache_path) # cache
  277. # Get labels
  278. labels, shapes = zip(*[cache[x] for x in self.img_files])
  279. self.shapes = np.array(shapes, dtype=np.float64)
  280. self.labels = list(labels)
  281. # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
  282. if self.rect:
  283. # Sort by aspect ratio
  284. s = self.shapes # wh
  285. ar = s[:, 1] / s[:, 0] # aspect ratio
  286. irect = ar.argsort()
  287. self.img_files = [self.img_files[i] for i in irect]
  288. self.label_files = [self.label_files[i] for i in irect]
  289. self.labels = [self.labels[i] for i in irect]
  290. self.shapes = s[irect] # wh
  291. ar = ar[irect]
  292. # Set training image shapes
  293. shapes = [[1, 1]] * nb
  294. for i in range(nb):
  295. ari = ar[bi == i]
  296. mini, maxi = ari.min(), ari.max()
  297. if maxi < 1:
  298. shapes[i] = [maxi, 1]
  299. elif mini > 1:
  300. shapes[i] = [1, 1 / mini]
  301. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  302. # Cache labels
  303. create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
  304. nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
  305. pbar = tqdm(self.label_files)
  306. for i, file in enumerate(pbar):
  307. l = self.labels[i] # label
  308. if l.shape[0]:
  309. assert l.shape[1] == 5, '> 5 label columns: %s' % file
  310. assert (l >= 0).all(), 'negative labels: %s' % file
  311. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
  312. if np.unique(l, axis=0).shape[0] < l.shape[0]: # duplicate rows
  313. nd += 1 # print('WARNING: duplicate rows in %s' % self.label_files[i]) # duplicate rows
  314. if single_cls:
  315. l[:, 0] = 0 # force dataset into single-class mode
  316. self.labels[i] = l
  317. nf += 1 # file found
  318. # Create subdataset (a smaller dataset)
  319. if create_datasubset and ns < 1E4:
  320. if ns == 0:
  321. create_folder(path='./datasubset')
  322. os.makedirs('./datasubset/images')
  323. exclude_classes = 43
  324. if exclude_classes not in l[:, 0]:
  325. ns += 1
  326. # shutil.copy(src=self.img_files[i], dst='./datasubset/images/') # copy image
  327. with open('./datasubset/images.txt', 'a') as f:
  328. f.write(self.img_files[i] + '\n')
  329. # Extract object detection boxes for a second stage classifier
  330. if extract_bounding_boxes:
  331. p = Path(self.img_files[i])
  332. img = cv2.imread(str(p))
  333. h, w = img.shape[:2]
  334. for j, x in enumerate(l):
  335. f = '%s%sclassifier%s%g_%g_%s' % (p.parent.parent, os.sep, os.sep, x[0], j, p.name)
  336. if not os.path.exists(Path(f).parent):
  337. os.makedirs(Path(f).parent) # make new output folder
  338. b = x[1:] * [w, h, w, h] # box
  339. b[2:] = b[2:].max() # rectangle to square
  340. b[2:] = b[2:] * 1.3 + 30 # pad
  341. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  342. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  343. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  344. assert cv2.imwrite(f, img[b[1]:b[3], b[0]:b[2]]), 'Failure extracting classifier boxes'
  345. else:
  346. ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
  347. # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
  348. pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
  349. cache_path, nf, nm, ne, nd, n)
  350. if nf == 0:
  351. s = 'WARNING: No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
  352. print(s)
  353. assert not augment, '%s. Can not train without labels.' % s
  354. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  355. self.imgs = [None] * n
  356. if cache_images:
  357. gb = 0 # Gigabytes of cached images
  358. pbar = tqdm(range(len(self.img_files)), desc='Caching images')
  359. self.img_hw0, self.img_hw = [None] * n, [None] * n
  360. for i in pbar: # max 10k images
  361. self.imgs[i], self.img_hw0[i], self.img_hw[i] = load_image(self, i) # img, hw_original, hw_resized
  362. gb += self.imgs[i].nbytes
  363. pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
  364. def cache_labels(self, path='labels.cache'):
  365. # Cache dataset labels, check images and read shapes
  366. x = {} # dict
  367. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  368. for (img, label) in pbar:
  369. try:
  370. l = []
  371. image = Image.open(img)
  372. image.verify() # PIL verify
  373. # _ = io.imread(img) # skimage verify (from skimage import io)
  374. shape = exif_size(image) # image size
  375. assert (shape[0] > 9) & (shape[1] > 9), 'image size <10 pixels'
  376. if os.path.isfile(label):
  377. with open(label, 'r') as f:
  378. l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels
  379. if len(l) == 0:
  380. l = np.zeros((0, 5), dtype=np.float32)
  381. x[img] = [l, shape]
  382. except Exception as e:
  383. x[img] = None
  384. print('WARNING: %s: %s' % (img, e))
  385. x['hash'] = get_hash(self.label_files + self.img_files)
  386. torch.save(x, path) # save for next time
  387. return x
  388. def __len__(self):
  389. return len(self.img_files)
  390. # def __iter__(self):
  391. # self.count = -1
  392. # print('ran dataset iter')
  393. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  394. # return self
  395. def __getitem__(self, index):
  396. if self.image_weights:
  397. index = self.indices[index]
  398. hyp = self.hyp
  399. if self.mosaic:
  400. # Load mosaic
  401. img, labels = load_mosaic(self, index)
  402. shapes = None
  403. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  404. # if random.random() < 0.5:
  405. # img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
  406. # r = np.random.beta(0.3, 0.3) # mixup ratio, alpha=beta=0.3
  407. # img = (img * r + img2 * (1 - r)).astype(np.uint8)
  408. # labels = np.concatenate((labels, labels2), 0)
  409. else:
  410. # Load image
  411. img, (h0, w0), (h, w) = load_image(self, index)
  412. # Letterbox
  413. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  414. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  415. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  416. # Load labels
  417. labels = []
  418. x = self.labels[index]
  419. if x.size > 0:
  420. # Normalized xywh to pixel xyxy format
  421. labels = x.copy()
  422. labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
  423. labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
  424. labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
  425. labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
  426. if self.augment:
  427. # Augment imagespace
  428. if not self.mosaic:
  429. img, labels = random_affine(img, labels,
  430. degrees=hyp['degrees'],
  431. translate=hyp['translate'],
  432. scale=hyp['scale'],
  433. shear=hyp['shear'])
  434. # Augment colorspace
  435. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  436. # Apply cutouts
  437. # if random.random() < 0.9:
  438. # labels = cutout(img, labels)
  439. nL = len(labels) # number of labels
  440. if nL:
  441. # convert xyxy to xywh
  442. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])
  443. # Normalize coordinates 0 - 1
  444. labels[:, [2, 4]] /= img.shape[0] # height
  445. labels[:, [1, 3]] /= img.shape[1] # width
  446. if self.augment:
  447. # random left-right flip
  448. lr_flip = True
  449. if lr_flip and random.random() < 0.5:
  450. img = np.fliplr(img)
  451. if nL:
  452. labels[:, 1] = 1 - labels[:, 1]
  453. # random up-down flip
  454. ud_flip = False
  455. if ud_flip and random.random() < 0.5:
  456. img = np.flipud(img)
  457. if nL:
  458. labels[:, 2] = 1 - labels[:, 2]
  459. labels_out = torch.zeros((nL, 6))
  460. if nL:
  461. labels_out[:, 1:] = torch.from_numpy(labels)
  462. # Convert
  463. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  464. img = np.ascontiguousarray(img)
  465. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  466. @staticmethod
  467. def collate_fn(batch):
  468. img, label, path, shapes = zip(*batch) # transposed
  469. for i, l in enumerate(label):
  470. l[:, 0] = i # add target image index for build_targets()
  471. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  472. def load_image(self, index):
  473. # loads 1 image from dataset, returns img, original hw, resized hw
  474. img = self.imgs[index]
  475. if img is None: # not cached
  476. path = self.img_files[index]
  477. img = cv2.imread(path) # BGR
  478. assert img is not None, 'Image Not Found ' + path
  479. h0, w0 = img.shape[:2] # orig hw
  480. r = self.img_size / max(h0, w0) # resize image to img_size
  481. if r != 1: # always resize down, only resize up if training with augmentation
  482. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  483. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  484. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  485. else:
  486. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  487. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  488. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  489. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  490. dtype = img.dtype # uint8
  491. x = np.arange(0, 256, dtype=np.int16)
  492. lut_hue = ((x * r[0]) % 180).astype(dtype)
  493. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  494. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  495. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  496. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  497. # Histogram equalization
  498. # if random.random() < 0.2:
  499. # for i in range(3):
  500. # img[:, :, i] = cv2.equalizeHist(img[:, :, i])
  501. def load_mosaic(self, index):
  502. # loads images in a mosaic
  503. labels4 = []
  504. s = self.img_size
  505. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  506. indices = [index] + [random.randint(0, len(self.labels) - 1) for _ in range(3)] # 3 additional image indices
  507. for i, index in enumerate(indices):
  508. # Load image
  509. img, _, (h, w) = load_image(self, index)
  510. # place img in img4
  511. if i == 0: # top left
  512. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  513. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  514. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  515. elif i == 1: # top right
  516. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  517. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  518. elif i == 2: # bottom left
  519. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  520. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(y2a - y1a, h)
  521. elif i == 3: # bottom right
  522. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  523. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  524. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  525. padw = x1a - x1b
  526. padh = y1a - y1b
  527. # Labels
  528. x = self.labels[index]
  529. labels = x.copy()
  530. if x.size > 0: # Normalized xywh to pixel xyxy format
  531. labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padw
  532. labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + padh
  533. labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padw
  534. labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + padh
  535. labels4.append(labels)
  536. # Concat/clip labels
  537. if len(labels4):
  538. labels4 = np.concatenate(labels4, 0)
  539. # np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop
  540. np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine
  541. # Replicate
  542. # img4, labels4 = replicate(img4, labels4)
  543. # Augment
  544. # img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
  545. img4, labels4 = random_affine(img4, labels4,
  546. degrees=self.hyp['degrees'],
  547. translate=self.hyp['translate'],
  548. scale=self.hyp['scale'],
  549. shear=self.hyp['shear'],
  550. border=self.mosaic_border) # border to remove
  551. return img4, labels4
  552. def replicate(img, labels):
  553. # Replicate labels
  554. h, w = img.shape[:2]
  555. boxes = labels[:, 1:].astype(int)
  556. x1, y1, x2, y2 = boxes.T
  557. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  558. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  559. x1b, y1b, x2b, y2b = boxes[i]
  560. bh, bw = y2b - y1b, x2b - x1b
  561. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  562. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  563. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  564. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  565. return img, labels
  566. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
  567. # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
  568. shape = img.shape[:2] # current shape [height, width]
  569. if isinstance(new_shape, int):
  570. new_shape = (new_shape, new_shape)
  571. # Scale ratio (new / old)
  572. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  573. if not scaleup: # only scale down, do not scale up (for better test mAP)
  574. r = min(r, 1.0)
  575. # Compute padding
  576. ratio = r, r # width, height ratios
  577. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  578. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  579. if auto: # minimum rectangle
  580. dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
  581. elif scaleFill: # stretch
  582. dw, dh = 0.0, 0.0
  583. new_unpad = (new_shape[1], new_shape[0])
  584. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  585. dw /= 2 # divide padding into 2 sides
  586. dh /= 2
  587. if shape[::-1] != new_unpad: # resize
  588. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  589. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  590. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  591. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  592. return img, ratio, (dw, dh)
  593. def random_affine(img, targets=(), degrees=10, translate=.1, scale=.1, shear=10, border=(0, 0)):
  594. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  595. # https://medium.com/uruvideo/dataset-augmentation-with-random-homographies-a8f4b44830d4
  596. # targets = [cls, xyxy]
  597. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  598. width = img.shape[1] + border[1] * 2
  599. # Rotation and Scale
  600. R = np.eye(3)
  601. a = random.uniform(-degrees, degrees)
  602. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  603. s = random.uniform(1 - scale, 1 + scale)
  604. # s = 2 ** random.uniform(-scale, scale)
  605. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s)
  606. # Translation
  607. T = np.eye(3)
  608. T[0, 2] = random.uniform(-translate, translate) * img.shape[1] + border[1] # x translation (pixels)
  609. T[1, 2] = random.uniform(-translate, translate) * img.shape[0] + border[0] # y translation (pixels)
  610. # Shear
  611. S = np.eye(3)
  612. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  613. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  614. # Combined rotation matrix
  615. M = S @ T @ R # ORDER IS IMPORTANT HERE!!
  616. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  617. img = cv2.warpAffine(img, M[:2], dsize=(width, height), flags=cv2.INTER_LINEAR, borderValue=(114, 114, 114))
  618. # Transform label coordinates
  619. n = len(targets)
  620. if n:
  621. # warp points
  622. xy = np.ones((n * 4, 3))
  623. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  624. xy = (xy @ M.T)[:, :2].reshape(n, 8)
  625. # create new boxes
  626. x = xy[:, [0, 2, 4, 6]]
  627. y = xy[:, [1, 3, 5, 7]]
  628. xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  629. # # apply angle-based reduction of bounding boxes
  630. # radians = a * math.pi / 180
  631. # reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
  632. # x = (xy[:, 2] + xy[:, 0]) / 2
  633. # y = (xy[:, 3] + xy[:, 1]) / 2
  634. # w = (xy[:, 2] - xy[:, 0]) * reduction
  635. # h = (xy[:, 3] - xy[:, 1]) * reduction
  636. # xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T
  637. # reject warped points outside of image
  638. xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
  639. xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
  640. w = xy[:, 2] - xy[:, 0]
  641. h = xy[:, 3] - xy[:, 1]
  642. area = w * h
  643. area0 = (targets[:, 3] - targets[:, 1]) * (targets[:, 4] - targets[:, 2])
  644. ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16)) # aspect ratio
  645. i = (w > 2) & (h > 2) & (area / (area0 * s + 1e-16) > 0.2) & (ar < 20)
  646. targets = targets[i]
  647. targets[:, 1:5] = xy[i]
  648. return img, targets
  649. def cutout(image, labels):
  650. # https://arxiv.org/abs/1708.04552
  651. # https://github.com/hysts/pytorch_cutout/blob/master/dataloader.py
  652. # https://towardsdatascience.com/when-conventional-wisdom-fails-revisiting-data-augmentation-for-self-driving-cars-4831998c5509
  653. h, w = image.shape[:2]
  654. def bbox_ioa(box1, box2):
  655. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  656. box2 = box2.transpose()
  657. # Get the coordinates of bounding boxes
  658. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  659. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  660. # Intersection area
  661. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  662. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  663. # box2 area
  664. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  665. # Intersection over box2 area
  666. return inter_area / box2_area
  667. # create random masks
  668. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  669. for s in scales:
  670. mask_h = random.randint(1, int(h * s))
  671. mask_w = random.randint(1, int(w * s))
  672. # box
  673. xmin = max(0, random.randint(0, w) - mask_w // 2)
  674. ymin = max(0, random.randint(0, h) - mask_h // 2)
  675. xmax = min(w, xmin + mask_w)
  676. ymax = min(h, ymin + mask_h)
  677. # apply random color mask
  678. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  679. # return unobscured labels
  680. if len(labels) and s > 0.03:
  681. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  682. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  683. labels = labels[ioa < 0.60] # remove >60% obscured labels
  684. return labels
  685. def reduce_img_size(path='../data/sm4/images', img_size=1024): # from utils.datasets import *; reduce_img_size()
  686. # creates a new ./images_reduced folder with reduced size images of maximum size img_size
  687. path_new = path + '_reduced' # reduced images path
  688. create_folder(path_new)
  689. for f in tqdm(glob.glob('%s/*.*' % path)):
  690. try:
  691. img = cv2.imread(f)
  692. h, w = img.shape[:2]
  693. r = img_size / max(h, w) # size ratio
  694. if r < 1.0:
  695. img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_AREA) # _LINEAR fastest
  696. fnew = f.replace(path, path_new) # .replace(Path(f).suffix, '.jpg')
  697. cv2.imwrite(fnew, img)
  698. except:
  699. print('WARNING: image failure %s' % f)
  700. def convert_images2bmp(): # from utils.datasets import *; convert_images2bmp()
  701. # Save images
  702. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  703. # for path in ['../coco/images/val2014', '../coco/images/train2014']:
  704. for path in ['../data/sm4/images', '../data/sm4/background']:
  705. create_folder(path + 'bmp')
  706. for ext in formats: # ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.dng']
  707. for f in tqdm(glob.glob('%s/*%s' % (path, ext)), desc='Converting %s' % ext):
  708. cv2.imwrite(f.replace(ext.lower(), '.bmp').replace(path, path + 'bmp'), cv2.imread(f))
  709. # Save labels
  710. # for path in ['../coco/trainvalno5k.txt', '../coco/5k.txt']:
  711. for file in ['../data/sm4/out_train.txt', '../data/sm4/out_test.txt']:
  712. with open(file, 'r') as f:
  713. lines = f.read()
  714. # lines = f.read().replace('2014/', '2014bmp/') # coco
  715. lines = lines.replace('/images', '/imagesbmp')
  716. lines = lines.replace('/background', '/backgroundbmp')
  717. for ext in formats:
  718. lines = lines.replace(ext, '.bmp')
  719. with open(file.replace('.txt', 'bmp.txt'), 'w') as f:
  720. f.write(lines)
  721. def recursive_dataset2bmp(dataset='../data/sm4_bmp'): # from utils.datasets import *; recursive_dataset2bmp()
  722. # Converts dataset to bmp (for faster training)
  723. formats = [x.lower() for x in img_formats] + [x.upper() for x in img_formats]
  724. for a, b, files in os.walk(dataset):
  725. for file in tqdm(files, desc=a):
  726. p = a + '/' + file
  727. s = Path(file).suffix
  728. if s == '.txt': # replace text
  729. with open(p, 'r') as f:
  730. lines = f.read()
  731. for f in formats:
  732. lines = lines.replace(f, '.bmp')
  733. with open(p, 'w') as f:
  734. f.write(lines)
  735. elif s in formats: # replace image
  736. cv2.imwrite(p.replace(s, '.bmp'), cv2.imread(p))
  737. if s != '.bmp':
  738. os.system("rm '%s'" % p)
  739. def imagelist2folder(path='data/coco_64img.txt'): # from utils.datasets import *; imagelist2folder()
  740. # Copies all the images in a text file (list of images) into a folder
  741. create_folder(path[:-4])
  742. with open(path, 'r') as f:
  743. for line in f.read().splitlines():
  744. os.system('cp "%s" %s' % (line, path[:-4]))
  745. print(line)
  746. def create_folder(path='./new_folder'):
  747. # Create folder
  748. if os.path.exists(path):
  749. shutil.rmtree(path) # delete output folder
  750. os.makedirs(path) # make new output folder