You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

689 lines
28KB

  1. # YOLOv5 general utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from itertools import repeat
  12. from multiprocessing.pool import ThreadPool
  13. from pathlib import Path
  14. import cv2
  15. import numpy as np
  16. import pandas as pd
  17. import pkg_resources as pkg
  18. import torch
  19. import torchvision
  20. import yaml
  21. from utils.google_utils import gsutil_getsize
  22. from utils.metrics import fitness
  23. from utils.torch_utils import init_torch_seeds
  24. # Settings
  25. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  26. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  27. pd.options.display.max_columns = 10
  28. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  29. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  30. def set_logging(rank=-1, verbose=True):
  31. logging.basicConfig(
  32. format="%(message)s",
  33. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  34. def init_seeds(seed=0):
  35. # Initialize random number generator (RNG) seeds
  36. random.seed(seed)
  37. np.random.seed(seed)
  38. init_torch_seeds(seed)
  39. def get_latest_run(search_dir='.'):
  40. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  41. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  42. return max(last_list, key=os.path.getctime) if last_list else ''
  43. def is_docker():
  44. # Is environment a Docker container
  45. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  46. def is_colab():
  47. # Is environment a Google Colab instance
  48. try:
  49. import google.colab
  50. return True
  51. except Exception as e:
  52. return False
  53. def emojis(str=''):
  54. # Return platform-dependent emoji-safe version of string
  55. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  56. def file_size(file):
  57. # Return file size in MB
  58. return Path(file).stat().st_size / 1e6
  59. def check_online():
  60. # Check internet connectivity
  61. import socket
  62. try:
  63. socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
  64. return True
  65. except OSError:
  66. return False
  67. def check_git_status():
  68. # Recommend 'git pull' if code is out of date
  69. print(colorstr('github: '), end='')
  70. try:
  71. assert Path('.git').exists(), 'skipping check (not a git repository)'
  72. assert not is_docker(), 'skipping check (Docker image)'
  73. assert check_online(), 'skipping check (offline)'
  74. cmd = 'git fetch && git config --get remote.origin.url'
  75. url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
  76. branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  77. n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  78. if n > 0:
  79. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  80. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  81. else:
  82. s = f'up to date with {url} ✅'
  83. print(emojis(s)) # emoji-safe
  84. except Exception as e:
  85. print(e)
  86. def check_python(minimum='3.7.0', required=True):
  87. # Check current python version vs. required python version
  88. current = platform.python_version()
  89. result = pkg.parse_version(current) >= pkg.parse_version(minimum)
  90. if required:
  91. assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed'
  92. return result
  93. def check_requirements(requirements='requirements.txt', exclude=()):
  94. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  95. prefix = colorstr('red', 'bold', 'requirements:')
  96. check_python() # check python version
  97. if isinstance(requirements, (str, Path)): # requirements.txt file
  98. file = Path(requirements)
  99. if not file.exists():
  100. print(f"{prefix} {file.resolve()} not found, check failed.")
  101. return
  102. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  103. else: # list or tuple of packages
  104. requirements = [x for x in requirements if x not in exclude]
  105. n = 0 # number of packages updates
  106. for r in requirements:
  107. try:
  108. pkg.require(r)
  109. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  110. n += 1
  111. print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...")
  112. print(subprocess.check_output(f"pip install '{r}'", shell=True).decode())
  113. if n: # if packages updated
  114. source = file.resolve() if 'file' in locals() else requirements
  115. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  116. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  117. print(emojis(s)) # emoji-safe
  118. def check_img_size(img_size, s=32):
  119. # Verify img_size is a multiple of stride s
  120. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  121. if new_size != img_size:
  122. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  123. return new_size
  124. def check_imshow():
  125. # Check if environment supports image displays
  126. try:
  127. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  128. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  129. cv2.imshow('test', np.zeros((1, 1, 3)))
  130. cv2.waitKey(1)
  131. cv2.destroyAllWindows()
  132. cv2.waitKey(1)
  133. return True
  134. except Exception as e:
  135. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  136. return False
  137. def check_file(file):
  138. # Search for file if not found
  139. if Path(file).is_file() or file == '':
  140. return file
  141. else:
  142. files = glob.glob('./**/' + file, recursive=True) # find file
  143. assert len(files), f'File Not Found: {file}' # assert file was found
  144. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  145. return files[0] # return file
  146. def check_dataset(dict):
  147. # Download dataset if not found locally
  148. val, s = dict.get('val'), dict.get('download')
  149. if val and len(val):
  150. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  151. if not all(x.exists() for x in val):
  152. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  153. if s and len(s): # download script
  154. if s.startswith('http') and s.endswith('.zip'): # URL
  155. f = Path(s).name # filename
  156. print(f'Downloading {s} ...')
  157. torch.hub.download_url_to_file(s, f)
  158. r = os.system(f'unzip -q {f} -d ../ && rm {f}') # unzip
  159. elif s.startswith('bash '): # bash script
  160. print(f'Running {s} ...')
  161. r = os.system(s)
  162. else: # python script
  163. r = exec(s) # return None
  164. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  165. else:
  166. raise Exception('Dataset not found.')
  167. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  168. # Multi-threaded file download and unzip function
  169. def download_one(url, dir):
  170. # Download 1 file
  171. f = dir / Path(url).name # filename
  172. if not f.exists():
  173. print(f'Downloading {url} to {f}...')
  174. if curl:
  175. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  176. else:
  177. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  178. if unzip and f.suffix in ('.zip', '.gz'):
  179. print(f'Unzipping {f}...')
  180. if f.suffix == '.zip':
  181. s = f'unzip -qo {f} -d {dir} && rm {f}' # unzip -quiet -overwrite
  182. elif f.suffix == '.gz':
  183. s = f'tar xfz {f} --directory {f.parent}' # unzip
  184. if delete: # delete zip file after unzip
  185. s += f' && rm {f}'
  186. os.system(s)
  187. dir = Path(dir)
  188. dir.mkdir(parents=True, exist_ok=True) # make directory
  189. if threads > 1:
  190. pool = ThreadPool(threads)
  191. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  192. pool.close()
  193. pool.join()
  194. else:
  195. for u in tuple(url) if isinstance(url, str) else url:
  196. download_one(u, dir)
  197. def make_divisible(x, divisor):
  198. # Returns x evenly divisible by divisor
  199. return math.ceil(x / divisor) * divisor
  200. def clean_str(s):
  201. # Cleans a string by replacing special characters with underscore _
  202. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  203. def one_cycle(y1=0.0, y2=1.0, steps=100):
  204. # lambda function for sinusoidal ramp from y1 to y2
  205. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  206. def colorstr(*input):
  207. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  208. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  209. colors = {'black': '\033[30m', # basic colors
  210. 'red': '\033[31m',
  211. 'green': '\033[32m',
  212. 'yellow': '\033[33m',
  213. 'blue': '\033[34m',
  214. 'magenta': '\033[35m',
  215. 'cyan': '\033[36m',
  216. 'white': '\033[37m',
  217. 'bright_black': '\033[90m', # bright colors
  218. 'bright_red': '\033[91m',
  219. 'bright_green': '\033[92m',
  220. 'bright_yellow': '\033[93m',
  221. 'bright_blue': '\033[94m',
  222. 'bright_magenta': '\033[95m',
  223. 'bright_cyan': '\033[96m',
  224. 'bright_white': '\033[97m',
  225. 'end': '\033[0m', # misc
  226. 'bold': '\033[1m',
  227. 'underline': '\033[4m'}
  228. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  229. def labels_to_class_weights(labels, nc=80):
  230. # Get class weights (inverse frequency) from training labels
  231. if labels[0] is None: # no labels loaded
  232. return torch.Tensor()
  233. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  234. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  235. weights = np.bincount(classes, minlength=nc) # occurrences per class
  236. # Prepend gridpoint count (for uCE training)
  237. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  238. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  239. weights[weights == 0] = 1 # replace empty bins with 1
  240. weights = 1 / weights # number of targets per class
  241. weights /= weights.sum() # normalize
  242. return torch.from_numpy(weights)
  243. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  244. # Produces image weights based on class_weights and image contents
  245. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  246. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  247. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  248. return image_weights
  249. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  250. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  251. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  252. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  253. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  254. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  255. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  256. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  257. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  258. return x
  259. def xyxy2xywh(x):
  260. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  261. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  262. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  263. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  264. y[:, 2] = x[:, 2] - x[:, 0] # width
  265. y[:, 3] = x[:, 3] - x[:, 1] # height
  266. return y
  267. def xywh2xyxy(x):
  268. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  269. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  270. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  271. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  272. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  273. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  274. return y
  275. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  276. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  277. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  278. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  279. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  280. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  281. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  282. return y
  283. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  284. # Convert normalized segments into pixel segments, shape (n,2)
  285. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  286. y[:, 0] = w * x[:, 0] + padw # top left x
  287. y[:, 1] = h * x[:, 1] + padh # top left y
  288. return y
  289. def segment2box(segment, width=640, height=640):
  290. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  291. x, y = segment.T # segment xy
  292. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  293. x, y, = x[inside], y[inside]
  294. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  295. def segments2boxes(segments):
  296. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  297. boxes = []
  298. for s in segments:
  299. x, y = s.T # segment xy
  300. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  301. return xyxy2xywh(np.array(boxes)) # cls, xywh
  302. def resample_segments(segments, n=1000):
  303. # Up-sample an (n,2) segment
  304. for i, s in enumerate(segments):
  305. x = np.linspace(0, len(s) - 1, n)
  306. xp = np.arange(len(s))
  307. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  308. return segments
  309. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  310. # Rescale coords (xyxy) from img1_shape to img0_shape
  311. if ratio_pad is None: # calculate from img0_shape
  312. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  313. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  314. else:
  315. gain = ratio_pad[0][0]
  316. pad = ratio_pad[1]
  317. coords[:, [0, 2]] -= pad[0] # x padding
  318. coords[:, [1, 3]] -= pad[1] # y padding
  319. coords[:, :4] /= gain
  320. clip_coords(coords, img0_shape)
  321. return coords
  322. def clip_coords(boxes, img_shape):
  323. # Clip bounding xyxy bounding boxes to image shape (height, width)
  324. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  325. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  326. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  327. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  328. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  329. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  330. box2 = box2.T
  331. # Get the coordinates of bounding boxes
  332. if x1y1x2y2: # x1, y1, x2, y2 = box1
  333. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  334. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  335. else: # transform from xywh to xyxy
  336. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  337. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  338. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  339. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  340. # Intersection area
  341. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  342. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  343. # Union Area
  344. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  345. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  346. union = w1 * h1 + w2 * h2 - inter + eps
  347. iou = inter / union
  348. if GIoU or DIoU or CIoU:
  349. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  350. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  351. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  352. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  353. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  354. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  355. if DIoU:
  356. return iou - rho2 / c2 # DIoU
  357. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  358. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  359. with torch.no_grad():
  360. alpha = v / (v - iou + (1 + eps))
  361. return iou - (rho2 / c2 + v * alpha) # CIoU
  362. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  363. c_area = cw * ch + eps # convex area
  364. return iou - (c_area - union) / c_area # GIoU
  365. else:
  366. return iou # IoU
  367. def box_iou(box1, box2):
  368. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  369. """
  370. Return intersection-over-union (Jaccard index) of boxes.
  371. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  372. Arguments:
  373. box1 (Tensor[N, 4])
  374. box2 (Tensor[M, 4])
  375. Returns:
  376. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  377. IoU values for every element in boxes1 and boxes2
  378. """
  379. def box_area(box):
  380. # box = 4xn
  381. return (box[2] - box[0]) * (box[3] - box[1])
  382. area1 = box_area(box1.T)
  383. area2 = box_area(box2.T)
  384. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  385. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  386. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  387. def wh_iou(wh1, wh2):
  388. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  389. wh1 = wh1[:, None] # [N,1,2]
  390. wh2 = wh2[None] # [1,M,2]
  391. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  392. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  393. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  394. labels=()):
  395. """Runs Non-Maximum Suppression (NMS) on inference results
  396. Returns:
  397. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  398. """
  399. nc = prediction.shape[2] - 5 # number of classes
  400. xc = prediction[..., 4] > conf_thres # candidates
  401. # Checks
  402. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  403. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  404. # Settings
  405. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  406. max_det = 300 # maximum number of detections per image
  407. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  408. time_limit = 10.0 # seconds to quit after
  409. redundant = True # require redundant detections
  410. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  411. merge = False # use merge-NMS
  412. t = time.time()
  413. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  414. for xi, x in enumerate(prediction): # image index, image inference
  415. # Apply constraints
  416. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  417. x = x[xc[xi]] # confidence
  418. # Cat apriori labels if autolabelling
  419. if labels and len(labels[xi]):
  420. l = labels[xi]
  421. v = torch.zeros((len(l), nc + 5), device=x.device)
  422. v[:, :4] = l[:, 1:5] # box
  423. v[:, 4] = 1.0 # conf
  424. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  425. x = torch.cat((x, v), 0)
  426. # If none remain process next image
  427. if not x.shape[0]:
  428. continue
  429. # Compute conf
  430. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  431. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  432. box = xywh2xyxy(x[:, :4])
  433. # Detections matrix nx6 (xyxy, conf, cls)
  434. if multi_label:
  435. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  436. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  437. else: # best class only
  438. conf, j = x[:, 5:].max(1, keepdim=True)
  439. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  440. # Filter by class
  441. if classes is not None:
  442. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  443. # Apply finite constraint
  444. # if not torch.isfinite(x).all():
  445. # x = x[torch.isfinite(x).all(1)]
  446. # Check shape
  447. n = x.shape[0] # number of boxes
  448. if not n: # no boxes
  449. continue
  450. elif n > max_nms: # excess boxes
  451. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  452. # Batched NMS
  453. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  454. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  455. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  456. if i.shape[0] > max_det: # limit detections
  457. i = i[:max_det]
  458. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  459. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  460. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  461. weights = iou * scores[None] # box weights
  462. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  463. if redundant:
  464. i = i[iou.sum(1) > 1] # require redundancy
  465. output[xi] = x[i]
  466. if (time.time() - t) > time_limit:
  467. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  468. break # time limit exceeded
  469. return output
  470. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  471. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  472. x = torch.load(f, map_location=torch.device('cpu'))
  473. if x.get('ema'):
  474. x['model'] = x['ema'] # replace model with ema
  475. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  476. x[k] = None
  477. x['epoch'] = -1
  478. x['model'].half() # to FP16
  479. for p in x['model'].parameters():
  480. p.requires_grad = False
  481. torch.save(x, s or f)
  482. mb = os.path.getsize(s or f) / 1E6 # filesize
  483. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  484. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  485. # Print mutation results to evolve.txt (for use with train.py --evolve)
  486. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  487. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  488. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  489. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  490. if bucket:
  491. url = 'gs://%s/evolve.txt' % bucket
  492. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  493. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  494. with open('evolve.txt', 'a') as f: # append result
  495. f.write(c + b + '\n')
  496. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  497. x = x[np.argsort(-fitness(x))] # sort
  498. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  499. # Save yaml
  500. for i, k in enumerate(hyp.keys()):
  501. hyp[k] = float(x[0, i + 7])
  502. with open(yaml_file, 'w') as f:
  503. results = tuple(x[0, :7])
  504. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  505. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  506. yaml.safe_dump(hyp, f, sort_keys=False)
  507. if bucket:
  508. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  509. def apply_classifier(x, model, img, im0):
  510. # Apply a second stage classifier to yolo outputs
  511. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  512. for i, d in enumerate(x): # per image
  513. if d is not None and len(d):
  514. d = d.clone()
  515. # Reshape and pad cutouts
  516. b = xyxy2xywh(d[:, :4]) # boxes
  517. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  518. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  519. d[:, :4] = xywh2xyxy(b).long()
  520. # Rescale boxes from img_size to im0 size
  521. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  522. # Classes
  523. pred_cls1 = d[:, 5].long()
  524. ims = []
  525. for j, a in enumerate(d): # per item
  526. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  527. im = cv2.resize(cutout, (224, 224)) # BGR
  528. # cv2.imwrite('test%i.jpg' % j, cutout)
  529. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  530. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  531. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  532. ims.append(im)
  533. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  534. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  535. return x
  536. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False):
  537. # Save an image crop as {file} with crop size multiplied by {gain} and padded by {pad} pixels
  538. xyxy = torch.tensor(xyxy).view(-1, 4)
  539. b = xyxy2xywh(xyxy) # boxes
  540. if square:
  541. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  542. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  543. xyxy = xywh2xyxy(b).long()
  544. clip_coords(xyxy, im.shape)
  545. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2])]
  546. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop if BGR else crop[..., ::-1])
  547. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  548. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  549. path = Path(path) # os-agnostic
  550. if path.exists() and not exist_ok:
  551. suffix = path.suffix
  552. path = path.with_suffix('')
  553. dirs = glob.glob(f"{path}{sep}*") # similar paths
  554. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  555. i = [int(m.groups()[0]) for m in matches if m] # indices
  556. n = max(i) + 1 if i else 2 # increment number
  557. path = Path(f"{path}{sep}{n}{suffix}") # update path
  558. dir = path if path.suffix == '' else path.parent # directory
  559. if not dir.exists() and mkdir:
  560. dir.mkdir(parents=True, exist_ok=True) # make directory
  561. return path