You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

506 lines
21KB

  1. # General utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import re
  8. import subprocess
  9. import time
  10. from pathlib import Path
  11. import cv2
  12. import numpy as np
  13. import torch
  14. import torchvision
  15. import yaml
  16. from utils.google_utils import gsutil_getsize
  17. from utils.metrics import fitness
  18. from utils.torch_utils import init_torch_seeds
  19. # Settings
  20. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  21. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  22. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  23. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  24. def set_logging(rank=-1):
  25. logging.basicConfig(
  26. format="%(message)s",
  27. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  28. def init_seeds(seed=0):
  29. # Initialize random number generator (RNG) seeds
  30. random.seed(seed)
  31. np.random.seed(seed)
  32. init_torch_seeds(seed)
  33. def get_latest_run(search_dir='.'):
  34. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  35. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  36. return max(last_list, key=os.path.getctime) if last_list else ''
  37. def check_online():
  38. # Check internet connectivity
  39. import socket
  40. try:
  41. socket.create_connection(("1.1.1.1", 53)) # check host accesability
  42. return True
  43. except OSError:
  44. return False
  45. def check_git_status():
  46. # Recommend 'git pull' if code is out of date
  47. print(colorstr('github: '), end='')
  48. try:
  49. if Path('.git').exists() and check_online():
  50. url = subprocess.check_output(
  51. 'git fetch && git config --get remote.origin.url', shell=True).decode('utf-8')[:-1]
  52. n = int(subprocess.check_output(
  53. 'git rev-list $(git rev-parse --abbrev-ref HEAD)..origin/master --count', shell=True)) # commits behind
  54. if n > 0:
  55. print(f"⚠️ WARNING: code is out of date by {n} {'commits' if n > 1 else 'commmit'}. "
  56. f"Use 'git pull' to update or 'git clone {url}' to download latest.")
  57. else:
  58. print(f'up to date with {url} ✅')
  59. except Exception as e:
  60. print(e)
  61. def check_requirements(file='requirements.txt'):
  62. # Check installed dependencies meet requirements
  63. import pkg_resources
  64. requirements = pkg_resources.parse_requirements(Path(file).open())
  65. requirements = [x.name + ''.join(*x.specs) if len(x.specs) else x.name for x in requirements]
  66. pkg_resources.require(requirements) # DistributionNotFound or VersionConflict exception if requirements not met
  67. def check_img_size(img_size, s=32):
  68. # Verify img_size is a multiple of stride s
  69. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  70. if new_size != img_size:
  71. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  72. return new_size
  73. def check_file(file):
  74. # Search for file if not found
  75. if os.path.isfile(file) or file == '':
  76. return file
  77. else:
  78. files = glob.glob('./**/' + file, recursive=True) # find file
  79. assert len(files), 'File Not Found: %s' % file # assert file was found
  80. assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files) # assert unique
  81. return files[0] # return file
  82. def check_dataset(dict):
  83. # Download dataset if not found locally
  84. val, s = dict.get('val'), dict.get('download')
  85. if val and len(val):
  86. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  87. if not all(x.exists() for x in val):
  88. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  89. if s and len(s): # download script
  90. print('Downloading %s ...' % s)
  91. if s.startswith('http') and s.endswith('.zip'): # URL
  92. f = Path(s).name # filename
  93. torch.hub.download_url_to_file(s, f)
  94. r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
  95. else: # bash script
  96. r = os.system(s)
  97. print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
  98. else:
  99. raise Exception('Dataset not found.')
  100. def make_divisible(x, divisor):
  101. # Returns x evenly divisible by divisor
  102. return math.ceil(x / divisor) * divisor
  103. def clean_str(s):
  104. # Cleans a string by replacing special characters with underscore _
  105. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  106. def one_cycle(y1=0.0, y2=1.0, steps=100):
  107. # lambda function for sinusoidal ramp from y1 to y2
  108. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  109. def colorstr(*input):
  110. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  111. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  112. colors = {'black': '\033[30m', # basic colors
  113. 'red': '\033[31m',
  114. 'green': '\033[32m',
  115. 'yellow': '\033[33m',
  116. 'blue': '\033[34m',
  117. 'magenta': '\033[35m',
  118. 'cyan': '\033[36m',
  119. 'white': '\033[37m',
  120. 'bright_black': '\033[90m', # bright colors
  121. 'bright_red': '\033[91m',
  122. 'bright_green': '\033[92m',
  123. 'bright_yellow': '\033[93m',
  124. 'bright_blue': '\033[94m',
  125. 'bright_magenta': '\033[95m',
  126. 'bright_cyan': '\033[96m',
  127. 'bright_white': '\033[97m',
  128. 'end': '\033[0m', # misc
  129. 'bold': '\033[1m',
  130. 'underline': '\033[4m'}
  131. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  132. def labels_to_class_weights(labels, nc=80):
  133. # Get class weights (inverse frequency) from training labels
  134. if labels[0] is None: # no labels loaded
  135. return torch.Tensor()
  136. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  137. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  138. weights = np.bincount(classes, minlength=nc) # occurrences per class
  139. # Prepend gridpoint count (for uCE training)
  140. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  141. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  142. weights[weights == 0] = 1 # replace empty bins with 1
  143. weights = 1 / weights # number of targets per class
  144. weights /= weights.sum() # normalize
  145. return torch.from_numpy(weights)
  146. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  147. # Produces image weights based on class_weights and image contents
  148. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  149. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  150. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  151. return image_weights
  152. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  153. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  154. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  155. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  156. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  157. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  158. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  159. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  160. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  161. return x
  162. def xyxy2xywh(x):
  163. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  164. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  165. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  166. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  167. y[:, 2] = x[:, 2] - x[:, 0] # width
  168. y[:, 3] = x[:, 3] - x[:, 1] # height
  169. return y
  170. def xywh2xyxy(x):
  171. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  172. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  173. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  174. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  175. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  176. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  177. return y
  178. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  179. # Rescale coords (xyxy) from img1_shape to img0_shape
  180. if ratio_pad is None: # calculate from img0_shape
  181. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  182. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  183. else:
  184. gain = ratio_pad[0][0]
  185. pad = ratio_pad[1]
  186. coords[:, [0, 2]] -= pad[0] # x padding
  187. coords[:, [1, 3]] -= pad[1] # y padding
  188. coords[:, :4] /= gain
  189. clip_coords(coords, img0_shape)
  190. return coords
  191. def clip_coords(boxes, img_shape):
  192. # Clip bounding xyxy bounding boxes to image shape (height, width)
  193. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  194. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  195. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  196. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  197. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9):
  198. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  199. box2 = box2.T
  200. # Get the coordinates of bounding boxes
  201. if x1y1x2y2: # x1, y1, x2, y2 = box1
  202. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  203. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  204. else: # transform from xywh to xyxy
  205. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  206. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  207. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  208. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  209. # Intersection area
  210. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  211. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  212. # Union Area
  213. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  214. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  215. union = w1 * h1 + w2 * h2 - inter + eps
  216. iou = inter / union
  217. if GIoU or DIoU or CIoU:
  218. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  219. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  220. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  221. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  222. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  223. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  224. if DIoU:
  225. return iou - rho2 / c2 # DIoU
  226. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  227. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  228. with torch.no_grad():
  229. alpha = v / ((1 + eps) - iou + v)
  230. return iou - (rho2 / c2 + v * alpha) # CIoU
  231. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  232. c_area = cw * ch + eps # convex area
  233. return iou - (c_area - union) / c_area # GIoU
  234. else:
  235. return iou # IoU
  236. def box_iou(box1, box2):
  237. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  238. """
  239. Return intersection-over-union (Jaccard index) of boxes.
  240. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  241. Arguments:
  242. box1 (Tensor[N, 4])
  243. box2 (Tensor[M, 4])
  244. Returns:
  245. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  246. IoU values for every element in boxes1 and boxes2
  247. """
  248. def box_area(box):
  249. # box = 4xn
  250. return (box[2] - box[0]) * (box[3] - box[1])
  251. area1 = box_area(box1.T)
  252. area2 = box_area(box2.T)
  253. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  254. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  255. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  256. def wh_iou(wh1, wh2):
  257. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  258. wh1 = wh1[:, None] # [N,1,2]
  259. wh2 = wh2[None] # [1,M,2]
  260. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  261. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  262. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
  263. """Performs Non-Maximum Suppression (NMS) on inference results
  264. Returns:
  265. detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
  266. """
  267. nc = prediction.shape[2] - 5 # number of classes
  268. xc = prediction[..., 4] > conf_thres # candidates
  269. # Settings
  270. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  271. max_det = 300 # maximum number of detections per image
  272. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  273. time_limit = 10.0 # seconds to quit after
  274. redundant = True # require redundant detections
  275. multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
  276. merge = False # use merge-NMS
  277. t = time.time()
  278. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  279. for xi, x in enumerate(prediction): # image index, image inference
  280. # Apply constraints
  281. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  282. x = x[xc[xi]] # confidence
  283. # Cat apriori labels if autolabelling
  284. if labels and len(labels[xi]):
  285. l = labels[xi]
  286. v = torch.zeros((len(l), nc + 5), device=x.device)
  287. v[:, :4] = l[:, 1:5] # box
  288. v[:, 4] = 1.0 # conf
  289. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  290. x = torch.cat((x, v), 0)
  291. # If none remain process next image
  292. if not x.shape[0]:
  293. continue
  294. # Compute conf
  295. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  296. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  297. box = xywh2xyxy(x[:, :4])
  298. # Detections matrix nx6 (xyxy, conf, cls)
  299. if multi_label:
  300. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  301. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  302. else: # best class only
  303. conf, j = x[:, 5:].max(1, keepdim=True)
  304. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  305. # Filter by class
  306. if classes is not None:
  307. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  308. # Apply finite constraint
  309. # if not torch.isfinite(x).all():
  310. # x = x[torch.isfinite(x).all(1)]
  311. # Check shape
  312. n = x.shape[0] # number of boxes
  313. if not n: # no boxes
  314. continue
  315. elif n > max_nms: # excess boxes
  316. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  317. # Batched NMS
  318. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  319. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  320. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  321. if i.shape[0] > max_det: # limit detections
  322. i = i[:max_det]
  323. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  324. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  325. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  326. weights = iou * scores[None] # box weights
  327. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  328. if redundant:
  329. i = i[iou.sum(1) > 1] # require redundancy
  330. output[xi] = x[i]
  331. if (time.time() - t) > time_limit:
  332. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  333. break # time limit exceeded
  334. return output
  335. def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
  336. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  337. x = torch.load(f, map_location=torch.device('cpu'))
  338. for key in 'optimizer', 'training_results', 'wandb_id':
  339. x[key] = None
  340. x['epoch'] = -1
  341. x['model'].half() # to FP16
  342. for p in x['model'].parameters():
  343. p.requires_grad = False
  344. torch.save(x, s or f)
  345. mb = os.path.getsize(s or f) / 1E6 # filesize
  346. print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
  347. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  348. # Print mutation results to evolve.txt (for use with train.py --evolve)
  349. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  350. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  351. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  352. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  353. if bucket:
  354. url = 'gs://%s/evolve.txt' % bucket
  355. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  356. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  357. with open('evolve.txt', 'a') as f: # append result
  358. f.write(c + b + '\n')
  359. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  360. x = x[np.argsort(-fitness(x))] # sort
  361. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  362. # Save yaml
  363. for i, k in enumerate(hyp.keys()):
  364. hyp[k] = float(x[0, i + 7])
  365. with open(yaml_file, 'w') as f:
  366. results = tuple(x[0, :7])
  367. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  368. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  369. yaml.dump(hyp, f, sort_keys=False)
  370. if bucket:
  371. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  372. def apply_classifier(x, model, img, im0):
  373. # applies a second stage classifier to yolo outputs
  374. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  375. for i, d in enumerate(x): # per image
  376. if d is not None and len(d):
  377. d = d.clone()
  378. # Reshape and pad cutouts
  379. b = xyxy2xywh(d[:, :4]) # boxes
  380. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  381. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  382. d[:, :4] = xywh2xyxy(b).long()
  383. # Rescale boxes from img_size to im0 size
  384. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  385. # Classes
  386. pred_cls1 = d[:, 5].long()
  387. ims = []
  388. for j, a in enumerate(d): # per item
  389. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  390. im = cv2.resize(cutout, (224, 224)) # BGR
  391. # cv2.imwrite('test%i.jpg' % j, cutout)
  392. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  393. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  394. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  395. ims.append(im)
  396. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  397. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  398. return x
  399. def increment_path(path, exist_ok=True, sep=''):
  400. # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
  401. path = Path(path) # os-agnostic
  402. if (path.exists() and exist_ok) or (not path.exists()):
  403. return str(path)
  404. else:
  405. dirs = glob.glob(f"{path}{sep}*") # similar paths
  406. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  407. i = [int(m.groups()[0]) for m in matches if m] # indices
  408. n = max(i) + 1 if i else 2 # increment number
  409. return f"{path}{sep}{n}" # update path