Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

601 lines
24KB

  1. # YOLOv5 general utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import pandas as pd
  15. import torch
  16. import torchvision
  17. import yaml
  18. from utils.google_utils import gsutil_getsize
  19. from utils.metrics import fitness
  20. from utils.torch_utils import init_torch_seeds
  21. # Settings
  22. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  23. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  24. pd.options.display.max_columns = 10
  25. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  26. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  27. def set_logging(rank=-1):
  28. logging.basicConfig(
  29. format="%(message)s",
  30. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  31. def init_seeds(seed=0):
  32. # Initialize random number generator (RNG) seeds
  33. random.seed(seed)
  34. np.random.seed(seed)
  35. init_torch_seeds(seed)
  36. def get_latest_run(search_dir='.'):
  37. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  38. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  39. return max(last_list, key=os.path.getctime) if last_list else ''
  40. def isdocker():
  41. # Is environment a Docker container
  42. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  43. def emojis(str=''):
  44. # Return platform-dependent emoji-safe version of string
  45. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  46. def check_online():
  47. # Check internet connectivity
  48. import socket
  49. try:
  50. socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
  51. return True
  52. except OSError:
  53. return False
  54. def check_git_status():
  55. # Recommend 'git pull' if code is out of date
  56. print(colorstr('github: '), end='')
  57. try:
  58. assert Path('.git').exists(), 'skipping check (not a git repository)'
  59. assert not isdocker(), 'skipping check (Docker image)'
  60. assert check_online(), 'skipping check (offline)'
  61. cmd = 'git fetch && git config --get remote.origin.url'
  62. url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
  63. branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  64. n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  65. if n > 0:
  66. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  67. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  68. else:
  69. s = f'up to date with {url} ✅'
  70. print(emojis(s)) # emoji-safe
  71. except Exception as e:
  72. print(e)
  73. def check_requirements(file='requirements.txt', exclude=()):
  74. # Check installed dependencies meet requirements
  75. import pkg_resources as pkg
  76. prefix = colorstr('red', 'bold', 'requirements:')
  77. file = Path(file)
  78. if not file.exists():
  79. print(f"{prefix} {file.resolve()} not found, check failed.")
  80. return
  81. n = 0 # number of packages updates
  82. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  83. for r in requirements:
  84. try:
  85. pkg.require(r)
  86. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  87. n += 1
  88. print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
  89. print(subprocess.check_output(f"pip install '{e.req}'", shell=True).decode())
  90. if n: # if packages updated
  91. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file.resolve()}\n" \
  92. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  93. print(emojis(s)) # emoji-safe
  94. def check_img_size(img_size, s=32):
  95. # Verify img_size is a multiple of stride s
  96. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  97. if new_size != img_size:
  98. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  99. return new_size
  100. def check_imshow():
  101. # Check if environment supports image displays
  102. try:
  103. assert not isdocker(), 'cv2.imshow() is disabled in Docker environments'
  104. cv2.imshow('test', np.zeros((1, 1, 3)))
  105. cv2.waitKey(1)
  106. cv2.destroyAllWindows()
  107. cv2.waitKey(1)
  108. return True
  109. except Exception as e:
  110. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  111. return False
  112. def check_file(file):
  113. # Search for file if not found
  114. if os.path.isfile(file) or file == '':
  115. return file
  116. else:
  117. files = glob.glob('./**/' + file, recursive=True) # find file
  118. assert len(files), 'File Not Found: %s' % file # assert file was found
  119. assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files) # assert unique
  120. return files[0] # return file
  121. def check_dataset(dict):
  122. # Download dataset if not found locally
  123. val, s = dict.get('val'), dict.get('download')
  124. if val and len(val):
  125. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  126. if not all(x.exists() for x in val):
  127. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  128. if s and len(s): # download script
  129. print('Downloading %s ...' % s)
  130. if s.startswith('http') and s.endswith('.zip'): # URL
  131. f = Path(s).name # filename
  132. torch.hub.download_url_to_file(s, f)
  133. r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
  134. else: # bash script
  135. r = os.system(s)
  136. print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
  137. else:
  138. raise Exception('Dataset not found.')
  139. def make_divisible(x, divisor):
  140. # Returns x evenly divisible by divisor
  141. return math.ceil(x / divisor) * divisor
  142. def clean_str(s):
  143. # Cleans a string by replacing special characters with underscore _
  144. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  145. def one_cycle(y1=0.0, y2=1.0, steps=100):
  146. # lambda function for sinusoidal ramp from y1 to y2
  147. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  148. def colorstr(*input):
  149. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  150. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  151. colors = {'black': '\033[30m', # basic colors
  152. 'red': '\033[31m',
  153. 'green': '\033[32m',
  154. 'yellow': '\033[33m',
  155. 'blue': '\033[34m',
  156. 'magenta': '\033[35m',
  157. 'cyan': '\033[36m',
  158. 'white': '\033[37m',
  159. 'bright_black': '\033[90m', # bright colors
  160. 'bright_red': '\033[91m',
  161. 'bright_green': '\033[92m',
  162. 'bright_yellow': '\033[93m',
  163. 'bright_blue': '\033[94m',
  164. 'bright_magenta': '\033[95m',
  165. 'bright_cyan': '\033[96m',
  166. 'bright_white': '\033[97m',
  167. 'end': '\033[0m', # misc
  168. 'bold': '\033[1m',
  169. 'underline': '\033[4m'}
  170. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  171. def labels_to_class_weights(labels, nc=80):
  172. # Get class weights (inverse frequency) from training labels
  173. if labels[0] is None: # no labels loaded
  174. return torch.Tensor()
  175. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  176. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  177. weights = np.bincount(classes, minlength=nc) # occurrences per class
  178. # Prepend gridpoint count (for uCE training)
  179. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  180. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  181. weights[weights == 0] = 1 # replace empty bins with 1
  182. weights = 1 / weights # number of targets per class
  183. weights /= weights.sum() # normalize
  184. return torch.from_numpy(weights)
  185. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  186. # Produces image weights based on class_weights and image contents
  187. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  188. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  189. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  190. return image_weights
  191. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  192. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  193. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  194. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  195. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  196. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  197. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  198. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  199. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  200. return x
  201. def xyxy2xywh(x):
  202. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  203. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  204. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  205. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  206. y[:, 2] = x[:, 2] - x[:, 0] # width
  207. y[:, 3] = x[:, 3] - x[:, 1] # height
  208. return y
  209. def xywh2xyxy(x):
  210. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  211. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  212. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  213. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  214. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  215. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  216. return y
  217. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  218. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  219. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  220. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  221. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  222. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  223. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  224. return y
  225. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  226. # Convert normalized segments into pixel segments, shape (n,2)
  227. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  228. y[:, 0] = w * x[:, 0] + padw # top left x
  229. y[:, 1] = h * x[:, 1] + padh # top left y
  230. return y
  231. def segment2box(segment, width=640, height=640):
  232. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  233. x, y = segment.T # segment xy
  234. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  235. x, y, = x[inside], y[inside]
  236. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  237. def segments2boxes(segments):
  238. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  239. boxes = []
  240. for s in segments:
  241. x, y = s.T # segment xy
  242. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  243. return xyxy2xywh(np.array(boxes)) # cls, xywh
  244. def resample_segments(segments, n=1000):
  245. # Up-sample an (n,2) segment
  246. for i, s in enumerate(segments):
  247. x = np.linspace(0, len(s) - 1, n)
  248. xp = np.arange(len(s))
  249. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  250. return segments
  251. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  252. # Rescale coords (xyxy) from img1_shape to img0_shape
  253. if ratio_pad is None: # calculate from img0_shape
  254. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  255. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  256. else:
  257. gain = ratio_pad[0][0]
  258. pad = ratio_pad[1]
  259. coords[:, [0, 2]] -= pad[0] # x padding
  260. coords[:, [1, 3]] -= pad[1] # y padding
  261. coords[:, :4] /= gain
  262. clip_coords(coords, img0_shape)
  263. return coords
  264. def clip_coords(boxes, img_shape):
  265. # Clip bounding xyxy bounding boxes to image shape (height, width)
  266. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  267. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  268. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  269. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  270. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  271. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  272. box2 = box2.T
  273. # Get the coordinates of bounding boxes
  274. if x1y1x2y2: # x1, y1, x2, y2 = box1
  275. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  276. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  277. else: # transform from xywh to xyxy
  278. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  279. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  280. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  281. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  282. # Intersection area
  283. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  284. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  285. # Union Area
  286. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  287. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  288. union = w1 * h1 + w2 * h2 - inter + eps
  289. iou = inter / union
  290. if GIoU or DIoU or CIoU:
  291. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  292. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  293. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  294. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  295. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  296. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  297. if DIoU:
  298. return iou - rho2 / c2 # DIoU
  299. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  300. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  301. with torch.no_grad():
  302. alpha = v / (v - iou + (1 + eps))
  303. return iou - (rho2 / c2 + v * alpha) # CIoU
  304. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  305. c_area = cw * ch + eps # convex area
  306. return iou - (c_area - union) / c_area # GIoU
  307. else:
  308. return iou # IoU
  309. def box_iou(box1, box2):
  310. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  311. """
  312. Return intersection-over-union (Jaccard index) of boxes.
  313. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  314. Arguments:
  315. box1 (Tensor[N, 4])
  316. box2 (Tensor[M, 4])
  317. Returns:
  318. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  319. IoU values for every element in boxes1 and boxes2
  320. """
  321. def box_area(box):
  322. # box = 4xn
  323. return (box[2] - box[0]) * (box[3] - box[1])
  324. area1 = box_area(box1.T)
  325. area2 = box_area(box2.T)
  326. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  327. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  328. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  329. def wh_iou(wh1, wh2):
  330. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  331. wh1 = wh1[:, None] # [N,1,2]
  332. wh2 = wh2[None] # [1,M,2]
  333. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  334. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  335. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  336. labels=()):
  337. """Runs Non-Maximum Suppression (NMS) on inference results
  338. Returns:
  339. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  340. """
  341. nc = prediction.shape[2] - 5 # number of classes
  342. xc = prediction[..., 4] > conf_thres # candidates
  343. # Settings
  344. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  345. max_det = 300 # maximum number of detections per image
  346. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  347. time_limit = 10.0 # seconds to quit after
  348. redundant = True # require redundant detections
  349. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  350. merge = False # use merge-NMS
  351. t = time.time()
  352. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  353. for xi, x in enumerate(prediction): # image index, image inference
  354. # Apply constraints
  355. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  356. x = x[xc[xi]] # confidence
  357. # Cat apriori labels if autolabelling
  358. if labels and len(labels[xi]):
  359. l = labels[xi]
  360. v = torch.zeros((len(l), nc + 5), device=x.device)
  361. v[:, :4] = l[:, 1:5] # box
  362. v[:, 4] = 1.0 # conf
  363. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  364. x = torch.cat((x, v), 0)
  365. # If none remain process next image
  366. if not x.shape[0]:
  367. continue
  368. # Compute conf
  369. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  370. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  371. box = xywh2xyxy(x[:, :4])
  372. # Detections matrix nx6 (xyxy, conf, cls)
  373. if multi_label:
  374. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  375. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  376. else: # best class only
  377. conf, j = x[:, 5:].max(1, keepdim=True)
  378. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  379. # Filter by class
  380. if classes is not None:
  381. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  382. # Apply finite constraint
  383. # if not torch.isfinite(x).all():
  384. # x = x[torch.isfinite(x).all(1)]
  385. # Check shape
  386. n = x.shape[0] # number of boxes
  387. if not n: # no boxes
  388. continue
  389. elif n > max_nms: # excess boxes
  390. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  391. # Batched NMS
  392. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  393. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  394. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  395. if i.shape[0] > max_det: # limit detections
  396. i = i[:max_det]
  397. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  398. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  399. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  400. weights = iou * scores[None] # box weights
  401. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  402. if redundant:
  403. i = i[iou.sum(1) > 1] # require redundancy
  404. output[xi] = x[i]
  405. if (time.time() - t) > time_limit:
  406. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  407. break # time limit exceeded
  408. return output
  409. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  410. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  411. x = torch.load(f, map_location=torch.device('cpu'))
  412. if x.get('ema'):
  413. x['model'] = x['ema'] # replace model with ema
  414. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  415. x[k] = None
  416. x['epoch'] = -1
  417. x['model'].half() # to FP16
  418. for p in x['model'].parameters():
  419. p.requires_grad = False
  420. torch.save(x, s or f)
  421. mb = os.path.getsize(s or f) / 1E6 # filesize
  422. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  423. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  424. # Print mutation results to evolve.txt (for use with train.py --evolve)
  425. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  426. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  427. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  428. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  429. if bucket:
  430. url = 'gs://%s/evolve.txt' % bucket
  431. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  432. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  433. with open('evolve.txt', 'a') as f: # append result
  434. f.write(c + b + '\n')
  435. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  436. x = x[np.argsort(-fitness(x))] # sort
  437. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  438. # Save yaml
  439. for i, k in enumerate(hyp.keys()):
  440. hyp[k] = float(x[0, i + 7])
  441. with open(yaml_file, 'w') as f:
  442. results = tuple(x[0, :7])
  443. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  444. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  445. yaml.dump(hyp, f, sort_keys=False)
  446. if bucket:
  447. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  448. def apply_classifier(x, model, img, im0):
  449. # applies a second stage classifier to yolo outputs
  450. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  451. for i, d in enumerate(x): # per image
  452. if d is not None and len(d):
  453. d = d.clone()
  454. # Reshape and pad cutouts
  455. b = xyxy2xywh(d[:, :4]) # boxes
  456. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  457. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  458. d[:, :4] = xywh2xyxy(b).long()
  459. # Rescale boxes from img_size to im0 size
  460. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  461. # Classes
  462. pred_cls1 = d[:, 5].long()
  463. ims = []
  464. for j, a in enumerate(d): # per item
  465. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  466. im = cv2.resize(cutout, (224, 224)) # BGR
  467. # cv2.imwrite('test%i.jpg' % j, cutout)
  468. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  469. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  470. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  471. ims.append(im)
  472. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  473. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  474. return x
  475. def increment_path(path, exist_ok=True, sep=''):
  476. # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
  477. path = Path(path) # os-agnostic
  478. if (path.exists() and exist_ok) or (not path.exists()):
  479. return str(path)
  480. else:
  481. dirs = glob.glob(f"{path}{sep}*") # similar paths
  482. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  483. i = [int(m.groups()[0]) for m in matches if m] # indices
  484. n = max(i) + 1 if i else 2 # increment number
  485. return f"{path}{sep}{n}" # update path