You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

730 lines
30KB

  1. # YOLOv5 general utils
  2. import contextlib
  3. import glob
  4. import logging
  5. import math
  6. import os
  7. import platform
  8. import random
  9. import re
  10. import signal
  11. import time
  12. import urllib
  13. from itertools import repeat
  14. from multiprocessing.pool import ThreadPool
  15. from pathlib import Path
  16. from subprocess import check_output
  17. import cv2
  18. import numpy as np
  19. import pandas as pd
  20. import pkg_resources as pkg
  21. import torch
  22. import torchvision
  23. import yaml
  24. from utils.google_utils import gsutil_getsize
  25. from utils.metrics import fitness
  26. from utils.torch_utils import init_torch_seeds
  27. # Settings
  28. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  29. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  30. pd.options.display.max_columns = 10
  31. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  32. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  33. class timeout(contextlib.ContextDecorator):
  34. # Usage: @timeout(seconds) decorator or 'with timeout(seconds):' context manager
  35. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  36. self.seconds = int(seconds)
  37. self.timeout_message = timeout_msg
  38. self.suppress = bool(suppress_timeout_errors)
  39. def _timeout_handler(self, signum, frame):
  40. raise TimeoutError(self.timeout_message)
  41. def __enter__(self):
  42. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  43. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  44. def __exit__(self, exc_type, exc_val, exc_tb):
  45. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  46. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  47. return True
  48. def set_logging(rank=-1, verbose=True):
  49. logging.basicConfig(
  50. format="%(message)s",
  51. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  52. def init_seeds(seed=0):
  53. # Initialize random number generator (RNG) seeds
  54. random.seed(seed)
  55. np.random.seed(seed)
  56. init_torch_seeds(seed)
  57. def get_latest_run(search_dir='.'):
  58. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  59. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  60. return max(last_list, key=os.path.getctime) if last_list else ''
  61. def is_docker():
  62. # Is environment a Docker container?
  63. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  64. def is_colab():
  65. # Is environment a Google Colab instance?
  66. try:
  67. import google.colab
  68. return True
  69. except Exception as e:
  70. return False
  71. def is_pip():
  72. # Is file in a pip package?
  73. return 'site-packages' in Path(__file__).absolute().parts
  74. def emojis(str=''):
  75. # Return platform-dependent emoji-safe version of string
  76. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  77. def file_size(file):
  78. # Return file size in MB
  79. return Path(file).stat().st_size / 1e6
  80. def check_online():
  81. # Check internet connectivity
  82. import socket
  83. try:
  84. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  85. return True
  86. except OSError:
  87. return False
  88. def check_git_status(err_msg=', for updates see https://github.com/ultralytics/yolov5'):
  89. # Recommend 'git pull' if code is out of date
  90. print(colorstr('github: '), end='')
  91. try:
  92. assert Path('.git').exists(), 'skipping check (not a git repository)'
  93. assert not is_docker(), 'skipping check (Docker image)'
  94. assert check_online(), 'skipping check (offline)'
  95. cmd = 'git fetch && git config --get remote.origin.url'
  96. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  97. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  98. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  99. if n > 0:
  100. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  101. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  102. else:
  103. s = f'up to date with {url} ✅'
  104. print(emojis(s)) # emoji-safe
  105. except Exception as e:
  106. print(f'{e}{err_msg}')
  107. def check_python(minimum='3.7.0', required=True):
  108. # Check current python version vs. required python version
  109. current = platform.python_version()
  110. result = pkg.parse_version(current) >= pkg.parse_version(minimum)
  111. if required:
  112. assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed'
  113. return result
  114. def check_requirements(requirements='requirements.txt', exclude=()):
  115. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  116. prefix = colorstr('red', 'bold', 'requirements:')
  117. check_python() # check python version
  118. if isinstance(requirements, (str, Path)): # requirements.txt file
  119. file = Path(requirements)
  120. if not file.exists():
  121. print(f"{prefix} {file.resolve()} not found, check failed.")
  122. return
  123. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  124. else: # list or tuple of packages
  125. requirements = [x for x in requirements if x not in exclude]
  126. n = 0 # number of packages updates
  127. for r in requirements:
  128. try:
  129. pkg.require(r)
  130. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  131. print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...")
  132. try:
  133. assert check_online(), f"'pip install {r}' skipped (offline)"
  134. print(check_output(f"pip install '{r}'", shell=True).decode())
  135. n += 1
  136. except Exception as e:
  137. print(f'{prefix} {e}')
  138. if n: # if packages updated
  139. source = file.resolve() if 'file' in locals() else requirements
  140. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  141. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  142. print(emojis(s)) # emoji-safe
  143. def check_img_size(img_size, s=32):
  144. # Verify img_size is a multiple of stride s
  145. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  146. if new_size != img_size:
  147. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  148. return new_size
  149. def check_imshow():
  150. # Check if environment supports image displays
  151. try:
  152. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  153. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  154. cv2.imshow('test', np.zeros((1, 1, 3)))
  155. cv2.waitKey(1)
  156. cv2.destroyAllWindows()
  157. cv2.waitKey(1)
  158. return True
  159. except Exception as e:
  160. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  161. return False
  162. def check_file(file):
  163. # Search/download file (if necessary) and return path
  164. file = str(file) # convert to str()
  165. if Path(file).is_file() or file == '': # exists
  166. return file
  167. elif file.startswith(('http://', 'https://')): # download
  168. url, file = file, Path(urllib.parse.unquote(str(file))).name # url, file (decode '%2F' to '/' etc.)
  169. file = file.split('?')[0] # parse authentication https://url.com/file.txt?auth...
  170. print(f'Downloading {url} to {file}...')
  171. torch.hub.download_url_to_file(url, file)
  172. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  173. return file
  174. else: # search
  175. files = glob.glob('./**/' + file, recursive=True) # find file
  176. assert len(files), f'File not found: {file}' # assert file was found
  177. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  178. return files[0] # return file
  179. def check_dataset(dict):
  180. # Download dataset if not found locally
  181. val, s = dict.get('val'), dict.get('download')
  182. if val and len(val):
  183. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  184. if not all(x.exists() for x in val):
  185. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  186. if s and len(s): # download script
  187. if s.startswith('http') and s.endswith('.zip'): # URL
  188. f = Path(s).name # filename
  189. print(f'Downloading {s} ...')
  190. torch.hub.download_url_to_file(s, f)
  191. r = os.system(f'unzip -q {f} -d ../ && rm {f}') # unzip
  192. elif s.startswith('bash '): # bash script
  193. print(f'Running {s} ...')
  194. r = os.system(s)
  195. else: # python script
  196. r = exec(s) # return None
  197. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  198. else:
  199. raise Exception('Dataset not found.')
  200. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  201. # Multi-threaded file download and unzip function
  202. def download_one(url, dir):
  203. # Download 1 file
  204. f = dir / Path(url).name # filename
  205. if not f.exists():
  206. print(f'Downloading {url} to {f}...')
  207. if curl:
  208. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  209. else:
  210. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  211. if unzip and f.suffix in ('.zip', '.gz'):
  212. print(f'Unzipping {f}...')
  213. if f.suffix == '.zip':
  214. s = f'unzip -qo {f} -d {dir} && rm {f}' # unzip -quiet -overwrite
  215. elif f.suffix == '.gz':
  216. s = f'tar xfz {f} --directory {f.parent}' # unzip
  217. if delete: # delete zip file after unzip
  218. s += f' && rm {f}'
  219. os.system(s)
  220. dir = Path(dir)
  221. dir.mkdir(parents=True, exist_ok=True) # make directory
  222. if threads > 1:
  223. pool = ThreadPool(threads)
  224. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  225. pool.close()
  226. pool.join()
  227. else:
  228. for u in tuple(url) if isinstance(url, str) else url:
  229. download_one(u, dir)
  230. def make_divisible(x, divisor):
  231. # Returns x evenly divisible by divisor
  232. return math.ceil(x / divisor) * divisor
  233. def clean_str(s):
  234. # Cleans a string by replacing special characters with underscore _
  235. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  236. def one_cycle(y1=0.0, y2=1.0, steps=100):
  237. # lambda function for sinusoidal ramp from y1 to y2
  238. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  239. def colorstr(*input):
  240. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  241. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  242. colors = {'black': '\033[30m', # basic colors
  243. 'red': '\033[31m',
  244. 'green': '\033[32m',
  245. 'yellow': '\033[33m',
  246. 'blue': '\033[34m',
  247. 'magenta': '\033[35m',
  248. 'cyan': '\033[36m',
  249. 'white': '\033[37m',
  250. 'bright_black': '\033[90m', # bright colors
  251. 'bright_red': '\033[91m',
  252. 'bright_green': '\033[92m',
  253. 'bright_yellow': '\033[93m',
  254. 'bright_blue': '\033[94m',
  255. 'bright_magenta': '\033[95m',
  256. 'bright_cyan': '\033[96m',
  257. 'bright_white': '\033[97m',
  258. 'end': '\033[0m', # misc
  259. 'bold': '\033[1m',
  260. 'underline': '\033[4m'}
  261. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  262. def labels_to_class_weights(labels, nc=80):
  263. # Get class weights (inverse frequency) from training labels
  264. if labels[0] is None: # no labels loaded
  265. return torch.Tensor()
  266. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  267. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  268. weights = np.bincount(classes, minlength=nc) # occurrences per class
  269. # Prepend gridpoint count (for uCE training)
  270. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  271. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  272. weights[weights == 0] = 1 # replace empty bins with 1
  273. weights = 1 / weights # number of targets per class
  274. weights /= weights.sum() # normalize
  275. return torch.from_numpy(weights)
  276. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  277. # Produces image weights based on class_weights and image contents
  278. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  279. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  280. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  281. return image_weights
  282. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  283. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  284. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  285. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  286. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  287. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  288. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  289. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  290. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  291. return x
  292. def xyxy2xywh(x):
  293. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  294. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  295. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  296. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  297. y[:, 2] = x[:, 2] - x[:, 0] # width
  298. y[:, 3] = x[:, 3] - x[:, 1] # height
  299. return y
  300. def xywh2xyxy(x):
  301. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  302. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  303. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  304. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  305. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  306. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  307. return y
  308. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  309. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  310. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  311. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  312. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  313. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  314. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  315. return y
  316. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  317. # Convert normalized segments into pixel segments, shape (n,2)
  318. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  319. y[:, 0] = w * x[:, 0] + padw # top left x
  320. y[:, 1] = h * x[:, 1] + padh # top left y
  321. return y
  322. def segment2box(segment, width=640, height=640):
  323. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  324. x, y = segment.T # segment xy
  325. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  326. x, y, = x[inside], y[inside]
  327. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  328. def segments2boxes(segments):
  329. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  330. boxes = []
  331. for s in segments:
  332. x, y = s.T # segment xy
  333. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  334. return xyxy2xywh(np.array(boxes)) # cls, xywh
  335. def resample_segments(segments, n=1000):
  336. # Up-sample an (n,2) segment
  337. for i, s in enumerate(segments):
  338. x = np.linspace(0, len(s) - 1, n)
  339. xp = np.arange(len(s))
  340. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  341. return segments
  342. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  343. # Rescale coords (xyxy) from img1_shape to img0_shape
  344. if ratio_pad is None: # calculate from img0_shape
  345. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  346. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  347. else:
  348. gain = ratio_pad[0][0]
  349. pad = ratio_pad[1]
  350. coords[:, [0, 2]] -= pad[0] # x padding
  351. coords[:, [1, 3]] -= pad[1] # y padding
  352. coords[:, :4] /= gain
  353. clip_coords(coords, img0_shape)
  354. return coords
  355. def clip_coords(boxes, img_shape):
  356. # Clip bounding xyxy bounding boxes to image shape (height, width)
  357. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  358. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  359. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  360. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  361. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  362. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  363. box2 = box2.T
  364. # Get the coordinates of bounding boxes
  365. if x1y1x2y2: # x1, y1, x2, y2 = box1
  366. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  367. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  368. else: # transform from xywh to xyxy
  369. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  370. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  371. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  372. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  373. # Intersection area
  374. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  375. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  376. # Union Area
  377. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  378. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  379. union = w1 * h1 + w2 * h2 - inter + eps
  380. iou = inter / union
  381. if GIoU or DIoU or CIoU:
  382. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  383. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  384. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  385. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  386. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  387. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  388. if DIoU:
  389. return iou - rho2 / c2 # DIoU
  390. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  391. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  392. with torch.no_grad():
  393. alpha = v / (v - iou + (1 + eps))
  394. return iou - (rho2 / c2 + v * alpha) # CIoU
  395. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  396. c_area = cw * ch + eps # convex area
  397. return iou - (c_area - union) / c_area # GIoU
  398. else:
  399. return iou # IoU
  400. def box_iou(box1, box2):
  401. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  402. """
  403. Return intersection-over-union (Jaccard index) of boxes.
  404. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  405. Arguments:
  406. box1 (Tensor[N, 4])
  407. box2 (Tensor[M, 4])
  408. Returns:
  409. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  410. IoU values for every element in boxes1 and boxes2
  411. """
  412. def box_area(box):
  413. # box = 4xn
  414. return (box[2] - box[0]) * (box[3] - box[1])
  415. area1 = box_area(box1.T)
  416. area2 = box_area(box2.T)
  417. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  418. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  419. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  420. def wh_iou(wh1, wh2):
  421. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  422. wh1 = wh1[:, None] # [N,1,2]
  423. wh2 = wh2[None] # [1,M,2]
  424. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  425. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  426. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  427. labels=(), max_det=300):
  428. """Runs Non-Maximum Suppression (NMS) on inference results
  429. Returns:
  430. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  431. """
  432. nc = prediction.shape[2] - 5 # number of classes
  433. xc = prediction[..., 4] > conf_thres # candidates
  434. # Checks
  435. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  436. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  437. # Settings
  438. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  439. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  440. time_limit = 10.0 # seconds to quit after
  441. redundant = True # require redundant detections
  442. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  443. merge = False # use merge-NMS
  444. t = time.time()
  445. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  446. for xi, x in enumerate(prediction): # image index, image inference
  447. # Apply constraints
  448. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  449. x = x[xc[xi]] # confidence
  450. # Cat apriori labels if autolabelling
  451. if labels and len(labels[xi]):
  452. l = labels[xi]
  453. v = torch.zeros((len(l), nc + 5), device=x.device)
  454. v[:, :4] = l[:, 1:5] # box
  455. v[:, 4] = 1.0 # conf
  456. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  457. x = torch.cat((x, v), 0)
  458. # If none remain process next image
  459. if not x.shape[0]:
  460. continue
  461. # Compute conf
  462. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  463. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  464. box = xywh2xyxy(x[:, :4])
  465. # Detections matrix nx6 (xyxy, conf, cls)
  466. if multi_label:
  467. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  468. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  469. else: # best class only
  470. conf, j = x[:, 5:].max(1, keepdim=True)
  471. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  472. # Filter by class
  473. if classes is not None:
  474. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  475. # Apply finite constraint
  476. # if not torch.isfinite(x).all():
  477. # x = x[torch.isfinite(x).all(1)]
  478. # Check shape
  479. n = x.shape[0] # number of boxes
  480. if not n: # no boxes
  481. continue
  482. elif n > max_nms: # excess boxes
  483. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  484. # Batched NMS
  485. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  486. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  487. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  488. if i.shape[0] > max_det: # limit detections
  489. i = i[:max_det]
  490. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  491. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  492. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  493. weights = iou * scores[None] # box weights
  494. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  495. if redundant:
  496. i = i[iou.sum(1) > 1] # require redundancy
  497. output[xi] = x[i]
  498. if (time.time() - t) > time_limit:
  499. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  500. break # time limit exceeded
  501. return output
  502. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  503. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  504. x = torch.load(f, map_location=torch.device('cpu'))
  505. if x.get('ema'):
  506. x['model'] = x['ema'] # replace model with ema
  507. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  508. x[k] = None
  509. x['epoch'] = -1
  510. x['model'].half() # to FP16
  511. for p in x['model'].parameters():
  512. p.requires_grad = False
  513. torch.save(x, s or f)
  514. mb = os.path.getsize(s or f) / 1E6 # filesize
  515. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  516. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  517. # Print mutation results to evolve.txt (for use with train.py --evolve)
  518. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  519. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  520. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  521. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  522. if bucket:
  523. url = 'gs://%s/evolve.txt' % bucket
  524. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  525. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  526. with open('evolve.txt', 'a') as f: # append result
  527. f.write(c + b + '\n')
  528. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  529. x = x[np.argsort(-fitness(x))] # sort
  530. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  531. # Save yaml
  532. for i, k in enumerate(hyp.keys()):
  533. hyp[k] = float(x[0, i + 7])
  534. with open(yaml_file, 'w') as f:
  535. results = tuple(x[0, :7])
  536. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  537. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  538. yaml.safe_dump(hyp, f, sort_keys=False)
  539. if bucket:
  540. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  541. def apply_classifier(x, model, img, im0):
  542. # Apply a second stage classifier to yolo outputs
  543. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  544. for i, d in enumerate(x): # per image
  545. if d is not None and len(d):
  546. d = d.clone()
  547. # Reshape and pad cutouts
  548. b = xyxy2xywh(d[:, :4]) # boxes
  549. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  550. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  551. d[:, :4] = xywh2xyxy(b).long()
  552. # Rescale boxes from img_size to im0 size
  553. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  554. # Classes
  555. pred_cls1 = d[:, 5].long()
  556. ims = []
  557. for j, a in enumerate(d): # per item
  558. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  559. im = cv2.resize(cutout, (224, 224)) # BGR
  560. # cv2.imwrite('test%i.jpg' % j, cutout)
  561. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  562. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  563. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  564. ims.append(im)
  565. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  566. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  567. return x
  568. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  569. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  570. xyxy = torch.tensor(xyxy).view(-1, 4)
  571. b = xyxy2xywh(xyxy) # boxes
  572. if square:
  573. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  574. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  575. xyxy = xywh2xyxy(b).long()
  576. clip_coords(xyxy, im.shape)
  577. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  578. if save:
  579. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  580. return crop
  581. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  582. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  583. path = Path(path) # os-agnostic
  584. if path.exists() and not exist_ok:
  585. suffix = path.suffix
  586. path = path.with_suffix('')
  587. dirs = glob.glob(f"{path}{sep}*") # similar paths
  588. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  589. i = [int(m.groups()[0]) for m in matches if m] # indices
  590. n = max(i) + 1 if i else 2 # increment number
  591. path = Path(f"{path}{sep}{n}{suffix}") # update path
  592. dir = path if path.suffix == '' else path.parent # directory
  593. if not dir.exists() and mkdir:
  594. dir.mkdir(parents=True, exist_ok=True) # make directory
  595. return path