You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

446 lines
18KB

  1. # General utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import torch
  15. import torchvision
  16. import yaml
  17. from utils.google_utils import gsutil_getsize
  18. from utils.metrics import fitness
  19. from utils.torch_utils import init_torch_seeds
  20. # Settings
  21. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  22. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  23. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  24. def set_logging(rank=-1):
  25. logging.basicConfig(
  26. format="%(message)s",
  27. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  28. def init_seeds(seed=0):
  29. random.seed(seed)
  30. np.random.seed(seed)
  31. init_torch_seeds(seed)
  32. def get_latest_run(search_dir='.'):
  33. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  34. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  35. return max(last_list, key=os.path.getctime) if last_list else ''
  36. def check_git_status():
  37. # Suggest 'git pull' if repo is out of date
  38. if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'):
  39. s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
  40. if 'Your branch is behind' in s:
  41. print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
  42. def check_img_size(img_size, s=32):
  43. # Verify img_size is a multiple of stride s
  44. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  45. if new_size != img_size:
  46. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  47. return new_size
  48. def check_file(file):
  49. # Search for file if not found
  50. if os.path.isfile(file) or file == '':
  51. return file
  52. else:
  53. files = glob.glob('./**/' + file, recursive=True) # find file
  54. assert len(files), 'File Not Found: %s' % file # assert file was found
  55. assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files) # assert unique
  56. return files[0] # return file
  57. def check_dataset(dict):
  58. # Download dataset if not found locally
  59. val, s = dict.get('val'), dict.get('download')
  60. if val and len(val):
  61. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  62. if not all(x.exists() for x in val):
  63. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  64. if s and len(s): # download script
  65. print('Downloading %s ...' % s)
  66. if s.startswith('http') and s.endswith('.zip'): # URL
  67. f = Path(s).name # filename
  68. torch.hub.download_url_to_file(s, f)
  69. r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
  70. else: # bash script
  71. r = os.system(s)
  72. print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
  73. else:
  74. raise Exception('Dataset not found.')
  75. def make_divisible(x, divisor):
  76. # Returns x evenly divisible by divisor
  77. return math.ceil(x / divisor) * divisor
  78. def clean_str(s):
  79. # Cleans a string by replacing special characters with underscore _
  80. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  81. def labels_to_class_weights(labels, nc=80):
  82. # Get class weights (inverse frequency) from training labels
  83. if labels[0] is None: # no labels loaded
  84. return torch.Tensor()
  85. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  86. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  87. weights = np.bincount(classes, minlength=nc) # occurrences per class
  88. # Prepend gridpoint count (for uCE training)
  89. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  90. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  91. weights[weights == 0] = 1 # replace empty bins with 1
  92. weights = 1 / weights # number of targets per class
  93. weights /= weights.sum() # normalize
  94. return torch.from_numpy(weights)
  95. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  96. # Produces image weights based on class_weights and image contents
  97. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  98. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  99. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  100. return image_weights
  101. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  102. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  103. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  104. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  105. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  106. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  107. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  108. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  109. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  110. return x
  111. def xyxy2xywh(x):
  112. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  113. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  114. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  115. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  116. y[:, 2] = x[:, 2] - x[:, 0] # width
  117. y[:, 3] = x[:, 3] - x[:, 1] # height
  118. return y
  119. def xywh2xyxy(x):
  120. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  121. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  122. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  123. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  124. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  125. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  126. return y
  127. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  128. # Rescale coords (xyxy) from img1_shape to img0_shape
  129. if ratio_pad is None: # calculate from img0_shape
  130. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  131. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  132. else:
  133. gain = ratio_pad[0][0]
  134. pad = ratio_pad[1]
  135. coords[:, [0, 2]] -= pad[0] # x padding
  136. coords[:, [1, 3]] -= pad[1] # y padding
  137. coords[:, :4] /= gain
  138. clip_coords(coords, img0_shape)
  139. return coords
  140. def clip_coords(boxes, img_shape):
  141. # Clip bounding xyxy bounding boxes to image shape (height, width)
  142. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  143. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  144. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  145. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  146. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9):
  147. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  148. box2 = box2.T
  149. # Get the coordinates of bounding boxes
  150. if x1y1x2y2: # x1, y1, x2, y2 = box1
  151. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  152. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  153. else: # transform from xywh to xyxy
  154. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  155. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  156. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  157. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  158. # Intersection area
  159. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  160. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  161. # Union Area
  162. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  163. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  164. union = w1 * h1 + w2 * h2 - inter + eps
  165. iou = inter / union
  166. if GIoU or DIoU or CIoU:
  167. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  168. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  169. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  170. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  171. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  172. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  173. if DIoU:
  174. return iou - rho2 / c2 # DIoU
  175. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  176. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  177. with torch.no_grad():
  178. alpha = v / ((1 + eps) - iou + v)
  179. return iou - (rho2 / c2 + v * alpha) # CIoU
  180. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  181. c_area = cw * ch + eps # convex area
  182. return iou - (c_area - union) / c_area # GIoU
  183. else:
  184. return iou # IoU
  185. def box_iou(box1, box2):
  186. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  187. """
  188. Return intersection-over-union (Jaccard index) of boxes.
  189. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  190. Arguments:
  191. box1 (Tensor[N, 4])
  192. box2 (Tensor[M, 4])
  193. Returns:
  194. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  195. IoU values for every element in boxes1 and boxes2
  196. """
  197. def box_area(box):
  198. # box = 4xn
  199. return (box[2] - box[0]) * (box[3] - box[1])
  200. area1 = box_area(box1.T)
  201. area2 = box_area(box2.T)
  202. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  203. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  204. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  205. def wh_iou(wh1, wh2):
  206. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  207. wh1 = wh1[:, None] # [N,1,2]
  208. wh2 = wh2[None] # [1,M,2]
  209. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  210. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  211. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
  212. """Performs Non-Maximum Suppression (NMS) on inference results
  213. Returns:
  214. detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
  215. """
  216. nc = prediction.shape[2] - 5 # number of classes
  217. xc = prediction[..., 4] > conf_thres # candidates
  218. # Settings
  219. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  220. max_det = 300 # maximum number of detections per image
  221. time_limit = 10.0 # seconds to quit after
  222. redundant = True # require redundant detections
  223. multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
  224. merge = False # use merge-NMS
  225. t = time.time()
  226. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  227. for xi, x in enumerate(prediction): # image index, image inference
  228. # Apply constraints
  229. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  230. x = x[xc[xi]] # confidence
  231. # Cat apriori labels if autolabelling
  232. if labels and len(labels[xi]):
  233. l = labels[xi]
  234. v = torch.zeros((len(l), nc + 5), device=x.device)
  235. v[:, :4] = l[:, 1:5] # box
  236. v[:, 4] = 1.0 # conf
  237. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  238. x = torch.cat((x, v), 0)
  239. # If none remain process next image
  240. if not x.shape[0]:
  241. continue
  242. # Compute conf
  243. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  244. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  245. box = xywh2xyxy(x[:, :4])
  246. # Detections matrix nx6 (xyxy, conf, cls)
  247. if multi_label:
  248. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  249. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  250. else: # best class only
  251. conf, j = x[:, 5:].max(1, keepdim=True)
  252. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  253. # Filter by class
  254. if classes is not None:
  255. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  256. # Apply finite constraint
  257. # if not torch.isfinite(x).all():
  258. # x = x[torch.isfinite(x).all(1)]
  259. # If none remain process next image
  260. n = x.shape[0] # number of boxes
  261. if not n:
  262. continue
  263. # Sort by confidence
  264. # x = x[x[:, 4].argsort(descending=True)]
  265. # Batched NMS
  266. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  267. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  268. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  269. if i.shape[0] > max_det: # limit detections
  270. i = i[:max_det]
  271. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  272. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  273. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  274. weights = iou * scores[None] # box weights
  275. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  276. if redundant:
  277. i = i[iou.sum(1) > 1] # require redundancy
  278. output[xi] = x[i]
  279. if (time.time() - t) > time_limit:
  280. break # time limit exceeded
  281. return output
  282. def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
  283. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  284. x = torch.load(f, map_location=torch.device('cpu'))
  285. x['optimizer'] = None
  286. x['training_results'] = None
  287. x['epoch'] = -1
  288. x['model'].half() # to FP16
  289. for p in x['model'].parameters():
  290. p.requires_grad = False
  291. torch.save(x, s or f)
  292. mb = os.path.getsize(s or f) / 1E6 # filesize
  293. print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
  294. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  295. # Print mutation results to evolve.txt (for use with train.py --evolve)
  296. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  297. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  298. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  299. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  300. if bucket:
  301. url = 'gs://%s/evolve.txt' % bucket
  302. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  303. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  304. with open('evolve.txt', 'a') as f: # append result
  305. f.write(c + b + '\n')
  306. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  307. x = x[np.argsort(-fitness(x))] # sort
  308. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  309. # Save yaml
  310. for i, k in enumerate(hyp.keys()):
  311. hyp[k] = float(x[0, i + 7])
  312. with open(yaml_file, 'w') as f:
  313. results = tuple(x[0, :7])
  314. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  315. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  316. yaml.dump(hyp, f, sort_keys=False)
  317. if bucket:
  318. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  319. def apply_classifier(x, model, img, im0):
  320. # applies a second stage classifier to yolo outputs
  321. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  322. for i, d in enumerate(x): # per image
  323. if d is not None and len(d):
  324. d = d.clone()
  325. # Reshape and pad cutouts
  326. b = xyxy2xywh(d[:, :4]) # boxes
  327. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  328. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  329. d[:, :4] = xywh2xyxy(b).long()
  330. # Rescale boxes from img_size to im0 size
  331. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  332. # Classes
  333. pred_cls1 = d[:, 5].long()
  334. ims = []
  335. for j, a in enumerate(d): # per item
  336. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  337. im = cv2.resize(cutout, (224, 224)) # BGR
  338. # cv2.imwrite('test%i.jpg' % j, cutout)
  339. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  340. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  341. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  342. ims.append(im)
  343. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  344. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  345. return x
  346. def increment_path(path, exist_ok=True, sep=''):
  347. # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
  348. path = Path(path) # os-agnostic
  349. if (path.exists() and exist_ok) or (not path.exists()):
  350. return str(path)
  351. else:
  352. dirs = glob.glob(f"{path}{sep}*") # similar paths
  353. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  354. i = [int(m.groups()[0]) for m in matches if m] # indices
  355. n = max(i) + 1 if i else 2 # increment number
  356. return f"{path}{sep}{n}" # update path