You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1130 lines
47KB

  1. # Dataset utils and dataloaders
  2. import glob
  3. import hashlib
  4. import json
  5. import logging
  6. import math
  7. import os
  8. import random
  9. import shutil
  10. import time
  11. from itertools import repeat
  12. from multiprocessing.pool import ThreadPool, Pool
  13. from pathlib import Path
  14. from threading import Thread
  15. import cv2
  16. import numpy as np
  17. import torch
  18. import torch.nn.functional as F
  19. import yaml
  20. from PIL import Image, ExifTags
  21. from torch.utils.data import Dataset
  22. from tqdm import tqdm
  23. from utils.general import check_requirements, check_file, check_dataset, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, \
  24. segment2box, segments2boxes, resample_segments, clean_str
  25. from utils.torch_utils import torch_distributed_zero_first
  26. # Parameters
  27. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  28. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  29. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  30. num_threads = min(8, os.cpu_count()) # number of multiprocessing threads
  31. logger = logging.getLogger(__name__)
  32. # Get orientation exif tag
  33. for orientation in ExifTags.TAGS.keys():
  34. if ExifTags.TAGS[orientation] == 'Orientation':
  35. break
  36. def get_hash(paths):
  37. # Returns a single hash value of a list of paths (files or dirs)
  38. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  39. h = hashlib.md5(str(size).encode()) # hash sizes
  40. h.update(''.join(paths).encode()) # hash paths
  41. return h.hexdigest() # return hash
  42. def exif_size(img):
  43. # Returns exif-corrected PIL size
  44. s = img.size # (width, height)
  45. try:
  46. rotation = dict(img._getexif().items())[orientation]
  47. if rotation == 6: # rotation 270
  48. s = (s[1], s[0])
  49. elif rotation == 8: # rotation 90
  50. s = (s[1], s[0])
  51. except:
  52. pass
  53. return s
  54. def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
  55. rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
  56. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  57. with torch_distributed_zero_first(rank):
  58. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  59. augment=augment, # augment images
  60. hyp=hyp, # augmentation hyperparameters
  61. rect=rect, # rectangular training
  62. cache_images=cache,
  63. single_cls=single_cls,
  64. stride=int(stride),
  65. pad=pad,
  66. image_weights=image_weights,
  67. prefix=prefix)
  68. batch_size = min(batch_size, len(dataset))
  69. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  70. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  71. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  72. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  73. dataloader = loader(dataset,
  74. batch_size=batch_size,
  75. num_workers=nw,
  76. sampler=sampler,
  77. pin_memory=True,
  78. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  79. return dataloader, dataset
  80. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  81. """ Dataloader that reuses workers
  82. Uses same syntax as vanilla DataLoader
  83. """
  84. def __init__(self, *args, **kwargs):
  85. super().__init__(*args, **kwargs)
  86. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  87. self.iterator = super().__iter__()
  88. def __len__(self):
  89. return len(self.batch_sampler.sampler)
  90. def __iter__(self):
  91. for i in range(len(self)):
  92. yield next(self.iterator)
  93. class _RepeatSampler(object):
  94. """ Sampler that repeats forever
  95. Args:
  96. sampler (Sampler)
  97. """
  98. def __init__(self, sampler):
  99. self.sampler = sampler
  100. def __iter__(self):
  101. while True:
  102. yield from iter(self.sampler)
  103. class LoadImages: # for inference
  104. def __init__(self, path, img_size=640, stride=32):
  105. p = str(Path(path).absolute()) # os-agnostic absolute path
  106. if '*' in p:
  107. files = sorted(glob.glob(p, recursive=True)) # glob
  108. elif os.path.isdir(p):
  109. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  110. elif os.path.isfile(p):
  111. files = [p] # files
  112. else:
  113. raise Exception(f'ERROR: {p} does not exist')
  114. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  115. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  116. ni, nv = len(images), len(videos)
  117. self.img_size = img_size
  118. self.stride = stride
  119. self.files = images + videos
  120. self.nf = ni + nv # number of files
  121. self.video_flag = [False] * ni + [True] * nv
  122. self.mode = 'image'
  123. if any(videos):
  124. self.new_video(videos[0]) # new video
  125. else:
  126. self.cap = None
  127. assert self.nf > 0, f'No images or videos found in {p}. ' \
  128. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
  129. def __iter__(self):
  130. self.count = 0
  131. return self
  132. def __next__(self):
  133. if self.count == self.nf:
  134. raise StopIteration
  135. path = self.files[self.count]
  136. if self.video_flag[self.count]:
  137. # Read video
  138. self.mode = 'video'
  139. ret_val, img0 = self.cap.read()
  140. if not ret_val:
  141. self.count += 1
  142. self.cap.release()
  143. if self.count == self.nf: # last video
  144. raise StopIteration
  145. else:
  146. path = self.files[self.count]
  147. self.new_video(path)
  148. ret_val, img0 = self.cap.read()
  149. self.frame += 1
  150. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='')
  151. else:
  152. # Read image
  153. self.count += 1
  154. img0 = cv2.imread(path) # BGR
  155. assert img0 is not None, 'Image Not Found ' + path
  156. print(f'image {self.count}/{self.nf} {path}: ', end='')
  157. # Padded resize
  158. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  159. # Convert
  160. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  161. img = np.ascontiguousarray(img)
  162. return path, img, img0, self.cap
  163. def new_video(self, path):
  164. self.frame = 0
  165. self.cap = cv2.VideoCapture(path)
  166. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  167. def __len__(self):
  168. return self.nf # number of files
  169. class LoadWebcam: # for inference
  170. def __init__(self, pipe='0', img_size=640, stride=32):
  171. self.img_size = img_size
  172. self.stride = stride
  173. if pipe.isnumeric():
  174. pipe = eval(pipe) # local camera
  175. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  176. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  177. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  178. self.pipe = pipe
  179. self.cap = cv2.VideoCapture(pipe) # video capture object
  180. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  181. def __iter__(self):
  182. self.count = -1
  183. return self
  184. def __next__(self):
  185. self.count += 1
  186. if cv2.waitKey(1) == ord('q'): # q to quit
  187. self.cap.release()
  188. cv2.destroyAllWindows()
  189. raise StopIteration
  190. # Read frame
  191. if self.pipe == 0: # local camera
  192. ret_val, img0 = self.cap.read()
  193. img0 = cv2.flip(img0, 1) # flip left-right
  194. else: # IP camera
  195. n = 0
  196. while True:
  197. n += 1
  198. self.cap.grab()
  199. if n % 30 == 0: # skip frames
  200. ret_val, img0 = self.cap.retrieve()
  201. if ret_val:
  202. break
  203. # Print
  204. assert ret_val, f'Camera Error {self.pipe}'
  205. img_path = 'webcam.jpg'
  206. print(f'webcam {self.count}: ', end='')
  207. # Padded resize
  208. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  209. # Convert
  210. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  211. img = np.ascontiguousarray(img)
  212. return img_path, img, img0, None
  213. def __len__(self):
  214. return 0
  215. class LoadStreams: # multiple IP or RTSP cameras
  216. def __init__(self, sources='streams.txt', img_size=640, stride=32):
  217. self.mode = 'stream'
  218. self.img_size = img_size
  219. self.stride = stride
  220. if os.path.isfile(sources):
  221. with open(sources, 'r') as f:
  222. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  223. else:
  224. sources = [sources]
  225. n = len(sources)
  226. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  227. self.sources = [clean_str(x) for x in sources] # clean source names for later
  228. for i, s in enumerate(sources): # index, source
  229. # Start thread to read frames from video stream
  230. print(f'{i + 1}/{n}: {s}... ', end='')
  231. if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
  232. check_requirements(('pafy', 'youtube_dl'))
  233. import pafy
  234. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  235. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  236. cap = cv2.VideoCapture(s)
  237. assert cap.isOpened(), f'Failed to open {s}'
  238. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  239. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  240. self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback
  241. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  242. _, self.imgs[i] = cap.read() # guarantee first frame
  243. self.threads[i] = Thread(target=self.update, args=([i, cap]), daemon=True)
  244. print(f" success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  245. self.threads[i].start()
  246. print('') # newline
  247. # check for common shapes
  248. s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
  249. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  250. if not self.rect:
  251. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  252. def update(self, i, cap):
  253. # Read stream `i` frames in daemon thread
  254. n, f = 0, self.frames[i]
  255. while cap.isOpened() and n < f:
  256. n += 1
  257. # _, self.imgs[index] = cap.read()
  258. cap.grab()
  259. if n % 4: # read every 4th frame
  260. success, im = cap.retrieve()
  261. self.imgs[i] = im if success else self.imgs[i] * 0
  262. time.sleep(1 / self.fps[i]) # wait time
  263. def __iter__(self):
  264. self.count = -1
  265. return self
  266. def __next__(self):
  267. self.count += 1
  268. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  269. cv2.destroyAllWindows()
  270. raise StopIteration
  271. # Letterbox
  272. img0 = self.imgs.copy()
  273. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
  274. # Stack
  275. img = np.stack(img, 0)
  276. # Convert
  277. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  278. img = np.ascontiguousarray(img)
  279. return self.sources, img, img0, None
  280. def __len__(self):
  281. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  282. def img2label_paths(img_paths):
  283. # Define label paths as a function of image paths
  284. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  285. return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]
  286. class LoadImagesAndLabels(Dataset): # for training/testing
  287. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  288. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  289. self.img_size = img_size
  290. self.augment = augment
  291. self.hyp = hyp
  292. self.image_weights = image_weights
  293. self.rect = False if image_weights else rect
  294. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  295. self.mosaic_border = [-img_size // 2, -img_size // 2]
  296. self.stride = stride
  297. self.path = path
  298. try:
  299. f = [] # image files
  300. for p in path if isinstance(path, list) else [path]:
  301. p = Path(p) # os-agnostic
  302. if p.is_dir(): # dir
  303. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  304. # f = list(p.rglob('**/*.*')) # pathlib
  305. elif p.is_file(): # file
  306. with open(p, 'r') as t:
  307. t = t.read().strip().splitlines()
  308. parent = str(p.parent) + os.sep
  309. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  310. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  311. else:
  312. raise Exception(f'{prefix}{p} does not exist')
  313. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
  314. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
  315. assert self.img_files, f'{prefix}No images found'
  316. except Exception as e:
  317. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
  318. # Check cache
  319. self.label_files = img2label_paths(self.img_files) # labels
  320. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
  321. if cache_path.is_file():
  322. cache, exists = torch.load(cache_path), True # load
  323. if cache['hash'] != get_hash(self.label_files + self.img_files): # changed
  324. cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
  325. else:
  326. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  327. # Display cache
  328. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
  329. if exists:
  330. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  331. tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
  332. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
  333. # Read cache
  334. cache.pop('hash') # remove hash
  335. cache.pop('version') # remove version
  336. labels, shapes, self.segments = zip(*cache.values())
  337. self.labels = list(labels)
  338. self.shapes = np.array(shapes, dtype=np.float64)
  339. self.img_files = list(cache.keys()) # update
  340. self.label_files = img2label_paths(cache.keys()) # update
  341. if single_cls:
  342. for x in self.labels:
  343. x[:, 0] = 0
  344. n = len(shapes) # number of images
  345. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  346. nb = bi[-1] + 1 # number of batches
  347. self.batch = bi # batch index of image
  348. self.n = n
  349. self.indices = range(n)
  350. # Rectangular Training
  351. if self.rect:
  352. # Sort by aspect ratio
  353. s = self.shapes # wh
  354. ar = s[:, 1] / s[:, 0] # aspect ratio
  355. irect = ar.argsort()
  356. self.img_files = [self.img_files[i] for i in irect]
  357. self.label_files = [self.label_files[i] for i in irect]
  358. self.labels = [self.labels[i] for i in irect]
  359. self.shapes = s[irect] # wh
  360. ar = ar[irect]
  361. # Set training image shapes
  362. shapes = [[1, 1]] * nb
  363. for i in range(nb):
  364. ari = ar[bi == i]
  365. mini, maxi = ari.min(), ari.max()
  366. if maxi < 1:
  367. shapes[i] = [maxi, 1]
  368. elif mini > 1:
  369. shapes[i] = [1, 1 / mini]
  370. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  371. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  372. self.imgs = [None] * n
  373. if cache_images:
  374. gb = 0 # Gigabytes of cached images
  375. self.img_hw0, self.img_hw = [None] * n, [None] * n
  376. results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
  377. pbar = tqdm(enumerate(results), total=n)
  378. for i, x in pbar:
  379. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
  380. gb += self.imgs[i].nbytes
  381. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
  382. pbar.close()
  383. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  384. # Cache dataset labels, check images and read shapes
  385. x = {} # dict
  386. nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
  387. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  388. with Pool(num_threads) as pool:
  389. pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
  390. desc=desc, total=len(self.img_files))
  391. for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f in pbar:
  392. nm += nm_f
  393. nf += nf_f
  394. ne += ne_f
  395. nc += nc_f
  396. if im_file:
  397. x[im_file] = [l, shape, segments]
  398. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  399. pbar.close()
  400. if nf == 0:
  401. logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
  402. x['hash'] = get_hash(self.label_files + self.img_files)
  403. x['results'] = nf, nm, ne, nc, len(self.img_files)
  404. x['version'] = 0.2 # cache version
  405. try:
  406. torch.save(x, path) # save cache for next time
  407. logging.info(f'{prefix}New cache created: {path}')
  408. except Exception as e:
  409. logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # path not writeable
  410. return x
  411. def __len__(self):
  412. return len(self.img_files)
  413. # def __iter__(self):
  414. # self.count = -1
  415. # print('ran dataset iter')
  416. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  417. # return self
  418. def __getitem__(self, index):
  419. index = self.indices[index] # linear, shuffled, or image_weights
  420. hyp = self.hyp
  421. mosaic = self.mosaic and random.random() < hyp['mosaic']
  422. if mosaic:
  423. # Load mosaic
  424. img, labels = load_mosaic(self, index)
  425. shapes = None
  426. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  427. if random.random() < hyp['mixup']:
  428. img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
  429. r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
  430. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  431. labels = np.concatenate((labels, labels2), 0)
  432. else:
  433. # Load image
  434. img, (h0, w0), (h, w) = load_image(self, index)
  435. # Letterbox
  436. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  437. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  438. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  439. labels = self.labels[index].copy()
  440. if labels.size: # normalized xywh to pixel xyxy format
  441. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  442. if self.augment:
  443. # Augment imagespace
  444. if not mosaic:
  445. img, labels = random_perspective(img, labels,
  446. degrees=hyp['degrees'],
  447. translate=hyp['translate'],
  448. scale=hyp['scale'],
  449. shear=hyp['shear'],
  450. perspective=hyp['perspective'])
  451. # Augment colorspace
  452. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  453. # Apply cutouts
  454. # if random.random() < 0.9:
  455. # labels = cutout(img, labels)
  456. nL = len(labels) # number of labels
  457. if nL:
  458. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  459. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  460. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  461. if self.augment:
  462. # flip up-down
  463. if random.random() < hyp['flipud']:
  464. img = np.flipud(img)
  465. if nL:
  466. labels[:, 2] = 1 - labels[:, 2]
  467. # flip left-right
  468. if random.random() < hyp['fliplr']:
  469. img = np.fliplr(img)
  470. if nL:
  471. labels[:, 1] = 1 - labels[:, 1]
  472. labels_out = torch.zeros((nL, 6))
  473. if nL:
  474. labels_out[:, 1:] = torch.from_numpy(labels)
  475. # Convert
  476. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  477. img = np.ascontiguousarray(img)
  478. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  479. @staticmethod
  480. def collate_fn(batch):
  481. img, label, path, shapes = zip(*batch) # transposed
  482. for i, l in enumerate(label):
  483. l[:, 0] = i # add target image index for build_targets()
  484. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  485. @staticmethod
  486. def collate_fn4(batch):
  487. img, label, path, shapes = zip(*batch) # transposed
  488. n = len(shapes) // 4
  489. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  490. ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
  491. wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
  492. s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
  493. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  494. i *= 4
  495. if random.random() < 0.5:
  496. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
  497. 0].type(img[i].type())
  498. l = label[i]
  499. else:
  500. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  501. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  502. img4.append(im)
  503. label4.append(l)
  504. for i, l in enumerate(label4):
  505. l[:, 0] = i # add target image index for build_targets()
  506. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  507. # Ancillary functions --------------------------------------------------------------------------------------------------
  508. def load_image(self, index):
  509. # loads 1 image from dataset, returns img, original hw, resized hw
  510. img = self.imgs[index]
  511. if img is None: # not cached
  512. path = self.img_files[index]
  513. img = cv2.imread(path) # BGR
  514. assert img is not None, 'Image Not Found ' + path
  515. h0, w0 = img.shape[:2] # orig hw
  516. r = self.img_size / max(h0, w0) # ratio
  517. if r != 1: # if sizes are not equal
  518. img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
  519. interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
  520. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  521. else:
  522. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  523. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  524. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  525. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  526. dtype = img.dtype # uint8
  527. x = np.arange(0, 256, dtype=r.dtype)
  528. lut_hue = ((x * r[0]) % 180).astype(dtype)
  529. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  530. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  531. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
  532. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  533. def hist_equalize(img, clahe=True, bgr=False):
  534. # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
  535. yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
  536. if clahe:
  537. c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  538. yuv[:, :, 0] = c.apply(yuv[:, :, 0])
  539. else:
  540. yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
  541. return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
  542. def load_mosaic(self, index):
  543. # loads images in a 4-mosaic
  544. labels4, segments4 = [], []
  545. s = self.img_size
  546. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  547. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  548. for i, index in enumerate(indices):
  549. # Load image
  550. img, _, (h, w) = load_image(self, index)
  551. # place img in img4
  552. if i == 0: # top left
  553. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  554. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  555. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  556. elif i == 1: # top right
  557. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  558. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  559. elif i == 2: # bottom left
  560. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  561. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  562. elif i == 3: # bottom right
  563. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  564. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  565. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  566. padw = x1a - x1b
  567. padh = y1a - y1b
  568. # Labels
  569. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  570. if labels.size:
  571. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  572. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  573. labels4.append(labels)
  574. segments4.extend(segments)
  575. # Concat/clip labels
  576. labels4 = np.concatenate(labels4, 0)
  577. for x in (labels4[:, 1:], *segments4):
  578. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  579. # img4, labels4 = replicate(img4, labels4) # replicate
  580. # Augment
  581. img4, labels4 = random_perspective(img4, labels4, segments4,
  582. degrees=self.hyp['degrees'],
  583. translate=self.hyp['translate'],
  584. scale=self.hyp['scale'],
  585. shear=self.hyp['shear'],
  586. perspective=self.hyp['perspective'],
  587. border=self.mosaic_border) # border to remove
  588. return img4, labels4
  589. def load_mosaic9(self, index):
  590. # loads images in a 9-mosaic
  591. labels9, segments9 = [], []
  592. s = self.img_size
  593. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  594. for i, index in enumerate(indices):
  595. # Load image
  596. img, _, (h, w) = load_image(self, index)
  597. # place img in img9
  598. if i == 0: # center
  599. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  600. h0, w0 = h, w
  601. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  602. elif i == 1: # top
  603. c = s, s - h, s + w, s
  604. elif i == 2: # top right
  605. c = s + wp, s - h, s + wp + w, s
  606. elif i == 3: # right
  607. c = s + w0, s, s + w0 + w, s + h
  608. elif i == 4: # bottom right
  609. c = s + w0, s + hp, s + w0 + w, s + hp + h
  610. elif i == 5: # bottom
  611. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  612. elif i == 6: # bottom left
  613. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  614. elif i == 7: # left
  615. c = s - w, s + h0 - h, s, s + h0
  616. elif i == 8: # top left
  617. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  618. padx, pady = c[:2]
  619. x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
  620. # Labels
  621. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  622. if labels.size:
  623. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  624. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  625. labels9.append(labels)
  626. segments9.extend(segments)
  627. # Image
  628. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  629. hp, wp = h, w # height, width previous
  630. # Offset
  631. yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y
  632. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  633. # Concat/clip labels
  634. labels9 = np.concatenate(labels9, 0)
  635. labels9[:, [1, 3]] -= xc
  636. labels9[:, [2, 4]] -= yc
  637. c = np.array([xc, yc]) # centers
  638. segments9 = [x - c for x in segments9]
  639. for x in (labels9[:, 1:], *segments9):
  640. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  641. # img9, labels9 = replicate(img9, labels9) # replicate
  642. # Augment
  643. img9, labels9 = random_perspective(img9, labels9, segments9,
  644. degrees=self.hyp['degrees'],
  645. translate=self.hyp['translate'],
  646. scale=self.hyp['scale'],
  647. shear=self.hyp['shear'],
  648. perspective=self.hyp['perspective'],
  649. border=self.mosaic_border) # border to remove
  650. return img9, labels9
  651. def replicate(img, labels):
  652. # Replicate labels
  653. h, w = img.shape[:2]
  654. boxes = labels[:, 1:].astype(int)
  655. x1, y1, x2, y2 = boxes.T
  656. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  657. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  658. x1b, y1b, x2b, y2b = boxes[i]
  659. bh, bw = y2b - y1b, x2b - x1b
  660. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  661. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  662. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  663. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  664. return img, labels
  665. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  666. # Resize and pad image while meeting stride-multiple constraints
  667. shape = img.shape[:2] # current shape [height, width]
  668. if isinstance(new_shape, int):
  669. new_shape = (new_shape, new_shape)
  670. # Scale ratio (new / old)
  671. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  672. if not scaleup: # only scale down, do not scale up (for better test mAP)
  673. r = min(r, 1.0)
  674. # Compute padding
  675. ratio = r, r # width, height ratios
  676. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  677. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  678. if auto: # minimum rectangle
  679. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  680. elif scaleFill: # stretch
  681. dw, dh = 0.0, 0.0
  682. new_unpad = (new_shape[1], new_shape[0])
  683. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  684. dw /= 2 # divide padding into 2 sides
  685. dh /= 2
  686. if shape[::-1] != new_unpad: # resize
  687. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  688. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  689. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  690. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  691. return img, ratio, (dw, dh)
  692. def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
  693. border=(0, 0)):
  694. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  695. # targets = [cls, xyxy]
  696. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  697. width = img.shape[1] + border[1] * 2
  698. # Center
  699. C = np.eye(3)
  700. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  701. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  702. # Perspective
  703. P = np.eye(3)
  704. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  705. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  706. # Rotation and Scale
  707. R = np.eye(3)
  708. a = random.uniform(-degrees, degrees)
  709. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  710. s = random.uniform(1 - scale, 1 + scale)
  711. # s = 2 ** random.uniform(-scale, scale)
  712. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  713. # Shear
  714. S = np.eye(3)
  715. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  716. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  717. # Translation
  718. T = np.eye(3)
  719. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  720. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  721. # Combined rotation matrix
  722. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  723. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  724. if perspective:
  725. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  726. else: # affine
  727. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  728. # Visualize
  729. # import matplotlib.pyplot as plt
  730. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  731. # ax[0].imshow(img[:, :, ::-1]) # base
  732. # ax[1].imshow(img2[:, :, ::-1]) # warped
  733. # Transform label coordinates
  734. n = len(targets)
  735. if n:
  736. use_segments = any(x.any() for x in segments)
  737. new = np.zeros((n, 4))
  738. if use_segments: # warp segments
  739. segments = resample_segments(segments) # upsample
  740. for i, segment in enumerate(segments):
  741. xy = np.ones((len(segment), 3))
  742. xy[:, :2] = segment
  743. xy = xy @ M.T # transform
  744. xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
  745. # clip
  746. new[i] = segment2box(xy, width, height)
  747. else: # warp boxes
  748. xy = np.ones((n * 4, 3))
  749. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  750. xy = xy @ M.T # transform
  751. xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
  752. # create new boxes
  753. x = xy[:, [0, 2, 4, 6]]
  754. y = xy[:, [1, 3, 5, 7]]
  755. new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  756. # clip
  757. new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
  758. new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
  759. # filter candidates
  760. i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
  761. targets = targets[i]
  762. targets[:, 1:5] = new[i]
  763. return img, targets
  764. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
  765. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  766. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  767. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  768. ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
  769. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
  770. def cutout(image, labels):
  771. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  772. h, w = image.shape[:2]
  773. def bbox_ioa(box1, box2):
  774. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  775. box2 = box2.transpose()
  776. # Get the coordinates of bounding boxes
  777. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  778. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  779. # Intersection area
  780. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  781. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  782. # box2 area
  783. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  784. # Intersection over box2 area
  785. return inter_area / box2_area
  786. # create random masks
  787. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  788. for s in scales:
  789. mask_h = random.randint(1, int(h * s))
  790. mask_w = random.randint(1, int(w * s))
  791. # box
  792. xmin = max(0, random.randint(0, w) - mask_w // 2)
  793. ymin = max(0, random.randint(0, h) - mask_h // 2)
  794. xmax = min(w, xmin + mask_w)
  795. ymax = min(h, ymin + mask_h)
  796. # apply random color mask
  797. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  798. # return unobscured labels
  799. if len(labels) and s > 0.03:
  800. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  801. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  802. labels = labels[ioa < 0.60] # remove >60% obscured labels
  803. return labels
  804. def create_folder(path='./new'):
  805. # Create folder
  806. if os.path.exists(path):
  807. shutil.rmtree(path) # delete output folder
  808. os.makedirs(path) # make new output folder
  809. def flatten_recursive(path='../coco128'):
  810. # Flatten a recursive directory by bringing all files to top level
  811. new_path = Path(path + '_flat')
  812. create_folder(new_path)
  813. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  814. shutil.copyfile(file, new_path / Path(file).name)
  815. def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128')
  816. # Convert detection dataset into classification dataset, with one directory per class
  817. path = Path(path) # images dir
  818. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  819. files = list(path.rglob('*.*'))
  820. n = len(files) # number of files
  821. for im_file in tqdm(files, total=n):
  822. if im_file.suffix[1:] in img_formats:
  823. # image
  824. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  825. h, w = im.shape[:2]
  826. # labels
  827. lb_file = Path(img2label_paths([str(im_file)])[0])
  828. if Path(lb_file).exists():
  829. with open(lb_file, 'r') as f:
  830. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  831. for j, x in enumerate(lb):
  832. c = int(x[0]) # class
  833. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  834. if not f.parent.is_dir():
  835. f.parent.mkdir(parents=True)
  836. b = x[1:] * [w, h, w, h] # box
  837. # b[2:] = b[2:].max() # rectangle to square
  838. b[2:] = b[2:] * 1.2 + 3 # pad
  839. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  840. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  841. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  842. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  843. def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False):
  844. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  845. Usage: from utils.datasets import *; autosplit('../coco128')
  846. Arguments
  847. path: Path to images directory
  848. weights: Train, val, test weights (list)
  849. annotated_only: Only use images with an annotated txt file
  850. """
  851. path = Path(path) # images dir
  852. files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only
  853. n = len(files) # number of files
  854. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  855. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  856. [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing
  857. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  858. for i, img in tqdm(zip(indices, files), total=n):
  859. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  860. with open(path / txt[i], 'a') as f:
  861. f.write(str(img) + '\n') # add image to txt file
  862. def verify_image_label(args):
  863. # Verify one image-label pair
  864. im_file, lb_file, prefix = args
  865. nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
  866. try:
  867. # verify images
  868. im = Image.open(im_file)
  869. im.verify() # PIL verify
  870. shape = exif_size(im) # image size
  871. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  872. assert im.format.lower() in img_formats, f'invalid image format {im.format}'
  873. if im.format.lower() in ('jpg', 'jpeg'):
  874. with open(im_file, 'rb') as f:
  875. f.seek(-2, 2)
  876. assert f.read() == b'\xff\xd9', 'corrupted JPEG'
  877. # verify labels
  878. segments = [] # instance segments
  879. if os.path.isfile(lb_file):
  880. nf = 1 # label found
  881. with open(lb_file, 'r') as f:
  882. l = [x.split() for x in f.read().strip().splitlines() if len(x)]
  883. if any([len(x) > 8 for x in l]): # is segment
  884. classes = np.array([x[0] for x in l], dtype=np.float32)
  885. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
  886. l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  887. l = np.array(l, dtype=np.float32)
  888. if len(l):
  889. assert l.shape[1] == 5, 'labels require 5 columns each'
  890. assert (l >= 0).all(), 'negative labels'
  891. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  892. assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
  893. else:
  894. ne = 1 # label empty
  895. l = np.zeros((0, 5), dtype=np.float32)
  896. else:
  897. nm = 1 # label missing
  898. l = np.zeros((0, 5), dtype=np.float32)
  899. return im_file, l, shape, segments, nm, nf, ne, nc
  900. except Exception as e:
  901. nc = 1
  902. logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
  903. return [None, None, None, None, nm, nf, ne, nc]
  904. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
  905. """ Return dataset statistics dictionary with images and instances counts per split per class
  906. Usage: from utils.datasets import *; dataset_stats('coco128.yaml', verbose=True)
  907. Arguments
  908. path: Path to data.yaml
  909. autodownload: Attempt to download dataset if not found locally
  910. verbose: Print stats dictionary
  911. """
  912. with open(check_file(path)) as f:
  913. data = yaml.safe_load(f) # data dict
  914. check_dataset(data, autodownload) # download dataset if missing
  915. nc = data['nc'] # number of classes
  916. stats = {'nc': nc, 'names': data['names']} # statistics dictionary
  917. for split in 'train', 'val', 'test':
  918. if split not in data:
  919. stats[split] = None # i.e. no test set
  920. continue
  921. x = []
  922. dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
  923. if split == 'train':
  924. cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
  925. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  926. x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
  927. x = np.array(x) # shape(128x80)
  928. stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
  929. 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
  930. 'per_class': (x > 0).sum(0).tolist()},
  931. 'labels': {str(Path(k).name): v.tolist() for k, v in zip(dataset.img_files, dataset.labels)}}
  932. # Save, print and return
  933. with open(cache_path.with_suffix('.json'), 'w') as f:
  934. json.dump(stats, f) # save stats *.json
  935. if verbose:
  936. print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
  937. # print(json.dumps(stats, indent=2, sort_keys=False))
  938. return stats