Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

1141 lines
47KB

  1. # Dataset utils and dataloaders
  2. import glob
  3. import hashlib
  4. import json
  5. import logging
  6. import math
  7. import os
  8. import random
  9. import shutil
  10. import time
  11. from itertools import repeat
  12. from multiprocessing.pool import ThreadPool, Pool
  13. from pathlib import Path
  14. from threading import Thread
  15. import cv2
  16. import numpy as np
  17. import torch
  18. import torch.nn.functional as F
  19. import yaml
  20. from PIL import Image, ExifTags
  21. from torch.utils.data import Dataset
  22. from tqdm import tqdm
  23. from utils.general import check_requirements, check_file, check_dataset, xywh2xyxy, xywhn2xyxy, xyxy2xywhn, \
  24. xyn2xy, segment2box, segments2boxes, resample_segments, clean_str
  25. from utils.torch_utils import torch_distributed_zero_first
  26. # Parameters
  27. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  28. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  29. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  30. num_threads = min(8, os.cpu_count()) # number of multiprocessing threads
  31. logger = logging.getLogger(__name__)
  32. # Get orientation exif tag
  33. for orientation in ExifTags.TAGS.keys():
  34. if ExifTags.TAGS[orientation] == 'Orientation':
  35. break
  36. def get_hash(paths):
  37. # Returns a single hash value of a list of paths (files or dirs)
  38. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  39. h = hashlib.md5(str(size).encode()) # hash sizes
  40. h.update(''.join(paths).encode()) # hash paths
  41. return h.hexdigest() # return hash
  42. def exif_size(img):
  43. # Returns exif-corrected PIL size
  44. s = img.size # (width, height)
  45. try:
  46. rotation = dict(img._getexif().items())[orientation]
  47. if rotation == 6: # rotation 270
  48. s = (s[1], s[0])
  49. elif rotation == 8: # rotation 90
  50. s = (s[1], s[0])
  51. except:
  52. pass
  53. return s
  54. def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
  55. rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
  56. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  57. with torch_distributed_zero_first(rank):
  58. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  59. augment=augment, # augment images
  60. hyp=hyp, # augmentation hyperparameters
  61. rect=rect, # rectangular training
  62. cache_images=cache,
  63. single_cls=single_cls,
  64. stride=int(stride),
  65. pad=pad,
  66. image_weights=image_weights,
  67. prefix=prefix)
  68. batch_size = min(batch_size, len(dataset))
  69. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
  70. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  71. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  72. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  73. dataloader = loader(dataset,
  74. batch_size=batch_size,
  75. num_workers=nw,
  76. sampler=sampler,
  77. pin_memory=True,
  78. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  79. return dataloader, dataset
  80. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  81. """ Dataloader that reuses workers
  82. Uses same syntax as vanilla DataLoader
  83. """
  84. def __init__(self, *args, **kwargs):
  85. super().__init__(*args, **kwargs)
  86. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  87. self.iterator = super().__iter__()
  88. def __len__(self):
  89. return len(self.batch_sampler.sampler)
  90. def __iter__(self):
  91. for i in range(len(self)):
  92. yield next(self.iterator)
  93. class _RepeatSampler(object):
  94. """ Sampler that repeats forever
  95. Args:
  96. sampler (Sampler)
  97. """
  98. def __init__(self, sampler):
  99. self.sampler = sampler
  100. def __iter__(self):
  101. while True:
  102. yield from iter(self.sampler)
  103. class LoadImages: # for inference
  104. def __init__(self, path, img_size=640, stride=32):
  105. p = str(Path(path).absolute()) # os-agnostic absolute path
  106. if '*' in p:
  107. files = sorted(glob.glob(p, recursive=True)) # glob
  108. elif os.path.isdir(p):
  109. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  110. elif os.path.isfile(p):
  111. files = [p] # files
  112. else:
  113. raise Exception(f'ERROR: {p} does not exist')
  114. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  115. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  116. ni, nv = len(images), len(videos)
  117. self.img_size = img_size
  118. self.stride = stride
  119. self.files = images + videos
  120. self.nf = ni + nv # number of files
  121. self.video_flag = [False] * ni + [True] * nv
  122. self.mode = 'image'
  123. if any(videos):
  124. self.new_video(videos[0]) # new video
  125. else:
  126. self.cap = None
  127. assert self.nf > 0, f'No images or videos found in {p}. ' \
  128. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
  129. def __iter__(self):
  130. self.count = 0
  131. return self
  132. def __next__(self):
  133. if self.count == self.nf:
  134. raise StopIteration
  135. path = self.files[self.count]
  136. if self.video_flag[self.count]:
  137. # Read video
  138. self.mode = 'video'
  139. ret_val, img0 = self.cap.read()
  140. if not ret_val:
  141. self.count += 1
  142. self.cap.release()
  143. if self.count == self.nf: # last video
  144. raise StopIteration
  145. else:
  146. path = self.files[self.count]
  147. self.new_video(path)
  148. ret_val, img0 = self.cap.read()
  149. self.frame += 1
  150. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='')
  151. else:
  152. # Read image
  153. self.count += 1
  154. img0 = cv2.imread(path) # BGR
  155. assert img0 is not None, 'Image Not Found ' + path
  156. print(f'image {self.count}/{self.nf} {path}: ', end='')
  157. # Padded resize
  158. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  159. # Convert
  160. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW
  161. img = np.ascontiguousarray(img)
  162. return path, img, img0, self.cap
  163. def new_video(self, path):
  164. self.frame = 0
  165. self.cap = cv2.VideoCapture(path)
  166. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  167. def __len__(self):
  168. return self.nf # number of files
  169. class LoadWebcam: # for inference
  170. def __init__(self, pipe='0', img_size=640, stride=32):
  171. self.img_size = img_size
  172. self.stride = stride
  173. if pipe.isnumeric():
  174. pipe = eval(pipe) # local camera
  175. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  176. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  177. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  178. self.pipe = pipe
  179. self.cap = cv2.VideoCapture(pipe) # video capture object
  180. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  181. def __iter__(self):
  182. self.count = -1
  183. return self
  184. def __next__(self):
  185. self.count += 1
  186. if cv2.waitKey(1) == ord('q'): # q to quit
  187. self.cap.release()
  188. cv2.destroyAllWindows()
  189. raise StopIteration
  190. # Read frame
  191. if self.pipe == 0: # local camera
  192. ret_val, img0 = self.cap.read()
  193. img0 = cv2.flip(img0, 1) # flip left-right
  194. else: # IP camera
  195. n = 0
  196. while True:
  197. n += 1
  198. self.cap.grab()
  199. if n % 30 == 0: # skip frames
  200. ret_val, img0 = self.cap.retrieve()
  201. if ret_val:
  202. break
  203. # Print
  204. assert ret_val, f'Camera Error {self.pipe}'
  205. img_path = 'webcam.jpg'
  206. print(f'webcam {self.count}: ', end='')
  207. # Padded resize
  208. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  209. # Convert
  210. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB and HWC to CHW
  211. img = np.ascontiguousarray(img)
  212. return img_path, img, img0, None
  213. def __len__(self):
  214. return 0
  215. class LoadStreams: # multiple IP or RTSP cameras
  216. def __init__(self, sources='streams.txt', img_size=640, stride=32):
  217. self.mode = 'stream'
  218. self.img_size = img_size
  219. self.stride = stride
  220. if os.path.isfile(sources):
  221. with open(sources, 'r') as f:
  222. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  223. else:
  224. sources = [sources]
  225. n = len(sources)
  226. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  227. self.sources = [clean_str(x) for x in sources] # clean source names for later
  228. for i, s in enumerate(sources): # index, source
  229. # Start thread to read frames from video stream
  230. print(f'{i + 1}/{n}: {s}... ', end='')
  231. if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
  232. check_requirements(('pafy', 'youtube_dl'))
  233. import pafy
  234. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  235. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  236. cap = cv2.VideoCapture(s)
  237. assert cap.isOpened(), f'Failed to open {s}'
  238. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  239. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  240. self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback
  241. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  242. _, self.imgs[i] = cap.read() # guarantee first frame
  243. self.threads[i] = Thread(target=self.update, args=([i, cap]), daemon=True)
  244. print(f" success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  245. self.threads[i].start()
  246. print('') # newline
  247. # check for common shapes
  248. s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
  249. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  250. if not self.rect:
  251. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  252. def update(self, i, cap):
  253. # Read stream `i` frames in daemon thread
  254. n, f = 0, self.frames[i]
  255. while cap.isOpened() and n < f:
  256. n += 1
  257. # _, self.imgs[index] = cap.read()
  258. cap.grab()
  259. if n % 4: # read every 4th frame
  260. success, im = cap.retrieve()
  261. self.imgs[i] = im if success else self.imgs[i] * 0
  262. time.sleep(1 / self.fps[i]) # wait time
  263. def __iter__(self):
  264. self.count = -1
  265. return self
  266. def __next__(self):
  267. self.count += 1
  268. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  269. cv2.destroyAllWindows()
  270. raise StopIteration
  271. # Letterbox
  272. img0 = self.imgs.copy()
  273. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
  274. # Stack
  275. img = np.stack(img, 0)
  276. # Convert
  277. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB and BHWC to BCHW
  278. img = np.ascontiguousarray(img)
  279. return self.sources, img, img0, None
  280. def __len__(self):
  281. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  282. def img2label_paths(img_paths):
  283. # Define label paths as a function of image paths
  284. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  285. return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
  286. class LoadImagesAndLabels(Dataset): # for training/testing
  287. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  288. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  289. self.img_size = img_size
  290. self.augment = augment
  291. self.hyp = hyp
  292. self.image_weights = image_weights
  293. self.rect = False if image_weights else rect
  294. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  295. self.mosaic_border = [-img_size // 2, -img_size // 2]
  296. self.stride = stride
  297. self.path = path
  298. try:
  299. f = [] # image files
  300. for p in path if isinstance(path, list) else [path]:
  301. p = Path(p) # os-agnostic
  302. if p.is_dir(): # dir
  303. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  304. # f = list(p.rglob('**/*.*')) # pathlib
  305. elif p.is_file(): # file
  306. with open(p, 'r') as t:
  307. t = t.read().strip().splitlines()
  308. parent = str(p.parent) + os.sep
  309. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  310. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  311. else:
  312. raise Exception(f'{prefix}{p} does not exist')
  313. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
  314. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
  315. assert self.img_files, f'{prefix}No images found'
  316. except Exception as e:
  317. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
  318. # Check cache
  319. self.label_files = img2label_paths(self.img_files) # labels
  320. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
  321. if cache_path.is_file():
  322. cache, exists = torch.load(cache_path), True # load
  323. if cache.get('version') != 0.3 or cache.get('hash') != get_hash(self.label_files + self.img_files):
  324. cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
  325. else:
  326. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  327. # Display cache
  328. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
  329. if exists:
  330. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  331. tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
  332. if cache['msgs']:
  333. logging.info('\n'.join(cache['msgs'])) # display warnings
  334. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
  335. # Read cache
  336. [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
  337. labels, shapes, self.segments = zip(*cache.values())
  338. self.labels = list(labels)
  339. self.shapes = np.array(shapes, dtype=np.float64)
  340. self.img_files = list(cache.keys()) # update
  341. self.label_files = img2label_paths(cache.keys()) # update
  342. if single_cls:
  343. for x in self.labels:
  344. x[:, 0] = 0
  345. n = len(shapes) # number of images
  346. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  347. nb = bi[-1] + 1 # number of batches
  348. self.batch = bi # batch index of image
  349. self.n = n
  350. self.indices = range(n)
  351. # Rectangular Training
  352. if self.rect:
  353. # Sort by aspect ratio
  354. s = self.shapes # wh
  355. ar = s[:, 1] / s[:, 0] # aspect ratio
  356. irect = ar.argsort()
  357. self.img_files = [self.img_files[i] for i in irect]
  358. self.label_files = [self.label_files[i] for i in irect]
  359. self.labels = [self.labels[i] for i in irect]
  360. self.shapes = s[irect] # wh
  361. ar = ar[irect]
  362. # Set training image shapes
  363. shapes = [[1, 1]] * nb
  364. for i in range(nb):
  365. ari = ar[bi == i]
  366. mini, maxi = ari.min(), ari.max()
  367. if maxi < 1:
  368. shapes[i] = [maxi, 1]
  369. elif mini > 1:
  370. shapes[i] = [1, 1 / mini]
  371. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  372. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  373. self.imgs = [None] * n
  374. if cache_images:
  375. gb = 0 # Gigabytes of cached images
  376. self.img_hw0, self.img_hw = [None] * n, [None] * n
  377. results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
  378. pbar = tqdm(enumerate(results), total=n)
  379. for i, x in pbar:
  380. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
  381. gb += self.imgs[i].nbytes
  382. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
  383. pbar.close()
  384. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  385. # Cache dataset labels, check images and read shapes
  386. x = {} # dict
  387. nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
  388. desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
  389. with Pool(num_threads) as pool:
  390. pbar = tqdm(pool.imap_unordered(verify_image_label, zip(self.img_files, self.label_files, repeat(prefix))),
  391. desc=desc, total=len(self.img_files))
  392. for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
  393. nm += nm_f
  394. nf += nf_f
  395. ne += ne_f
  396. nc += nc_f
  397. if im_file:
  398. x[im_file] = [l, shape, segments]
  399. if msg:
  400. msgs.append(msg)
  401. pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  402. pbar.close()
  403. if msgs:
  404. logging.info('\n'.join(msgs))
  405. if nf == 0:
  406. logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
  407. x['hash'] = get_hash(self.label_files + self.img_files)
  408. x['results'] = nf, nm, ne, nc, len(self.img_files)
  409. x['msgs'] = msgs # warnings
  410. x['version'] = 0.3 # cache version
  411. try:
  412. torch.save(x, path) # save cache for next time
  413. logging.info(f'{prefix}New cache created: {path}')
  414. except Exception as e:
  415. logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # path not writeable
  416. return x
  417. def __len__(self):
  418. return len(self.img_files)
  419. # def __iter__(self):
  420. # self.count = -1
  421. # print('ran dataset iter')
  422. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  423. # return self
  424. def __getitem__(self, index):
  425. index = self.indices[index] # linear, shuffled, or image_weights
  426. hyp = self.hyp
  427. mosaic = self.mosaic and random.random() < hyp['mosaic']
  428. if mosaic:
  429. # Load mosaic
  430. img, labels = load_mosaic(self, index)
  431. shapes = None
  432. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  433. if random.random() < hyp['mixup']:
  434. img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
  435. r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
  436. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  437. labels = np.concatenate((labels, labels2), 0)
  438. else:
  439. # Load image
  440. img, (h0, w0), (h, w) = load_image(self, index)
  441. # Letterbox
  442. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  443. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  444. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  445. labels = self.labels[index].copy()
  446. if labels.size: # normalized xywh to pixel xyxy format
  447. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  448. if self.augment:
  449. # Augment imagespace
  450. if not mosaic:
  451. img, labels = random_perspective(img, labels,
  452. degrees=hyp['degrees'],
  453. translate=hyp['translate'],
  454. scale=hyp['scale'],
  455. shear=hyp['shear'],
  456. perspective=hyp['perspective'])
  457. # Augment colorspace
  458. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  459. # Apply cutouts
  460. # if random.random() < 0.9:
  461. # labels = cutout(img, labels)
  462. nL = len(labels) # number of labels
  463. if nL:
  464. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0]) # xyxy to xywh normalized
  465. if self.augment:
  466. # flip up-down
  467. if random.random() < hyp['flipud']:
  468. img = np.flipud(img)
  469. if nL:
  470. labels[:, 2] = 1 - labels[:, 2]
  471. # flip left-right
  472. if random.random() < hyp['fliplr']:
  473. img = np.fliplr(img)
  474. if nL:
  475. labels[:, 1] = 1 - labels[:, 1]
  476. labels_out = torch.zeros((nL, 6))
  477. if nL:
  478. labels_out[:, 1:] = torch.from_numpy(labels)
  479. # Convert
  480. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3 x img_height x img_width
  481. img = np.ascontiguousarray(img)
  482. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  483. @staticmethod
  484. def collate_fn(batch):
  485. img, label, path, shapes = zip(*batch) # transposed
  486. for i, l in enumerate(label):
  487. l[:, 0] = i # add target image index for build_targets()
  488. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  489. @staticmethod
  490. def collate_fn4(batch):
  491. img, label, path, shapes = zip(*batch) # transposed
  492. n = len(shapes) // 4
  493. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  494. ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
  495. wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
  496. s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
  497. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  498. i *= 4
  499. if random.random() < 0.5:
  500. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
  501. 0].type(img[i].type())
  502. l = label[i]
  503. else:
  504. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  505. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  506. img4.append(im)
  507. label4.append(l)
  508. for i, l in enumerate(label4):
  509. l[:, 0] = i # add target image index for build_targets()
  510. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  511. # Ancillary functions --------------------------------------------------------------------------------------------------
  512. def load_image(self, index):
  513. # loads 1 image from dataset, returns img, original hw, resized hw
  514. img = self.imgs[index]
  515. if img is None: # not cached
  516. path = self.img_files[index]
  517. img = cv2.imread(path) # BGR
  518. assert img is not None, 'Image Not Found ' + path
  519. h0, w0 = img.shape[:2] # orig hw
  520. r = self.img_size / max(h0, w0) # ratio
  521. if r != 1: # if sizes are not equal
  522. img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
  523. interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
  524. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  525. else:
  526. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  527. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  528. if hgain or sgain or vgain:
  529. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  530. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  531. dtype = img.dtype # uint8
  532. x = np.arange(0, 256, dtype=r.dtype)
  533. lut_hue = ((x * r[0]) % 180).astype(dtype)
  534. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  535. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  536. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
  537. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  538. def hist_equalize(img, clahe=True, bgr=False):
  539. # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
  540. yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
  541. if clahe:
  542. c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  543. yuv[:, :, 0] = c.apply(yuv[:, :, 0])
  544. else:
  545. yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
  546. return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
  547. def load_mosaic(self, index):
  548. # loads images in a 4-mosaic
  549. labels4, segments4 = [], []
  550. s = self.img_size
  551. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  552. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  553. for i, index in enumerate(indices):
  554. # Load image
  555. img, _, (h, w) = load_image(self, index)
  556. # place img in img4
  557. if i == 0: # top left
  558. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  559. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  560. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  561. elif i == 1: # top right
  562. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  563. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  564. elif i == 2: # bottom left
  565. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  566. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  567. elif i == 3: # bottom right
  568. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  569. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  570. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  571. padw = x1a - x1b
  572. padh = y1a - y1b
  573. # Labels
  574. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  575. if labels.size:
  576. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  577. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  578. labels4.append(labels)
  579. segments4.extend(segments)
  580. # Concat/clip labels
  581. labels4 = np.concatenate(labels4, 0)
  582. for x in (labels4[:, 1:], *segments4):
  583. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  584. # img4, labels4 = replicate(img4, labels4) # replicate
  585. # Augment
  586. img4, labels4 = random_perspective(img4, labels4, segments4,
  587. degrees=self.hyp['degrees'],
  588. translate=self.hyp['translate'],
  589. scale=self.hyp['scale'],
  590. shear=self.hyp['shear'],
  591. perspective=self.hyp['perspective'],
  592. border=self.mosaic_border) # border to remove
  593. return img4, labels4
  594. def load_mosaic9(self, index):
  595. # loads images in a 9-mosaic
  596. labels9, segments9 = [], []
  597. s = self.img_size
  598. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  599. for i, index in enumerate(indices):
  600. # Load image
  601. img, _, (h, w) = load_image(self, index)
  602. # place img in img9
  603. if i == 0: # center
  604. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  605. h0, w0 = h, w
  606. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  607. elif i == 1: # top
  608. c = s, s - h, s + w, s
  609. elif i == 2: # top right
  610. c = s + wp, s - h, s + wp + w, s
  611. elif i == 3: # right
  612. c = s + w0, s, s + w0 + w, s + h
  613. elif i == 4: # bottom right
  614. c = s + w0, s + hp, s + w0 + w, s + hp + h
  615. elif i == 5: # bottom
  616. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  617. elif i == 6: # bottom left
  618. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  619. elif i == 7: # left
  620. c = s - w, s + h0 - h, s, s + h0
  621. elif i == 8: # top left
  622. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  623. padx, pady = c[:2]
  624. x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
  625. # Labels
  626. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  627. if labels.size:
  628. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  629. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  630. labels9.append(labels)
  631. segments9.extend(segments)
  632. # Image
  633. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  634. hp, wp = h, w # height, width previous
  635. # Offset
  636. yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y
  637. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  638. # Concat/clip labels
  639. labels9 = np.concatenate(labels9, 0)
  640. labels9[:, [1, 3]] -= xc
  641. labels9[:, [2, 4]] -= yc
  642. c = np.array([xc, yc]) # centers
  643. segments9 = [x - c for x in segments9]
  644. for x in (labels9[:, 1:], *segments9):
  645. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  646. # img9, labels9 = replicate(img9, labels9) # replicate
  647. # Augment
  648. img9, labels9 = random_perspective(img9, labels9, segments9,
  649. degrees=self.hyp['degrees'],
  650. translate=self.hyp['translate'],
  651. scale=self.hyp['scale'],
  652. shear=self.hyp['shear'],
  653. perspective=self.hyp['perspective'],
  654. border=self.mosaic_border) # border to remove
  655. return img9, labels9
  656. def replicate(img, labels):
  657. # Replicate labels
  658. h, w = img.shape[:2]
  659. boxes = labels[:, 1:].astype(int)
  660. x1, y1, x2, y2 = boxes.T
  661. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  662. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  663. x1b, y1b, x2b, y2b = boxes[i]
  664. bh, bw = y2b - y1b, x2b - x1b
  665. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  666. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  667. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  668. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  669. return img, labels
  670. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  671. # Resize and pad image while meeting stride-multiple constraints
  672. shape = img.shape[:2] # current shape [height, width]
  673. if isinstance(new_shape, int):
  674. new_shape = (new_shape, new_shape)
  675. # Scale ratio (new / old)
  676. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  677. if not scaleup: # only scale down, do not scale up (for better test mAP)
  678. r = min(r, 1.0)
  679. # Compute padding
  680. ratio = r, r # width, height ratios
  681. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  682. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  683. if auto: # minimum rectangle
  684. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  685. elif scaleFill: # stretch
  686. dw, dh = 0.0, 0.0
  687. new_unpad = (new_shape[1], new_shape[0])
  688. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  689. dw /= 2 # divide padding into 2 sides
  690. dh /= 2
  691. if shape[::-1] != new_unpad: # resize
  692. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  693. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  694. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  695. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  696. return img, ratio, (dw, dh)
  697. def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
  698. border=(0, 0)):
  699. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  700. # targets = [cls, xyxy]
  701. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  702. width = img.shape[1] + border[1] * 2
  703. # Center
  704. C = np.eye(3)
  705. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  706. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  707. # Perspective
  708. P = np.eye(3)
  709. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  710. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  711. # Rotation and Scale
  712. R = np.eye(3)
  713. a = random.uniform(-degrees, degrees)
  714. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  715. s = random.uniform(1 - scale, 1 + scale)
  716. # s = 2 ** random.uniform(-scale, scale)
  717. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  718. # Shear
  719. S = np.eye(3)
  720. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  721. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  722. # Translation
  723. T = np.eye(3)
  724. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  725. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  726. # Combined rotation matrix
  727. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  728. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  729. if perspective:
  730. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  731. else: # affine
  732. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  733. # Visualize
  734. # import matplotlib.pyplot as plt
  735. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  736. # ax[0].imshow(img[:, :, ::-1]) # base
  737. # ax[1].imshow(img2[:, :, ::-1]) # warped
  738. # Transform label coordinates
  739. n = len(targets)
  740. if n:
  741. use_segments = any(x.any() for x in segments)
  742. new = np.zeros((n, 4))
  743. if use_segments: # warp segments
  744. segments = resample_segments(segments) # upsample
  745. for i, segment in enumerate(segments):
  746. xy = np.ones((len(segment), 3))
  747. xy[:, :2] = segment
  748. xy = xy @ M.T # transform
  749. xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
  750. # clip
  751. new[i] = segment2box(xy, width, height)
  752. else: # warp boxes
  753. xy = np.ones((n * 4, 3))
  754. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  755. xy = xy @ M.T # transform
  756. xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
  757. # create new boxes
  758. x = xy[:, [0, 2, 4, 6]]
  759. y = xy[:, [1, 3, 5, 7]]
  760. new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  761. # clip
  762. new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
  763. new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
  764. # filter candidates
  765. i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
  766. targets = targets[i]
  767. targets[:, 1:5] = new[i]
  768. return img, targets
  769. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
  770. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  771. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  772. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  773. ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
  774. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
  775. def cutout(image, labels):
  776. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  777. h, w = image.shape[:2]
  778. def bbox_ioa(box1, box2):
  779. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  780. box2 = box2.transpose()
  781. # Get the coordinates of bounding boxes
  782. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  783. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  784. # Intersection area
  785. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  786. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  787. # box2 area
  788. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  789. # Intersection over box2 area
  790. return inter_area / box2_area
  791. # create random masks
  792. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  793. for s in scales:
  794. mask_h = random.randint(1, int(h * s))
  795. mask_w = random.randint(1, int(w * s))
  796. # box
  797. xmin = max(0, random.randint(0, w) - mask_w // 2)
  798. ymin = max(0, random.randint(0, h) - mask_h // 2)
  799. xmax = min(w, xmin + mask_w)
  800. ymax = min(h, ymin + mask_h)
  801. # apply random color mask
  802. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  803. # return unobscured labels
  804. if len(labels) and s > 0.03:
  805. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  806. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  807. labels = labels[ioa < 0.60] # remove >60% obscured labels
  808. return labels
  809. def create_folder(path='./new'):
  810. # Create folder
  811. if os.path.exists(path):
  812. shutil.rmtree(path) # delete output folder
  813. os.makedirs(path) # make new output folder
  814. def flatten_recursive(path='../coco128'):
  815. # Flatten a recursive directory by bringing all files to top level
  816. new_path = Path(path + '_flat')
  817. create_folder(new_path)
  818. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  819. shutil.copyfile(file, new_path / Path(file).name)
  820. def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128')
  821. # Convert detection dataset into classification dataset, with one directory per class
  822. path = Path(path) # images dir
  823. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  824. files = list(path.rglob('*.*'))
  825. n = len(files) # number of files
  826. for im_file in tqdm(files, total=n):
  827. if im_file.suffix[1:] in img_formats:
  828. # image
  829. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  830. h, w = im.shape[:2]
  831. # labels
  832. lb_file = Path(img2label_paths([str(im_file)])[0])
  833. if Path(lb_file).exists():
  834. with open(lb_file, 'r') as f:
  835. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  836. for j, x in enumerate(lb):
  837. c = int(x[0]) # class
  838. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  839. if not f.parent.is_dir():
  840. f.parent.mkdir(parents=True)
  841. b = x[1:] * [w, h, w, h] # box
  842. # b[2:] = b[2:].max() # rectangle to square
  843. b[2:] = b[2:] * 1.2 + 3 # pad
  844. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  845. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  846. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  847. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  848. def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False):
  849. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  850. Usage: from utils.datasets import *; autosplit('../coco128')
  851. Arguments
  852. path: Path to images directory
  853. weights: Train, val, test weights (list)
  854. annotated_only: Only use images with an annotated txt file
  855. """
  856. path = Path(path) # images dir
  857. files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only
  858. n = len(files) # number of files
  859. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  860. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  861. [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing
  862. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  863. for i, img in tqdm(zip(indices, files), total=n):
  864. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  865. with open(path / txt[i], 'a') as f:
  866. f.write(str(img) + '\n') # add image to txt file
  867. def verify_image_label(args):
  868. # Verify one image-label pair
  869. im_file, lb_file, prefix = args
  870. nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
  871. try:
  872. # verify images
  873. im = Image.open(im_file)
  874. im.verify() # PIL verify
  875. shape = exif_size(im) # image size
  876. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  877. assert im.format.lower() in img_formats, f'invalid image format {im.format}'
  878. if im.format.lower() in ('jpg', 'jpeg'):
  879. with open(im_file, 'rb') as f:
  880. f.seek(-2, 2)
  881. assert f.read() == b'\xff\xd9', 'corrupted JPEG'
  882. # verify labels
  883. segments = [] # instance segments
  884. if os.path.isfile(lb_file):
  885. nf = 1 # label found
  886. with open(lb_file, 'r') as f:
  887. l = [x.split() for x in f.read().strip().splitlines() if len(x)]
  888. if any([len(x) > 8 for x in l]): # is segment
  889. classes = np.array([x[0] for x in l], dtype=np.float32)
  890. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
  891. l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  892. l = np.array(l, dtype=np.float32)
  893. if len(l):
  894. assert l.shape[1] == 5, 'labels require 5 columns each'
  895. assert (l >= 0).all(), 'negative labels'
  896. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  897. assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
  898. else:
  899. ne = 1 # label empty
  900. l = np.zeros((0, 5), dtype=np.float32)
  901. else:
  902. nm = 1 # label missing
  903. l = np.zeros((0, 5), dtype=np.float32)
  904. return im_file, l, shape, segments, nm, nf, ne, nc, ''
  905. except Exception as e:
  906. nc = 1
  907. msg = f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}'
  908. return [None, None, None, None, nm, nf, ne, nc, msg]
  909. def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
  910. """ Return dataset statistics dictionary with images and instances counts per split per class
  911. Usage: from utils.datasets import *; dataset_stats('coco128.yaml', verbose=True)
  912. Arguments
  913. path: Path to data.yaml
  914. autodownload: Attempt to download dataset if not found locally
  915. verbose: Print stats dictionary
  916. """
  917. def round_labels(labels):
  918. # Update labels to integer class and 6 decimal place floats
  919. return [[int(c), *[round(x, 6) for x in points]] for c, *points in labels]
  920. with open(check_file(path)) as f:
  921. data = yaml.safe_load(f) # data dict
  922. check_dataset(data, autodownload) # download dataset if missing
  923. nc = data['nc'] # number of classes
  924. stats = {'nc': nc, 'names': data['names']} # statistics dictionary
  925. for split in 'train', 'val', 'test':
  926. if split not in data:
  927. stats[split] = None # i.e. no test set
  928. continue
  929. x = []
  930. dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
  931. if split == 'train':
  932. cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
  933. for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
  934. x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
  935. x = np.array(x) # shape(128x80)
  936. stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
  937. 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
  938. 'per_class': (x > 0).sum(0).tolist()},
  939. 'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
  940. zip(dataset.img_files, dataset.labels)]}
  941. # Save, print and return
  942. with open(cache_path.with_suffix('.json'), 'w') as f:
  943. json.dump(stats, f) # save stats *.json
  944. if verbose:
  945. print(json.dumps(stats, indent=2, sort_keys=False))
  946. # print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
  947. return stats