You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

460 lines
19KB

  1. # General utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import torch
  15. import torchvision
  16. import yaml
  17. from utils.google_utils import gsutil_getsize
  18. from utils.metrics import fitness
  19. from utils.torch_utils import init_torch_seeds
  20. # Settings
  21. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  22. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  23. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  24. def set_logging(rank=-1):
  25. logging.basicConfig(
  26. format="%(message)s",
  27. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  28. def init_seeds(seed=0):
  29. random.seed(seed)
  30. np.random.seed(seed)
  31. init_torch_seeds(seed)
  32. def get_latest_run(search_dir='.'):
  33. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  34. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  35. return max(last_list, key=os.path.getctime) if last_list else ''
  36. def check_git_status():
  37. # Suggest 'git pull' if repo is out of date
  38. if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'):
  39. s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
  40. if 'Your branch is behind' in s:
  41. print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
  42. def check_requirements(file='requirements.txt'):
  43. # Check installed dependencies meet requirements
  44. import pkg_resources
  45. requirements = pkg_resources.parse_requirements(Path(file).open())
  46. requirements = [x.name + ''.join(*x.specs) if len(x.specs) else x.name for x in requirements]
  47. pkg_resources.require(requirements) # DistributionNotFound or VersionConflict exception if requirements not met
  48. def check_img_size(img_size, s=32):
  49. # Verify img_size is a multiple of stride s
  50. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  51. if new_size != img_size:
  52. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  53. return new_size
  54. def check_file(file):
  55. # Search for file if not found
  56. if os.path.isfile(file) or file == '':
  57. return file
  58. else:
  59. files = glob.glob('./**/' + file, recursive=True) # find file
  60. assert len(files), 'File Not Found: %s' % file # assert file was found
  61. assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files) # assert unique
  62. return files[0] # return file
  63. def check_dataset(dict):
  64. # Download dataset if not found locally
  65. val, s = dict.get('val'), dict.get('download')
  66. if val and len(val):
  67. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  68. if not all(x.exists() for x in val):
  69. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  70. if s and len(s): # download script
  71. print('Downloading %s ...' % s)
  72. if s.startswith('http') and s.endswith('.zip'): # URL
  73. f = Path(s).name # filename
  74. torch.hub.download_url_to_file(s, f)
  75. r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
  76. else: # bash script
  77. r = os.system(s)
  78. print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
  79. else:
  80. raise Exception('Dataset not found.')
  81. def make_divisible(x, divisor):
  82. # Returns x evenly divisible by divisor
  83. return math.ceil(x / divisor) * divisor
  84. def clean_str(s):
  85. # Cleans a string by replacing special characters with underscore _
  86. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  87. def one_cycle(y1=0.0, y2=1.0, steps=100):
  88. # lambda function for sinusoidal ramp from y1 to y2
  89. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  90. def labels_to_class_weights(labels, nc=80):
  91. # Get class weights (inverse frequency) from training labels
  92. if labels[0] is None: # no labels loaded
  93. return torch.Tensor()
  94. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  95. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  96. weights = np.bincount(classes, minlength=nc) # occurrences per class
  97. # Prepend gridpoint count (for uCE training)
  98. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  99. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  100. weights[weights == 0] = 1 # replace empty bins with 1
  101. weights = 1 / weights # number of targets per class
  102. weights /= weights.sum() # normalize
  103. return torch.from_numpy(weights)
  104. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  105. # Produces image weights based on class_weights and image contents
  106. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  107. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  108. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  109. return image_weights
  110. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  111. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  112. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  113. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  114. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  115. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  116. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  117. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  118. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  119. return x
  120. def xyxy2xywh(x):
  121. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  122. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  123. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  124. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  125. y[:, 2] = x[:, 2] - x[:, 0] # width
  126. y[:, 3] = x[:, 3] - x[:, 1] # height
  127. return y
  128. def xywh2xyxy(x):
  129. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  130. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  131. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  132. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  133. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  134. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  135. return y
  136. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  137. # Rescale coords (xyxy) from img1_shape to img0_shape
  138. if ratio_pad is None: # calculate from img0_shape
  139. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  140. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  141. else:
  142. gain = ratio_pad[0][0]
  143. pad = ratio_pad[1]
  144. coords[:, [0, 2]] -= pad[0] # x padding
  145. coords[:, [1, 3]] -= pad[1] # y padding
  146. coords[:, :4] /= gain
  147. clip_coords(coords, img0_shape)
  148. return coords
  149. def clip_coords(boxes, img_shape):
  150. # Clip bounding xyxy bounding boxes to image shape (height, width)
  151. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  152. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  153. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  154. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  155. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9):
  156. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  157. box2 = box2.T
  158. # Get the coordinates of bounding boxes
  159. if x1y1x2y2: # x1, y1, x2, y2 = box1
  160. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  161. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  162. else: # transform from xywh to xyxy
  163. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  164. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  165. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  166. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  167. # Intersection area
  168. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  169. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  170. # Union Area
  171. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  172. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  173. union = w1 * h1 + w2 * h2 - inter + eps
  174. iou = inter / union
  175. if GIoU or DIoU or CIoU:
  176. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  177. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  178. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  179. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  180. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  181. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  182. if DIoU:
  183. return iou - rho2 / c2 # DIoU
  184. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  185. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  186. with torch.no_grad():
  187. alpha = v / ((1 + eps) - iou + v)
  188. return iou - (rho2 / c2 + v * alpha) # CIoU
  189. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  190. c_area = cw * ch + eps # convex area
  191. return iou - (c_area - union) / c_area # GIoU
  192. else:
  193. return iou # IoU
  194. def box_iou(box1, box2):
  195. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  196. """
  197. Return intersection-over-union (Jaccard index) of boxes.
  198. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  199. Arguments:
  200. box1 (Tensor[N, 4])
  201. box2 (Tensor[M, 4])
  202. Returns:
  203. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  204. IoU values for every element in boxes1 and boxes2
  205. """
  206. def box_area(box):
  207. # box = 4xn
  208. return (box[2] - box[0]) * (box[3] - box[1])
  209. area1 = box_area(box1.T)
  210. area2 = box_area(box2.T)
  211. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  212. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  213. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  214. def wh_iou(wh1, wh2):
  215. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  216. wh1 = wh1[:, None] # [N,1,2]
  217. wh2 = wh2[None] # [1,M,2]
  218. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  219. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  220. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
  221. """Performs Non-Maximum Suppression (NMS) on inference results
  222. Returns:
  223. detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
  224. """
  225. nc = prediction.shape[2] - 5 # number of classes
  226. xc = prediction[..., 4] > conf_thres # candidates
  227. # Settings
  228. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  229. max_det = 300 # maximum number of detections per image
  230. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  231. time_limit = 10.0 # seconds to quit after
  232. redundant = True # require redundant detections
  233. multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
  234. merge = False # use merge-NMS
  235. t = time.time()
  236. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  237. for xi, x in enumerate(prediction): # image index, image inference
  238. # Apply constraints
  239. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  240. x = x[xc[xi]] # confidence
  241. # Cat apriori labels if autolabelling
  242. if labels and len(labels[xi]):
  243. l = labels[xi]
  244. v = torch.zeros((len(l), nc + 5), device=x.device)
  245. v[:, :4] = l[:, 1:5] # box
  246. v[:, 4] = 1.0 # conf
  247. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  248. x = torch.cat((x, v), 0)
  249. # If none remain process next image
  250. if not x.shape[0]:
  251. continue
  252. # Compute conf
  253. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  254. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  255. box = xywh2xyxy(x[:, :4])
  256. # Detections matrix nx6 (xyxy, conf, cls)
  257. if multi_label:
  258. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  259. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  260. else: # best class only
  261. conf, j = x[:, 5:].max(1, keepdim=True)
  262. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  263. # Filter by class
  264. if classes is not None:
  265. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  266. # Apply finite constraint
  267. # if not torch.isfinite(x).all():
  268. # x = x[torch.isfinite(x).all(1)]
  269. # Check shape
  270. n = x.shape[0] # number of boxes
  271. if not n: # no boxes
  272. continue
  273. elif n > max_nms: # excess boxes
  274. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  275. # Batched NMS
  276. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  277. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  278. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  279. if i.shape[0] > max_det: # limit detections
  280. i = i[:max_det]
  281. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  282. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  283. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  284. weights = iou * scores[None] # box weights
  285. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  286. if redundant:
  287. i = i[iou.sum(1) > 1] # require redundancy
  288. output[xi] = x[i]
  289. if (time.time() - t) > time_limit:
  290. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  291. break # time limit exceeded
  292. return output
  293. def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
  294. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  295. x = torch.load(f, map_location=torch.device('cpu'))
  296. for key in 'optimizer', 'training_results', 'wandb_id':
  297. x[key] = None
  298. x['epoch'] = -1
  299. x['model'].half() # to FP16
  300. for p in x['model'].parameters():
  301. p.requires_grad = False
  302. torch.save(x, s or f)
  303. mb = os.path.getsize(s or f) / 1E6 # filesize
  304. print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
  305. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  306. # Print mutation results to evolve.txt (for use with train.py --evolve)
  307. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  308. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  309. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  310. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  311. if bucket:
  312. url = 'gs://%s/evolve.txt' % bucket
  313. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  314. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  315. with open('evolve.txt', 'a') as f: # append result
  316. f.write(c + b + '\n')
  317. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  318. x = x[np.argsort(-fitness(x))] # sort
  319. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  320. # Save yaml
  321. for i, k in enumerate(hyp.keys()):
  322. hyp[k] = float(x[0, i + 7])
  323. with open(yaml_file, 'w') as f:
  324. results = tuple(x[0, :7])
  325. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  326. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  327. yaml.dump(hyp, f, sort_keys=False)
  328. if bucket:
  329. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  330. def apply_classifier(x, model, img, im0):
  331. # applies a second stage classifier to yolo outputs
  332. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  333. for i, d in enumerate(x): # per image
  334. if d is not None and len(d):
  335. d = d.clone()
  336. # Reshape and pad cutouts
  337. b = xyxy2xywh(d[:, :4]) # boxes
  338. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  339. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  340. d[:, :4] = xywh2xyxy(b).long()
  341. # Rescale boxes from img_size to im0 size
  342. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  343. # Classes
  344. pred_cls1 = d[:, 5].long()
  345. ims = []
  346. for j, a in enumerate(d): # per item
  347. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  348. im = cv2.resize(cutout, (224, 224)) # BGR
  349. # cv2.imwrite('test%i.jpg' % j, cutout)
  350. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  351. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  352. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  353. ims.append(im)
  354. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  355. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  356. return x
  357. def increment_path(path, exist_ok=True, sep=''):
  358. # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
  359. path = Path(path) # os-agnostic
  360. if (path.exists() and exist_ok) or (not path.exists()):
  361. return str(path)
  362. else:
  363. dirs = glob.glob(f"{path}{sep}*") # similar paths
  364. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  365. i = [int(m.groups()[0]) for m in matches if m] # indices
  366. n = max(i) + 1 if i else 2 # increment number
  367. return f"{path}{sep}{n}" # update path