车位角点检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

преди 9 месеца
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074
  1. # Dataset utils and dataloaders
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import shutil
  8. import time
  9. from itertools import repeat
  10. from multiprocessing.pool import ThreadPool
  11. from pathlib import Path
  12. from threading import Thread
  13. import cv2
  14. import numpy as np
  15. import torch
  16. import torch.nn.functional as F
  17. from PIL import Image, ExifTags
  18. from torch.utils.data import Dataset
  19. from tqdm import tqdm
  20. from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, \
  21. resample_segments, clean_str
  22. from utils.torch_utils import torch_distributed_zero_first
  23. # Parameters
  24. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  25. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  26. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  27. logger = logging.getLogger(__name__)
  28. # Get orientation exif tag
  29. for orientation in ExifTags.TAGS.keys():
  30. if ExifTags.TAGS[orientation] == 'Orientation':
  31. break
  32. def get_hash(files):
  33. # Returns a single hash value of a list of files
  34. return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
  35. def exif_size(img):
  36. # Returns exif-corrected PIL size
  37. s = img.size # (width, height)
  38. try:
  39. rotation = dict(img._getexif().items())[orientation]
  40. if rotation == 6: # rotation 270
  41. s = (s[1], s[0])
  42. elif rotation == 8: # rotation 90
  43. s = (s[1], s[0])
  44. except:
  45. pass
  46. return s
  47. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  48. rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
  49. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  50. with torch_distributed_zero_first(rank):
  51. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  52. augment=augment, # augment images
  53. hyp=hyp, # augmentation hyperparameters
  54. rect=rect, # rectangular training
  55. cache_images=cache,
  56. single_cls=opt.single_cls,
  57. stride=int(stride),
  58. pad=pad,
  59. image_weights=image_weights,
  60. prefix=prefix)
  61. batch_size = min(batch_size, len(dataset))
  62. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  63. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  64. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  65. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  66. dataloader = loader(dataset,
  67. batch_size=batch_size,
  68. num_workers=nw,
  69. sampler=sampler,
  70. pin_memory=True,
  71. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  72. return dataloader, dataset
  73. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  74. """ Dataloader that reuses workers
  75. Uses same syntax as vanilla DataLoader
  76. """
  77. def __init__(self, *args, **kwargs):
  78. super().__init__(*args, **kwargs)
  79. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  80. self.iterator = super().__iter__()
  81. def __len__(self):
  82. return len(self.batch_sampler.sampler)
  83. def __iter__(self):
  84. for i in range(len(self)):
  85. yield next(self.iterator)
  86. class _RepeatSampler(object):
  87. """ Sampler that repeats forever
  88. Args:
  89. sampler (Sampler)
  90. """
  91. def __init__(self, sampler):
  92. self.sampler = sampler
  93. def __iter__(self):
  94. while True:
  95. yield from iter(self.sampler)
  96. class LoadImages: # for inference
  97. def __init__(self, path, img_size=640, stride=32):
  98. p = str(Path(path).absolute()) # os-agnostic absolute path
  99. if '*' in p:
  100. files = sorted(glob.glob(p, recursive=True)) # glob
  101. elif os.path.isdir(p):
  102. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  103. elif os.path.isfile(p):
  104. files = [p] # files
  105. else:
  106. raise Exception(f'ERROR: {p} does not exist')
  107. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  108. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  109. ni, nv = len(images), len(videos)
  110. self.img_size = img_size
  111. self.stride = stride
  112. self.files = images + videos
  113. self.nf = ni + nv # number of files
  114. self.video_flag = [False] * ni + [True] * nv
  115. self.mode = 'image'
  116. if any(videos):
  117. self.new_video(videos[0]) # new video
  118. else:
  119. self.cap = None
  120. assert self.nf > 0, f'No images or videos found in {p}. ' \
  121. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
  122. def __iter__(self):
  123. self.count = 0
  124. return self
  125. def __next__(self):
  126. if self.count == self.nf:
  127. raise StopIteration
  128. path = self.files[self.count]
  129. if self.video_flag[self.count]:
  130. # Read video
  131. self.mode = 'video'
  132. ret_val, img0 = self.cap.read()
  133. if not ret_val:
  134. self.count += 1
  135. self.cap.release()
  136. if self.count == self.nf: # last video
  137. raise StopIteration
  138. else:
  139. path = self.files[self.count]
  140. self.new_video(path)
  141. ret_val, img0 = self.cap.read()
  142. self.frame += 1
  143. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
  144. else:
  145. # Read image
  146. self.count += 1
  147. img0 = cv2.imread(path) # BGR
  148. assert img0 is not None, 'Image Not Found ' + path
  149. print(f'image {self.count}/{self.nf} {path}: ', end='')
  150. # Padded resize
  151. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  152. # Convert
  153. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  154. img = np.ascontiguousarray(img)
  155. return path, img, img0, self.cap
  156. def new_video(self, path):
  157. self.frame = 0
  158. self.cap = cv2.VideoCapture(path)
  159. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  160. def __len__(self):
  161. return self.nf # number of files
  162. class LoadWebcam: # for inference
  163. def __init__(self, pipe='0', img_size=640, stride=32):
  164. self.img_size = img_size
  165. self.stride = stride
  166. if pipe.isnumeric():
  167. pipe = eval(pipe) # local camera
  168. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  169. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  170. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  171. self.pipe = pipe
  172. self.cap = cv2.VideoCapture(pipe) # video capture object
  173. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  174. def __iter__(self):
  175. self.count = -1
  176. return self
  177. def __next__(self):
  178. self.count += 1
  179. if cv2.waitKey(1) == ord('q'): # q to quit
  180. self.cap.release()
  181. cv2.destroyAllWindows()
  182. raise StopIteration
  183. # Read frame
  184. if self.pipe == 0: # local camera
  185. ret_val, img0 = self.cap.read()
  186. img0 = cv2.flip(img0, 1) # flip left-right
  187. else: # IP camera
  188. n = 0
  189. while True:
  190. n += 1
  191. self.cap.grab()
  192. if n % 30 == 0: # skip frames
  193. ret_val, img0 = self.cap.retrieve()
  194. if ret_val:
  195. break
  196. # Print
  197. assert ret_val, f'Camera Error {self.pipe}'
  198. img_path = 'webcam.jpg'
  199. print(f'webcam {self.count}: ', end='')
  200. # Padded resize
  201. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  202. # Convert
  203. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  204. img = np.ascontiguousarray(img)
  205. return img_path, img, img0, None
  206. def __len__(self):
  207. return 0
  208. class LoadStreams: # multiple IP or RTSP cameras
  209. def __init__(self, sources='streams.txt', img_size=640, stride=32):
  210. self.mode = 'stream'
  211. self.img_size = img_size
  212. self.stride = stride
  213. if os.path.isfile(sources):
  214. with open(sources, 'r') as f:
  215. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  216. else:
  217. sources = [sources]
  218. n = len(sources)
  219. self.imgs = [None] * n
  220. self.sources = [clean_str(x) for x in sources] # clean source names for later
  221. for i, s in enumerate(sources): # index, source
  222. assert i==0
  223. # Start thread to read frames from video stream
  224. print(f'{i + 1}/{n}: {s}... ', end='')
  225. if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
  226. check_requirements(('pafy', 'youtube_dl'))
  227. import pafy
  228. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  229. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  230. cap = cv2.VideoCapture(s)
  231. assert cap.isOpened(), f'Failed to open {s}'
  232. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  233. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  234. self.fps = cap.get(cv2.CAP_PROP_FPS) % 100
  235. self.cap = cap
  236. _, self.imgs[i] = cap.read() # guarantee first frame
  237. print(f' success ({w}x{h} at {self.fps:.2f} FPS).')
  238. print('') # newline
  239. # check for common shapes
  240. s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
  241. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  242. if not self.rect:
  243. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  244. def update(self, index, cap):
  245. frames=2
  246. # Read next stream frame in a daemon thread
  247. n = 0
  248. iframe=0
  249. while cap.isOpened():
  250. n += 1
  251. _, self.imgs[index] = cap.read()
  252. iframe +=1
  253. '''cap.grab()
  254. if n == frames: # read every 4th frame
  255. success, im = cap.retrieve()
  256. self.imgs[index] = im if success else self.imgs[index] * 0
  257. n = 0'''
  258. #print('###sleep:%.1f ms ,index:%d ,n:%d, iframe:%d'%(1/self.fps*1000,index,n,iframe) )
  259. #time.sleep(1 / self.fps) # wait time
  260. return self.imgs
  261. def __iter__(self):
  262. self.count = -1
  263. return self
  264. def __next__(self):
  265. self.count += 1
  266. #img0 = self.imgs.copy()
  267. img0 = self.update(0,self.cap).copy()
  268. if not isinstance(img0[0],np.ndarray):
  269. #print('####video stream :%s error or video ends#####',self.sources)
  270. return False, None, None, None
  271. #if cv2.waitKey(1) == ord('q'): # q to quit
  272. # cv2.destroyAllWindows()
  273. # raise StopIteration
  274. # Letterbox
  275. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
  276. # Stack
  277. img = np.stack(img, 0)
  278. # Convert
  279. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  280. img = np.ascontiguousarray(img)
  281. return self.sources, img, img0, None
  282. def __len__(self):
  283. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  284. def img2label_paths(img_paths):
  285. # Define label paths as a function of image paths
  286. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  287. return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]
  288. class LoadImagesAndLabels(Dataset): # for training/testing
  289. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  290. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  291. self.img_size = img_size
  292. self.augment = augment
  293. self.hyp = hyp
  294. self.image_weights = image_weights
  295. self.rect = False if image_weights else rect
  296. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  297. self.mosaic_border = [-img_size // 2, -img_size // 2]
  298. self.stride = stride
  299. self.path = path
  300. try:
  301. f = [] # image files
  302. for p in path if isinstance(path, list) else [path]:
  303. p = Path(p) # os-agnostic
  304. if p.is_dir(): # dir
  305. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  306. # f = list(p.rglob('**/*.*')) # pathlib
  307. elif p.is_file(): # file
  308. with open(p, 'r') as t:
  309. t = t.read().strip().splitlines()
  310. parent = str(p.parent) + os.sep
  311. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  312. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  313. else:
  314. raise Exception(f'{prefix}{p} does not exist')
  315. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
  316. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
  317. assert self.img_files, f'{prefix}No images found'
  318. except Exception as e:
  319. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
  320. # Check cache
  321. self.label_files = img2label_paths(self.img_files) # labels
  322. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
  323. if cache_path.is_file():
  324. cache, exists = torch.load(cache_path), True # load
  325. if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed
  326. cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
  327. else:
  328. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  329. # Display cache
  330. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
  331. if exists:
  332. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  333. tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
  334. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
  335. # Read cache
  336. cache.pop('hash') # remove hash
  337. cache.pop('version') # remove version
  338. labels, shapes, self.segments = zip(*cache.values())
  339. self.labels = list(labels)
  340. self.shapes = np.array(shapes, dtype=np.float64)
  341. self.img_files = list(cache.keys()) # update
  342. self.label_files = img2label_paths(cache.keys()) # update
  343. if single_cls:
  344. for x in self.labels:
  345. x[:, 0] = 0
  346. n = len(shapes) # number of images
  347. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  348. nb = bi[-1] + 1 # number of batches
  349. self.batch = bi # batch index of image
  350. self.n = n
  351. self.indices = range(n)
  352. # Rectangular Training
  353. if self.rect:
  354. # Sort by aspect ratio
  355. s = self.shapes # wh
  356. ar = s[:, 1] / s[:, 0] # aspect ratio
  357. irect = ar.argsort()
  358. self.img_files = [self.img_files[i] for i in irect]
  359. self.label_files = [self.label_files[i] for i in irect]
  360. self.labels = [self.labels[i] for i in irect]
  361. self.shapes = s[irect] # wh
  362. ar = ar[irect]
  363. # Set training image shapes
  364. shapes = [[1, 1]] * nb
  365. for i in range(nb):
  366. ari = ar[bi == i]
  367. mini, maxi = ari.min(), ari.max()
  368. if maxi < 1:
  369. shapes[i] = [maxi, 1]
  370. elif mini > 1:
  371. shapes[i] = [1, 1 / mini]
  372. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  373. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  374. self.imgs = [None] * n
  375. if cache_images:
  376. gb = 0 # Gigabytes of cached images
  377. self.img_hw0, self.img_hw = [None] * n, [None] * n
  378. results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads
  379. pbar = tqdm(enumerate(results), total=n)
  380. for i, x in pbar:
  381. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
  382. gb += self.imgs[i].nbytes
  383. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
  384. pbar.close()
  385. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  386. # Cache dataset labels, check images and read shapes
  387. x = {} # dict
  388. nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
  389. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  390. for i, (im_file, lb_file) in enumerate(pbar):
  391. try:
  392. # verify images
  393. im = Image.open(im_file)
  394. im.verify() # PIL verify
  395. shape = exif_size(im) # image size
  396. segments = [] # instance segments
  397. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  398. assert im.format.lower() in img_formats, f'invalid image format {im.format}'
  399. # verify labels
  400. if os.path.isfile(lb_file):
  401. nf += 1 # label found
  402. with open(lb_file, 'r') as f:
  403. l = [x.split() for x in f.read().strip().splitlines()]
  404. if any([len(x) > 8 for x in l]): # is segment
  405. classes = np.array([x[0] for x in l], dtype=np.float32)
  406. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
  407. l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  408. l = np.array(l, dtype=np.float32)
  409. if len(l):
  410. assert l.shape[1] == 5, 'labels require 5 columns each'
  411. assert (l >= 0).all(), 'negative labels'
  412. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  413. assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
  414. else:
  415. ne += 1 # label empty
  416. l = np.zeros((0, 5), dtype=np.float32)
  417. else:
  418. nm += 1 # label missing
  419. l = np.zeros((0, 5), dtype=np.float32)
  420. x[im_file] = [l, shape, segments]
  421. except Exception as e:
  422. nc += 1
  423. print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
  424. pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
  425. f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  426. pbar.close()
  427. if nf == 0:
  428. print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
  429. x['hash'] = get_hash(self.label_files + self.img_files)
  430. x['results'] = nf, nm, ne, nc, i + 1
  431. x['version'] = 0.1 # cache version
  432. torch.save(x, path) # save for next time
  433. logging.info(f'{prefix}New cache created: {path}')
  434. return x
  435. def __len__(self):
  436. return len(self.img_files)
  437. # def __iter__(self):
  438. # self.count = -1
  439. # print('ran dataset iter')
  440. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  441. # return self
  442. def __getitem__(self, index):
  443. index = self.indices[index] # linear, shuffled, or image_weights
  444. hyp = self.hyp
  445. mosaic = self.mosaic and random.random() < hyp['mosaic']
  446. if mosaic:
  447. # Load mosaic
  448. img, labels = load_mosaic(self, index)
  449. shapes = None
  450. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  451. if random.random() < hyp['mixup']:
  452. img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
  453. r = np.random.beta(8.0, 8.0) # mixup ratio, alpha=beta=8.0
  454. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  455. labels = np.concatenate((labels, labels2), 0)
  456. else:
  457. # Load image
  458. img, (h0, w0), (h, w) = load_image(self, index)
  459. # Letterbox
  460. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  461. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  462. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  463. labels = self.labels[index].copy()
  464. if labels.size: # normalized xywh to pixel xyxy format
  465. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  466. if self.augment:
  467. # Augment imagespace
  468. if not mosaic:
  469. img, labels = random_perspective(img, labels,
  470. degrees=hyp['degrees'],
  471. translate=hyp['translate'],
  472. scale=hyp['scale'],
  473. shear=hyp['shear'],
  474. perspective=hyp['perspective'])
  475. # Augment colorspace
  476. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  477. # Apply cutouts
  478. # if random.random() < 0.9:
  479. # labels = cutout(img, labels)
  480. nL = len(labels) # number of labels
  481. if nL:
  482. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  483. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  484. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  485. if self.augment:
  486. # flip up-down
  487. if random.random() < hyp['flipud']:
  488. img = np.flipud(img)
  489. if nL:
  490. labels[:, 2] = 1 - labels[:, 2]
  491. # flip left-right
  492. if random.random() < hyp['fliplr']:
  493. img = np.fliplr(img)
  494. if nL:
  495. labels[:, 1] = 1 - labels[:, 1]
  496. labels_out = torch.zeros((nL, 6))
  497. if nL:
  498. labels_out[:, 1:] = torch.from_numpy(labels)
  499. # Convert
  500. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  501. img = np.ascontiguousarray(img)
  502. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  503. @staticmethod
  504. def collate_fn(batch):
  505. img, label, path, shapes = zip(*batch) # transposed
  506. for i, l in enumerate(label):
  507. l[:, 0] = i # add target image index for build_targets()
  508. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  509. @staticmethod
  510. def collate_fn4(batch):
  511. img, label, path, shapes = zip(*batch) # transposed
  512. n = len(shapes) // 4
  513. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  514. ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
  515. wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
  516. s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
  517. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  518. i *= 4
  519. if random.random() < 0.5:
  520. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
  521. 0].type(img[i].type())
  522. l = label[i]
  523. else:
  524. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  525. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  526. img4.append(im)
  527. label4.append(l)
  528. for i, l in enumerate(label4):
  529. l[:, 0] = i # add target image index for build_targets()
  530. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  531. # Ancillary functions --------------------------------------------------------------------------------------------------
  532. def load_image(self, index):
  533. # loads 1 image from dataset, returns img, original hw, resized hw
  534. img = self.imgs[index]
  535. if img is None: # not cached
  536. path = self.img_files[index]
  537. img = cv2.imread(path) # BGR
  538. assert img is not None, 'Image Not Found ' + path
  539. h0, w0 = img.shape[:2] # orig hw
  540. r = self.img_size / max(h0, w0) # resize image to img_size
  541. if r != 1: # always resize down, only resize up if training with augmentation
  542. interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR
  543. img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
  544. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  545. else:
  546. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  547. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  548. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  549. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  550. dtype = img.dtype # uint8
  551. x = np.arange(0, 256, dtype=np.int16)
  552. lut_hue = ((x * r[0]) % 180).astype(dtype)
  553. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  554. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  555. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  556. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  557. def hist_equalize(img, clahe=True, bgr=False):
  558. # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
  559. yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
  560. if clahe:
  561. c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  562. yuv[:, :, 0] = c.apply(yuv[:, :, 0])
  563. else:
  564. yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
  565. return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
  566. def load_mosaic(self, index):
  567. # loads images in a 4-mosaic
  568. labels4, segments4 = [], []
  569. s = self.img_size
  570. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  571. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  572. for i, index in enumerate(indices):
  573. # Load image
  574. img, _, (h, w) = load_image(self, index)
  575. # place img in img4
  576. if i == 0: # top left
  577. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  578. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  579. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  580. elif i == 1: # top right
  581. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  582. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  583. elif i == 2: # bottom left
  584. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  585. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  586. elif i == 3: # bottom right
  587. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  588. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  589. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  590. padw = x1a - x1b
  591. padh = y1a - y1b
  592. # Labels
  593. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  594. if labels.size:
  595. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  596. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  597. labels4.append(labels)
  598. segments4.extend(segments)
  599. # Concat/clip labels
  600. labels4 = np.concatenate(labels4, 0)
  601. for x in (labels4[:, 1:], *segments4):
  602. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  603. # img4, labels4 = replicate(img4, labels4) # replicate
  604. # Augment
  605. img4, labels4 = random_perspective(img4, labels4, segments4,
  606. degrees=self.hyp['degrees'],
  607. translate=self.hyp['translate'],
  608. scale=self.hyp['scale'],
  609. shear=self.hyp['shear'],
  610. perspective=self.hyp['perspective'],
  611. border=self.mosaic_border) # border to remove
  612. return img4, labels4
  613. def load_mosaic9(self, index):
  614. # loads images in a 9-mosaic
  615. labels9, segments9 = [], []
  616. s = self.img_size
  617. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  618. for i, index in enumerate(indices):
  619. # Load image
  620. img, _, (h, w) = load_image(self, index)
  621. # place img in img9
  622. if i == 0: # center
  623. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  624. h0, w0 = h, w
  625. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  626. elif i == 1: # top
  627. c = s, s - h, s + w, s
  628. elif i == 2: # top right
  629. c = s + wp, s - h, s + wp + w, s
  630. elif i == 3: # right
  631. c = s + w0, s, s + w0 + w, s + h
  632. elif i == 4: # bottom right
  633. c = s + w0, s + hp, s + w0 + w, s + hp + h
  634. elif i == 5: # bottom
  635. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  636. elif i == 6: # bottom left
  637. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  638. elif i == 7: # left
  639. c = s - w, s + h0 - h, s, s + h0
  640. elif i == 8: # top left
  641. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  642. padx, pady = c[:2]
  643. x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
  644. # Labels
  645. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  646. if labels.size:
  647. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  648. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  649. labels9.append(labels)
  650. segments9.extend(segments)
  651. # Image
  652. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  653. hp, wp = h, w # height, width previous
  654. # Offset
  655. yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y
  656. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  657. # Concat/clip labels
  658. labels9 = np.concatenate(labels9, 0)
  659. labels9[:, [1, 3]] -= xc
  660. labels9[:, [2, 4]] -= yc
  661. c = np.array([xc, yc]) # centers
  662. segments9 = [x - c for x in segments9]
  663. for x in (labels9[:, 1:], *segments9):
  664. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  665. # img9, labels9 = replicate(img9, labels9) # replicate
  666. # Augment
  667. img9, labels9 = random_perspective(img9, labels9, segments9,
  668. degrees=self.hyp['degrees'],
  669. translate=self.hyp['translate'],
  670. scale=self.hyp['scale'],
  671. shear=self.hyp['shear'],
  672. perspective=self.hyp['perspective'],
  673. border=self.mosaic_border) # border to remove
  674. return img9, labels9
  675. def replicate(img, labels):
  676. # Replicate labels
  677. h, w = img.shape[:2]
  678. boxes = labels[:, 1:].astype(int)
  679. x1, y1, x2, y2 = boxes.T
  680. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  681. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  682. x1b, y1b, x2b, y2b = boxes[i]
  683. bh, bw = y2b - y1b, x2b - x1b
  684. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  685. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  686. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  687. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  688. return img, labels
  689. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  690. # Resize and pad image while meeting stride-multiple constraints
  691. shape = img.shape[:2] # current shape [height, width]
  692. if isinstance(new_shape, int):
  693. new_shape = (new_shape, new_shape)
  694. # Scale ratio (new / old)
  695. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  696. if not scaleup: # only scale down, do not scale up (for better test mAP)
  697. r = min(r, 1.0)
  698. # Compute padding
  699. ratio = r, r # width, height ratios
  700. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  701. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  702. if auto: # minimum rectangle
  703. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  704. elif scaleFill: # stretch
  705. dw, dh = 0.0, 0.0
  706. new_unpad = (new_shape[1], new_shape[0])
  707. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  708. dw /= 2 # divide padding into 2 sides
  709. dh /= 2
  710. if shape[::-1] != new_unpad: # resize
  711. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  712. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  713. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  714. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  715. return img, ratio, (dw, dh)
  716. def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
  717. border=(0, 0)):
  718. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  719. # targets = [cls, xyxy]
  720. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  721. width = img.shape[1] + border[1] * 2
  722. # Center
  723. C = np.eye(3)
  724. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  725. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  726. # Perspective
  727. P = np.eye(3)
  728. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  729. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  730. # Rotation and Scale
  731. R = np.eye(3)
  732. a = random.uniform(-degrees, degrees)
  733. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  734. s = random.uniform(1 - scale, 1 + scale)
  735. # s = 2 ** random.uniform(-scale, scale)
  736. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  737. # Shear
  738. S = np.eye(3)
  739. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  740. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  741. # Translation
  742. T = np.eye(3)
  743. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  744. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  745. # Combined rotation matrix
  746. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  747. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  748. if perspective:
  749. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  750. else: # affine
  751. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  752. # Visualize
  753. # import matplotlib.pyplot as plt
  754. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  755. # ax[0].imshow(img[:, :, ::-1]) # base
  756. # ax[1].imshow(img2[:, :, ::-1]) # warped
  757. # Transform label coordinates
  758. n = len(targets)
  759. if n:
  760. use_segments = any(x.any() for x in segments)
  761. new = np.zeros((n, 4))
  762. if use_segments: # warp segments
  763. segments = resample_segments(segments) # upsample
  764. for i, segment in enumerate(segments):
  765. xy = np.ones((len(segment), 3))
  766. xy[:, :2] = segment
  767. xy = xy @ M.T # transform
  768. xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
  769. # clip
  770. new[i] = segment2box(xy, width, height)
  771. else: # warp boxes
  772. xy = np.ones((n * 4, 3))
  773. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  774. xy = xy @ M.T # transform
  775. xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
  776. # create new boxes
  777. x = xy[:, [0, 2, 4, 6]]
  778. y = xy[:, [1, 3, 5, 7]]
  779. new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  780. # clip
  781. new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
  782. new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
  783. # filter candidates
  784. i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
  785. targets = targets[i]
  786. targets[:, 1:5] = new[i]
  787. return img, targets
  788. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
  789. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  790. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  791. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  792. ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
  793. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
  794. def cutout(image, labels):
  795. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  796. h, w = image.shape[:2]
  797. def bbox_ioa(box1, box2):
  798. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  799. box2 = box2.transpose()
  800. # Get the coordinates of bounding boxes
  801. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  802. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  803. # Intersection area
  804. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  805. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  806. # box2 area
  807. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  808. # Intersection over box2 area
  809. return inter_area / box2_area
  810. # create random masks
  811. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  812. for s in scales:
  813. mask_h = random.randint(1, int(h * s))
  814. mask_w = random.randint(1, int(w * s))
  815. # box
  816. xmin = max(0, random.randint(0, w) - mask_w // 2)
  817. ymin = max(0, random.randint(0, h) - mask_h // 2)
  818. xmax = min(w, xmin + mask_w)
  819. ymax = min(h, ymin + mask_h)
  820. # apply random color mask
  821. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  822. # return unobscured labels
  823. if len(labels) and s > 0.03:
  824. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  825. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  826. labels = labels[ioa < 0.60] # remove >60% obscured labels
  827. return labels
  828. def create_folder(path='./new'):
  829. # Create folder
  830. if os.path.exists(path):
  831. shutil.rmtree(path) # delete output folder
  832. os.makedirs(path) # make new output folder
  833. def flatten_recursive(path='../coco128'):
  834. # Flatten a recursive directory by bringing all files to top level
  835. new_path = Path(path + '_flat')
  836. create_folder(new_path)
  837. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  838. shutil.copyfile(file, new_path / Path(file).name)
  839. def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128')
  840. # Convert detection dataset into classification dataset, with one directory per class
  841. path = Path(path) # images dir
  842. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  843. files = list(path.rglob('*.*'))
  844. n = len(files) # number of files
  845. for im_file in tqdm(files, total=n):
  846. if im_file.suffix[1:] in img_formats:
  847. # image
  848. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  849. h, w = im.shape[:2]
  850. # labels
  851. lb_file = Path(img2label_paths([str(im_file)])[0])
  852. if Path(lb_file).exists():
  853. with open(lb_file, 'r') as f:
  854. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  855. for j, x in enumerate(lb):
  856. c = int(x[0]) # class
  857. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  858. if not f.parent.is_dir():
  859. f.parent.mkdir(parents=True)
  860. b = x[1:] * [w, h, w, h] # box
  861. # b[2:] = b[2:].max() # rectangle to square
  862. b[2:] = b[2:] * 1.2 + 3 # pad
  863. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  864. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  865. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  866. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  867. def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False):
  868. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  869. Usage: from utils.datasets import *; autosplit('../coco128')
  870. Arguments
  871. path: Path to images directory
  872. weights: Train, val, test weights (list)
  873. annotated_only: Only use images with an annotated txt file
  874. """
  875. path = Path(path) # images dir
  876. files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only
  877. n = len(files) # number of files
  878. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  879. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  880. [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing
  881. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  882. for i, img in tqdm(zip(indices, files), total=n):
  883. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  884. with open(path / txt[i], 'a') as f:
  885. f.write(str(img) + '\n') # add image to txt file