You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1072 line
44KB

  1. # Dataset utils and dataloaders
  2. import glob
  3. import hashlib
  4. import logging
  5. import math
  6. import os
  7. import random
  8. import shutil
  9. import time
  10. from itertools import repeat
  11. from multiprocessing.pool import ThreadPool
  12. from pathlib import Path
  13. from threading import Thread
  14. import cv2
  15. import numpy as np
  16. import torch
  17. import torch.nn.functional as F
  18. from PIL import Image, ExifTags
  19. from torch.utils.data import Dataset
  20. from tqdm import tqdm
  21. from utils.general import check_requirements, xyxy2xywh, xywh2xyxy, xywhn2xyxy, xyn2xy, segment2box, segments2boxes, \
  22. resample_segments, clean_str
  23. from utils.torch_utils import torch_distributed_zero_first
  24. # Parameters
  25. help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
  26. img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
  27. vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
  28. logger = logging.getLogger(__name__)
  29. # Get orientation exif tag
  30. for orientation in ExifTags.TAGS.keys():
  31. if ExifTags.TAGS[orientation] == 'Orientation':
  32. break
  33. def get_hash(paths):
  34. # Returns a single hash value of a list of paths (files or dirs)
  35. size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
  36. h = hashlib.md5(str(size).encode()) # hash sizes
  37. h.update(''.join(paths).encode()) # hash paths
  38. return h.hexdigest() # return hash
  39. def exif_size(img):
  40. # Returns exif-corrected PIL size
  41. s = img.size # (width, height)
  42. try:
  43. rotation = dict(img._getexif().items())[orientation]
  44. if rotation == 6: # rotation 270
  45. s = (s[1], s[0])
  46. elif rotation == 8: # rotation 90
  47. s = (s[1], s[0])
  48. except:
  49. pass
  50. return s
  51. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
  52. rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
  53. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
  54. with torch_distributed_zero_first(rank):
  55. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
  56. augment=augment, # augment images
  57. hyp=hyp, # augmentation hyperparameters
  58. rect=rect, # rectangular training
  59. cache_images=cache,
  60. single_cls=opt.single_cls,
  61. stride=int(stride),
  62. pad=pad,
  63. image_weights=image_weights,
  64. prefix=prefix)
  65. batch_size = min(batch_size, len(dataset))
  66. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
  67. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
  68. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
  69. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
  70. dataloader = loader(dataset,
  71. batch_size=batch_size,
  72. num_workers=nw,
  73. sampler=sampler,
  74. pin_memory=True,
  75. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
  76. return dataloader, dataset
  77. class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
  78. """ Dataloader that reuses workers
  79. Uses same syntax as vanilla DataLoader
  80. """
  81. def __init__(self, *args, **kwargs):
  82. super().__init__(*args, **kwargs)
  83. object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
  84. self.iterator = super().__iter__()
  85. def __len__(self):
  86. return len(self.batch_sampler.sampler)
  87. def __iter__(self):
  88. for i in range(len(self)):
  89. yield next(self.iterator)
  90. class _RepeatSampler(object):
  91. """ Sampler that repeats forever
  92. Args:
  93. sampler (Sampler)
  94. """
  95. def __init__(self, sampler):
  96. self.sampler = sampler
  97. def __iter__(self):
  98. while True:
  99. yield from iter(self.sampler)
  100. class LoadImages: # for inference
  101. def __init__(self, path, img_size=640, stride=32):
  102. p = str(Path(path).absolute()) # os-agnostic absolute path
  103. if '*' in p:
  104. files = sorted(glob.glob(p, recursive=True)) # glob
  105. elif os.path.isdir(p):
  106. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  107. elif os.path.isfile(p):
  108. files = [p] # files
  109. else:
  110. raise Exception(f'ERROR: {p} does not exist')
  111. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  112. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  113. ni, nv = len(images), len(videos)
  114. self.img_size = img_size
  115. self.stride = stride
  116. self.files = images + videos
  117. self.nf = ni + nv # number of files
  118. self.video_flag = [False] * ni + [True] * nv
  119. self.mode = 'image'
  120. if any(videos):
  121. self.new_video(videos[0]) # new video
  122. else:
  123. self.cap = None
  124. assert self.nf > 0, f'No images or videos found in {p}. ' \
  125. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
  126. def __iter__(self):
  127. self.count = 0
  128. return self
  129. def __next__(self):
  130. if self.count == self.nf:
  131. raise StopIteration
  132. path = self.files[self.count]
  133. if self.video_flag[self.count]:
  134. # Read video
  135. self.mode = 'video'
  136. ret_val, img0 = self.cap.read()
  137. if not ret_val:
  138. self.count += 1
  139. self.cap.release()
  140. if self.count == self.nf: # last video
  141. raise StopIteration
  142. else:
  143. path = self.files[self.count]
  144. self.new_video(path)
  145. ret_val, img0 = self.cap.read()
  146. self.frame += 1
  147. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ', end='')
  148. else:
  149. # Read image
  150. self.count += 1
  151. img0 = cv2.imread(path) # BGR
  152. assert img0 is not None, 'Image Not Found ' + path
  153. print(f'image {self.count}/{self.nf} {path}: ', end='')
  154. # Padded resize
  155. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  156. # Convert
  157. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  158. img = np.ascontiguousarray(img)
  159. return path, img, img0, self.cap
  160. def new_video(self, path):
  161. self.frame = 0
  162. self.cap = cv2.VideoCapture(path)
  163. self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  164. def __len__(self):
  165. return self.nf # number of files
  166. class LoadWebcam: # for inference
  167. def __init__(self, pipe='0', img_size=640, stride=32):
  168. self.img_size = img_size
  169. self.stride = stride
  170. if pipe.isnumeric():
  171. pipe = eval(pipe) # local camera
  172. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  173. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  174. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  175. self.pipe = pipe
  176. self.cap = cv2.VideoCapture(pipe) # video capture object
  177. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  178. def __iter__(self):
  179. self.count = -1
  180. return self
  181. def __next__(self):
  182. self.count += 1
  183. if cv2.waitKey(1) == ord('q'): # q to quit
  184. self.cap.release()
  185. cv2.destroyAllWindows()
  186. raise StopIteration
  187. # Read frame
  188. if self.pipe == 0: # local camera
  189. ret_val, img0 = self.cap.read()
  190. img0 = cv2.flip(img0, 1) # flip left-right
  191. else: # IP camera
  192. n = 0
  193. while True:
  194. n += 1
  195. self.cap.grab()
  196. if n % 30 == 0: # skip frames
  197. ret_val, img0 = self.cap.retrieve()
  198. if ret_val:
  199. break
  200. # Print
  201. assert ret_val, f'Camera Error {self.pipe}'
  202. img_path = 'webcam.jpg'
  203. print(f'webcam {self.count}: ', end='')
  204. # Padded resize
  205. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  206. # Convert
  207. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  208. img = np.ascontiguousarray(img)
  209. return img_path, img, img0, None
  210. def __len__(self):
  211. return 0
  212. class LoadStreams: # multiple IP or RTSP cameras
  213. def __init__(self, sources='streams.txt', img_size=640, stride=32):
  214. self.mode = 'stream'
  215. self.img_size = img_size
  216. self.stride = stride
  217. if os.path.isfile(sources):
  218. with open(sources, 'r') as f:
  219. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  220. else:
  221. sources = [sources]
  222. n = len(sources)
  223. self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
  224. self.sources = [clean_str(x) for x in sources] # clean source names for later
  225. for i, s in enumerate(sources): # index, source
  226. # Start thread to read frames from video stream
  227. print(f'{i + 1}/{n}: {s}... ', end='')
  228. if 'youtube.com/' in s or 'youtu.be/' in s: # if source is YouTube video
  229. check_requirements(('pafy', 'youtube_dl'))
  230. import pafy
  231. s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
  232. s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
  233. cap = cv2.VideoCapture(s)
  234. assert cap.isOpened(), f'Failed to open {s}'
  235. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  236. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  237. self.fps[i] = max(cap.get(cv2.CAP_PROP_FPS) % 100, 0) or 30.0 # 30 FPS fallback
  238. self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
  239. _, self.imgs[i] = cap.read() # guarantee first frame
  240. self.threads[i] = Thread(target=self.update, args=([i, cap]), daemon=True)
  241. print(f" success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
  242. self.threads[i].start()
  243. print('') # newline
  244. # check for common shapes
  245. s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
  246. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  247. if not self.rect:
  248. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  249. def update(self, i, cap):
  250. # Read stream `i` frames in daemon thread
  251. n, f = 0, self.frames[i]
  252. while cap.isOpened() and n < f:
  253. n += 1
  254. # _, self.imgs[index] = cap.read()
  255. cap.grab()
  256. if n % 4: # read every 4th frame
  257. success, im = cap.retrieve()
  258. self.imgs[i] = im if success else self.imgs[i] * 0
  259. time.sleep(1 / self.fps[i]) # wait time
  260. def __iter__(self):
  261. self.count = -1
  262. return self
  263. def __next__(self):
  264. self.count += 1
  265. if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
  266. cv2.destroyAllWindows()
  267. raise StopIteration
  268. # Letterbox
  269. img0 = self.imgs.copy()
  270. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
  271. # Stack
  272. img = np.stack(img, 0)
  273. # Convert
  274. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  275. img = np.ascontiguousarray(img)
  276. return self.sources, img, img0, None
  277. def __len__(self):
  278. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
  279. def img2label_paths(img_paths):
  280. # Define label paths as a function of image paths
  281. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
  282. return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]
  283. class LoadImagesAndLabels(Dataset): # for training/testing
  284. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
  285. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
  286. self.img_size = img_size
  287. self.augment = augment
  288. self.hyp = hyp
  289. self.image_weights = image_weights
  290. self.rect = False if image_weights else rect
  291. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
  292. self.mosaic_border = [-img_size // 2, -img_size // 2]
  293. self.stride = stride
  294. self.path = path
  295. try:
  296. f = [] # image files
  297. for p in path if isinstance(path, list) else [path]:
  298. p = Path(p) # os-agnostic
  299. if p.is_dir(): # dir
  300. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
  301. # f = list(p.rglob('**/*.*')) # pathlib
  302. elif p.is_file(): # file
  303. with open(p, 'r') as t:
  304. t = t.read().strip().splitlines()
  305. parent = str(p.parent) + os.sep
  306. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
  307. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
  308. else:
  309. raise Exception(f'{prefix}{p} does not exist')
  310. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
  311. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
  312. assert self.img_files, f'{prefix}No images found'
  313. except Exception as e:
  314. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')
  315. # Check cache
  316. self.label_files = img2label_paths(self.img_files) # labels
  317. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
  318. if cache_path.is_file():
  319. cache, exists = torch.load(cache_path), True # load
  320. if cache['hash'] != get_hash(self.label_files + self.img_files): # changed
  321. cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
  322. else:
  323. cache, exists = self.cache_labels(cache_path, prefix), False # cache
  324. # Display cache
  325. nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupted, total
  326. if exists:
  327. d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  328. tqdm(None, desc=prefix + d, total=n, initial=n) # display cache results
  329. assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}'
  330. # Read cache
  331. cache.pop('hash') # remove hash
  332. cache.pop('version') # remove version
  333. labels, shapes, self.segments = zip(*cache.values())
  334. self.labels = list(labels)
  335. self.shapes = np.array(shapes, dtype=np.float64)
  336. self.img_files = list(cache.keys()) # update
  337. self.label_files = img2label_paths(cache.keys()) # update
  338. if single_cls:
  339. for x in self.labels:
  340. x[:, 0] = 0
  341. n = len(shapes) # number of images
  342. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
  343. nb = bi[-1] + 1 # number of batches
  344. self.batch = bi # batch index of image
  345. self.n = n
  346. self.indices = range(n)
  347. # Rectangular Training
  348. if self.rect:
  349. # Sort by aspect ratio
  350. s = self.shapes # wh
  351. ar = s[:, 1] / s[:, 0] # aspect ratio
  352. irect = ar.argsort()
  353. self.img_files = [self.img_files[i] for i in irect]
  354. self.label_files = [self.label_files[i] for i in irect]
  355. self.labels = [self.labels[i] for i in irect]
  356. self.shapes = s[irect] # wh
  357. ar = ar[irect]
  358. # Set training image shapes
  359. shapes = [[1, 1]] * nb
  360. for i in range(nb):
  361. ari = ar[bi == i]
  362. mini, maxi = ari.min(), ari.max()
  363. if maxi < 1:
  364. shapes[i] = [maxi, 1]
  365. elif mini > 1:
  366. shapes[i] = [1, 1 / mini]
  367. self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
  368. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
  369. self.imgs = [None] * n
  370. if cache_images:
  371. gb = 0 # Gigabytes of cached images
  372. self.img_hw0, self.img_hw = [None] * n, [None] * n
  373. results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads
  374. pbar = tqdm(enumerate(results), total=n)
  375. for i, x in pbar:
  376. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
  377. gb += self.imgs[i].nbytes
  378. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
  379. pbar.close()
  380. def cache_labels(self, path=Path('./labels.cache'), prefix=''):
  381. # Cache dataset labels, check images and read shapes
  382. x = {} # dict
  383. nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
  384. pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
  385. for i, (im_file, lb_file) in enumerate(pbar):
  386. try:
  387. # verify images
  388. im = Image.open(im_file)
  389. im.verify() # PIL verify
  390. shape = exif_size(im) # image size
  391. segments = [] # instance segments
  392. assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
  393. assert im.format.lower() in img_formats, f'invalid image format {im.format}'
  394. # verify labels
  395. if os.path.isfile(lb_file):
  396. nf += 1 # label found
  397. with open(lb_file, 'r') as f:
  398. l = [x.split() for x in f.read().strip().splitlines() if len(x)]
  399. if any([len(x) > 8 for x in l]): # is segment
  400. classes = np.array([x[0] for x in l], dtype=np.float32)
  401. segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
  402. l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
  403. l = np.array(l, dtype=np.float32)
  404. if len(l):
  405. assert l.shape[1] == 5, 'labels require 5 columns each'
  406. assert (l >= 0).all(), 'negative labels'
  407. assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
  408. assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
  409. else:
  410. ne += 1 # label empty
  411. l = np.zeros((0, 5), dtype=np.float32)
  412. else:
  413. nm += 1 # label missing
  414. l = np.zeros((0, 5), dtype=np.float32)
  415. x[im_file] = [l, shape, segments]
  416. except Exception as e:
  417. nc += 1
  418. logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
  419. pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
  420. f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
  421. pbar.close()
  422. if nf == 0:
  423. logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
  424. x['hash'] = get_hash(self.label_files + self.img_files)
  425. x['results'] = nf, nm, ne, nc, i + 1
  426. x['version'] = 0.2 # cache version
  427. try:
  428. torch.save(x, path) # save cache for next time
  429. logging.info(f'{prefix}New cache created: {path}')
  430. except Exception as e:
  431. logging.info(f'{prefix}WARNING: Cache directory {path.parent} is not writeable: {e}') # path not writeable
  432. return x
  433. def __len__(self):
  434. return len(self.img_files)
  435. # def __iter__(self):
  436. # self.count = -1
  437. # print('ran dataset iter')
  438. # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
  439. # return self
  440. def __getitem__(self, index):
  441. index = self.indices[index] # linear, shuffled, or image_weights
  442. hyp = self.hyp
  443. mosaic = self.mosaic and random.random() < hyp['mosaic']
  444. if mosaic:
  445. # Load mosaic
  446. img, labels = load_mosaic(self, index)
  447. shapes = None
  448. # MixUp https://arxiv.org/pdf/1710.09412.pdf
  449. if random.random() < hyp['mixup']:
  450. img2, labels2 = load_mosaic(self, random.randint(0, self.n - 1))
  451. r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
  452. img = (img * r + img2 * (1 - r)).astype(np.uint8)
  453. labels = np.concatenate((labels, labels2), 0)
  454. else:
  455. # Load image
  456. img, (h0, w0), (h, w) = load_image(self, index)
  457. # Letterbox
  458. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  459. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  460. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  461. labels = self.labels[index].copy()
  462. if labels.size: # normalized xywh to pixel xyxy format
  463. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  464. if self.augment:
  465. # Augment imagespace
  466. if not mosaic:
  467. img, labels = random_perspective(img, labels,
  468. degrees=hyp['degrees'],
  469. translate=hyp['translate'],
  470. scale=hyp['scale'],
  471. shear=hyp['shear'],
  472. perspective=hyp['perspective'])
  473. # Augment colorspace
  474. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  475. # Apply cutouts
  476. # if random.random() < 0.9:
  477. # labels = cutout(img, labels)
  478. nL = len(labels) # number of labels
  479. if nL:
  480. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
  481. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
  482. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1
  483. if self.augment:
  484. # flip up-down
  485. if random.random() < hyp['flipud']:
  486. img = np.flipud(img)
  487. if nL:
  488. labels[:, 2] = 1 - labels[:, 2]
  489. # flip left-right
  490. if random.random() < hyp['fliplr']:
  491. img = np.fliplr(img)
  492. if nL:
  493. labels[:, 1] = 1 - labels[:, 1]
  494. labels_out = torch.zeros((nL, 6))
  495. if nL:
  496. labels_out[:, 1:] = torch.from_numpy(labels)
  497. # Convert
  498. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  499. img = np.ascontiguousarray(img)
  500. return torch.from_numpy(img), labels_out, self.img_files[index], shapes
  501. @staticmethod
  502. def collate_fn(batch):
  503. img, label, path, shapes = zip(*batch) # transposed
  504. for i, l in enumerate(label):
  505. l[:, 0] = i # add target image index for build_targets()
  506. return torch.stack(img, 0), torch.cat(label, 0), path, shapes
  507. @staticmethod
  508. def collate_fn4(batch):
  509. img, label, path, shapes = zip(*batch) # transposed
  510. n = len(shapes) // 4
  511. img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
  512. ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
  513. wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
  514. s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
  515. for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
  516. i *= 4
  517. if random.random() < 0.5:
  518. im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
  519. 0].type(img[i].type())
  520. l = label[i]
  521. else:
  522. im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
  523. l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
  524. img4.append(im)
  525. label4.append(l)
  526. for i, l in enumerate(label4):
  527. l[:, 0] = i # add target image index for build_targets()
  528. return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
  529. # Ancillary functions --------------------------------------------------------------------------------------------------
  530. def load_image(self, index):
  531. # loads 1 image from dataset, returns img, original hw, resized hw
  532. img = self.imgs[index]
  533. if img is None: # not cached
  534. path = self.img_files[index]
  535. img = cv2.imread(path) # BGR
  536. assert img is not None, 'Image Not Found ' + path
  537. h0, w0 = img.shape[:2] # orig hw
  538. r = self.img_size / max(h0, w0) # ratio
  539. if r != 1: # if sizes are not equal
  540. img = cv2.resize(img, (int(w0 * r), int(h0 * r)),
  541. interpolation=cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR)
  542. return img, (h0, w0), img.shape[:2] # img, hw_original, hw_resized
  543. else:
  544. return self.imgs[index], self.img_hw0[index], self.img_hw[index] # img, hw_original, hw_resized
  545. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  546. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  547. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  548. dtype = img.dtype # uint8
  549. x = np.arange(0, 256, dtype=r.dtype)
  550. lut_hue = ((x * r[0]) % 180).astype(dtype)
  551. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  552. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  553. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
  554. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  555. def hist_equalize(img, clahe=True, bgr=False):
  556. # Equalize histogram on BGR image 'img' with img.shape(n,m,3) and range 0-255
  557. yuv = cv2.cvtColor(img, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
  558. if clahe:
  559. c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
  560. yuv[:, :, 0] = c.apply(yuv[:, :, 0])
  561. else:
  562. yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
  563. return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
  564. def load_mosaic(self, index):
  565. # loads images in a 4-mosaic
  566. labels4, segments4 = [], []
  567. s = self.img_size
  568. yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border] # mosaic center x, y
  569. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  570. for i, index in enumerate(indices):
  571. # Load image
  572. img, _, (h, w) = load_image(self, index)
  573. # place img in img4
  574. if i == 0: # top left
  575. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  576. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  577. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  578. elif i == 1: # top right
  579. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  580. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  581. elif i == 2: # bottom left
  582. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  583. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  584. elif i == 3: # bottom right
  585. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  586. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  587. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  588. padw = x1a - x1b
  589. padh = y1a - y1b
  590. # Labels
  591. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  592. if labels.size:
  593. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  594. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  595. labels4.append(labels)
  596. segments4.extend(segments)
  597. # Concat/clip labels
  598. labels4 = np.concatenate(labels4, 0)
  599. for x in (labels4[:, 1:], *segments4):
  600. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  601. # img4, labels4 = replicate(img4, labels4) # replicate
  602. # Augment
  603. img4, labels4 = random_perspective(img4, labels4, segments4,
  604. degrees=self.hyp['degrees'],
  605. translate=self.hyp['translate'],
  606. scale=self.hyp['scale'],
  607. shear=self.hyp['shear'],
  608. perspective=self.hyp['perspective'],
  609. border=self.mosaic_border) # border to remove
  610. return img4, labels4
  611. def load_mosaic9(self, index):
  612. # loads images in a 9-mosaic
  613. labels9, segments9 = [], []
  614. s = self.img_size
  615. indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
  616. for i, index in enumerate(indices):
  617. # Load image
  618. img, _, (h, w) = load_image(self, index)
  619. # place img in img9
  620. if i == 0: # center
  621. img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  622. h0, w0 = h, w
  623. c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
  624. elif i == 1: # top
  625. c = s, s - h, s + w, s
  626. elif i == 2: # top right
  627. c = s + wp, s - h, s + wp + w, s
  628. elif i == 3: # right
  629. c = s + w0, s, s + w0 + w, s + h
  630. elif i == 4: # bottom right
  631. c = s + w0, s + hp, s + w0 + w, s + hp + h
  632. elif i == 5: # bottom
  633. c = s + w0 - w, s + h0, s + w0, s + h0 + h
  634. elif i == 6: # bottom left
  635. c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
  636. elif i == 7: # left
  637. c = s - w, s + h0 - h, s, s + h0
  638. elif i == 8: # top left
  639. c = s - w, s + h0 - hp - h, s, s + h0 - hp
  640. padx, pady = c[:2]
  641. x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords
  642. # Labels
  643. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  644. if labels.size:
  645. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
  646. segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
  647. labels9.append(labels)
  648. segments9.extend(segments)
  649. # Image
  650. img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
  651. hp, wp = h, w # height, width previous
  652. # Offset
  653. yc, xc = [int(random.uniform(0, s)) for _ in self.mosaic_border] # mosaic center x, y
  654. img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
  655. # Concat/clip labels
  656. labels9 = np.concatenate(labels9, 0)
  657. labels9[:, [1, 3]] -= xc
  658. labels9[:, [2, 4]] -= yc
  659. c = np.array([xc, yc]) # centers
  660. segments9 = [x - c for x in segments9]
  661. for x in (labels9[:, 1:], *segments9):
  662. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  663. # img9, labels9 = replicate(img9, labels9) # replicate
  664. # Augment
  665. img9, labels9 = random_perspective(img9, labels9, segments9,
  666. degrees=self.hyp['degrees'],
  667. translate=self.hyp['translate'],
  668. scale=self.hyp['scale'],
  669. shear=self.hyp['shear'],
  670. perspective=self.hyp['perspective'],
  671. border=self.mosaic_border) # border to remove
  672. return img9, labels9
  673. def replicate(img, labels):
  674. # Replicate labels
  675. h, w = img.shape[:2]
  676. boxes = labels[:, 1:].astype(int)
  677. x1, y1, x2, y2 = boxes.T
  678. s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
  679. for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
  680. x1b, y1b, x2b, y2b = boxes[i]
  681. bh, bw = y2b - y1b, x2b - x1b
  682. yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
  683. x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
  684. img[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  685. labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
  686. return img, labels
  687. def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  688. # Resize and pad image while meeting stride-multiple constraints
  689. shape = img.shape[:2] # current shape [height, width]
  690. if isinstance(new_shape, int):
  691. new_shape = (new_shape, new_shape)
  692. # Scale ratio (new / old)
  693. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  694. if not scaleup: # only scale down, do not scale up (for better test mAP)
  695. r = min(r, 1.0)
  696. # Compute padding
  697. ratio = r, r # width, height ratios
  698. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  699. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  700. if auto: # minimum rectangle
  701. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  702. elif scaleFill: # stretch
  703. dw, dh = 0.0, 0.0
  704. new_unpad = (new_shape[1], new_shape[0])
  705. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  706. dw /= 2 # divide padding into 2 sides
  707. dh /= 2
  708. if shape[::-1] != new_unpad: # resize
  709. img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
  710. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  711. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  712. img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  713. return img, ratio, (dw, dh)
  714. def random_perspective(img, targets=(), segments=(), degrees=10, translate=.1, scale=.1, shear=10, perspective=0.0,
  715. border=(0, 0)):
  716. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-10, 10))
  717. # targets = [cls, xyxy]
  718. height = img.shape[0] + border[0] * 2 # shape(h,w,c)
  719. width = img.shape[1] + border[1] * 2
  720. # Center
  721. C = np.eye(3)
  722. C[0, 2] = -img.shape[1] / 2 # x translation (pixels)
  723. C[1, 2] = -img.shape[0] / 2 # y translation (pixels)
  724. # Perspective
  725. P = np.eye(3)
  726. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  727. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  728. # Rotation and Scale
  729. R = np.eye(3)
  730. a = random.uniform(-degrees, degrees)
  731. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  732. s = random.uniform(1 - scale, 1 + scale)
  733. # s = 2 ** random.uniform(-scale, scale)
  734. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  735. # Shear
  736. S = np.eye(3)
  737. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  738. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  739. # Translation
  740. T = np.eye(3)
  741. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  742. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  743. # Combined rotation matrix
  744. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  745. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  746. if perspective:
  747. img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114))
  748. else: # affine
  749. img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
  750. # Visualize
  751. # import matplotlib.pyplot as plt
  752. # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
  753. # ax[0].imshow(img[:, :, ::-1]) # base
  754. # ax[1].imshow(img2[:, :, ::-1]) # warped
  755. # Transform label coordinates
  756. n = len(targets)
  757. if n:
  758. use_segments = any(x.any() for x in segments)
  759. new = np.zeros((n, 4))
  760. if use_segments: # warp segments
  761. segments = resample_segments(segments) # upsample
  762. for i, segment in enumerate(segments):
  763. xy = np.ones((len(segment), 3))
  764. xy[:, :2] = segment
  765. xy = xy @ M.T # transform
  766. xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
  767. # clip
  768. new[i] = segment2box(xy, width, height)
  769. else: # warp boxes
  770. xy = np.ones((n * 4, 3))
  771. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  772. xy = xy @ M.T # transform
  773. xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
  774. # create new boxes
  775. x = xy[:, [0, 2, 4, 6]]
  776. y = xy[:, [1, 3, 5, 7]]
  777. new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  778. # clip
  779. new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
  780. new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
  781. # filter candidates
  782. i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
  783. targets = targets[i]
  784. targets[:, 1:5] = new[i]
  785. return img, targets
  786. def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
  787. # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
  788. w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
  789. w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
  790. ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
  791. return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
  792. def cutout(image, labels):
  793. # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
  794. h, w = image.shape[:2]
  795. def bbox_ioa(box1, box2):
  796. # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2
  797. box2 = box2.transpose()
  798. # Get the coordinates of bounding boxes
  799. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  800. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  801. # Intersection area
  802. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  803. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  804. # box2 area
  805. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16
  806. # Intersection over box2 area
  807. return inter_area / box2_area
  808. # create random masks
  809. scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
  810. for s in scales:
  811. mask_h = random.randint(1, int(h * s))
  812. mask_w = random.randint(1, int(w * s))
  813. # box
  814. xmin = max(0, random.randint(0, w) - mask_w // 2)
  815. ymin = max(0, random.randint(0, h) - mask_h // 2)
  816. xmax = min(w, xmin + mask_w)
  817. ymax = min(h, ymin + mask_h)
  818. # apply random color mask
  819. image[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
  820. # return unobscured labels
  821. if len(labels) and s > 0.03:
  822. box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32)
  823. ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area
  824. labels = labels[ioa < 0.60] # remove >60% obscured labels
  825. return labels
  826. def create_folder(path='./new'):
  827. # Create folder
  828. if os.path.exists(path):
  829. shutil.rmtree(path) # delete output folder
  830. os.makedirs(path) # make new output folder
  831. def flatten_recursive(path='../coco128'):
  832. # Flatten a recursive directory by bringing all files to top level
  833. new_path = Path(path + '_flat')
  834. create_folder(new_path)
  835. for file in tqdm(glob.glob(str(Path(path)) + '/**/*.*', recursive=True)):
  836. shutil.copyfile(file, new_path / Path(file).name)
  837. def extract_boxes(path='../coco128/'): # from utils.datasets import *; extract_boxes('../coco128')
  838. # Convert detection dataset into classification dataset, with one directory per class
  839. path = Path(path) # images dir
  840. shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
  841. files = list(path.rglob('*.*'))
  842. n = len(files) # number of files
  843. for im_file in tqdm(files, total=n):
  844. if im_file.suffix[1:] in img_formats:
  845. # image
  846. im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
  847. h, w = im.shape[:2]
  848. # labels
  849. lb_file = Path(img2label_paths([str(im_file)])[0])
  850. if Path(lb_file).exists():
  851. with open(lb_file, 'r') as f:
  852. lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
  853. for j, x in enumerate(lb):
  854. c = int(x[0]) # class
  855. f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
  856. if not f.parent.is_dir():
  857. f.parent.mkdir(parents=True)
  858. b = x[1:] * [w, h, w, h] # box
  859. # b[2:] = b[2:].max() # rectangle to square
  860. b[2:] = b[2:] * 1.2 + 3 # pad
  861. b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(np.int)
  862. b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
  863. b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
  864. assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
  865. def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False):
  866. """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
  867. Usage: from utils.datasets import *; autosplit('../coco128')
  868. Arguments
  869. path: Path to images directory
  870. weights: Train, val, test weights (list)
  871. annotated_only: Only use images with an annotated txt file
  872. """
  873. path = Path(path) # images dir
  874. files = sum([list(path.rglob(f"*.{img_ext}")) for img_ext in img_formats], []) # image files only
  875. n = len(files) # number of files
  876. indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
  877. txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
  878. [(path / x).unlink() for x in txt if (path / x).exists()] # remove existing
  879. print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
  880. for i, img in tqdm(zip(indices, files), total=n):
  881. if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
  882. with open(path / txt[i], 'a') as f:
  883. f.write(str(img) + '\n') # add image to txt file