You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

747 lines
30KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import signal
  14. import time
  15. import urllib
  16. from itertools import repeat
  17. from multiprocessing.pool import ThreadPool
  18. from pathlib import Path
  19. from subprocess import check_output
  20. import cv2
  21. import numpy as np
  22. import pandas as pd
  23. import pkg_resources as pkg
  24. import torch
  25. import torchvision
  26. import yaml
  27. from utils.downloads import gsutil_getsize
  28. from utils.metrics import box_iou, fitness
  29. from utils.torch_utils import init_torch_seeds
  30. # Settings
  31. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  32. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  33. pd.options.display.max_columns = 10
  34. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  35. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  36. class Profile(contextlib.ContextDecorator):
  37. # Usage: @Profile() decorator or 'with Profile():' context manager
  38. def __enter__(self):
  39. self.start = time.time()
  40. def __exit__(self, type, value, traceback):
  41. print(f'Profile results: {time.time() - self.start:.5f}s')
  42. class Timeout(contextlib.ContextDecorator):
  43. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  44. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  45. self.seconds = int(seconds)
  46. self.timeout_message = timeout_msg
  47. self.suppress = bool(suppress_timeout_errors)
  48. def _timeout_handler(self, signum, frame):
  49. raise TimeoutError(self.timeout_message)
  50. def __enter__(self):
  51. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  52. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  53. def __exit__(self, exc_type, exc_val, exc_tb):
  54. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  55. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  56. return True
  57. def try_except(func):
  58. # try-except function. Usage: @try_except decorator
  59. def handler(*args, **kwargs):
  60. try:
  61. func(*args, **kwargs)
  62. except Exception as e:
  63. print(e)
  64. return handler
  65. def methods(instance):
  66. # Get class/instance methods
  67. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  68. def set_logging(rank=-1, verbose=True):
  69. logging.basicConfig(
  70. format="%(message)s",
  71. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  72. def init_seeds(seed=0):
  73. # Initialize random number generator (RNG) seeds
  74. random.seed(seed)
  75. np.random.seed(seed)
  76. init_torch_seeds(seed)
  77. def get_latest_run(search_dir='.'):
  78. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  79. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  80. return max(last_list, key=os.path.getctime) if last_list else ''
  81. def is_docker():
  82. # Is environment a Docker container?
  83. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  84. def is_colab():
  85. # Is environment a Google Colab instance?
  86. try:
  87. import google.colab
  88. return True
  89. except Exception as e:
  90. return False
  91. def is_pip():
  92. # Is file in a pip package?
  93. return 'site-packages' in Path(__file__).absolute().parts
  94. def is_ascii(s=''):
  95. # Is string composed of all ASCII (no UTF) characters?
  96. s = str(s) # convert list, tuple, None, etc. to str
  97. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  98. def emojis(str=''):
  99. # Return platform-dependent emoji-safe version of string
  100. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  101. def file_size(file):
  102. # Return file size in MB
  103. return Path(file).stat().st_size / 1e6
  104. def check_online():
  105. # Check internet connectivity
  106. import socket
  107. try:
  108. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  109. return True
  110. except OSError:
  111. return False
  112. @try_except
  113. def check_git_status():
  114. # Recommend 'git pull' if code is out of date
  115. msg = ', for updates see https://github.com/ultralytics/yolov5'
  116. print(colorstr('github: '), end='')
  117. assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
  118. assert not is_docker(), 'skipping check (Docker image)' + msg
  119. assert check_online(), 'skipping check (offline)' + msg
  120. cmd = 'git fetch && git config --get remote.origin.url'
  121. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  122. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  123. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  124. if n > 0:
  125. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  126. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  127. else:
  128. s = f'up to date with {url} ✅'
  129. print(emojis(s)) # emoji-safe
  130. def check_python(minimum='3.6.2'):
  131. # Check current python version vs. required python version
  132. check_version(platform.python_version(), minimum, name='Python ')
  133. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
  134. # Check version vs. required version
  135. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  136. result = (current == minimum) if pinned else (current >= minimum)
  137. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  138. @try_except
  139. def check_requirements(requirements='requirements.txt', exclude=(), install=True):
  140. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  141. prefix = colorstr('red', 'bold', 'requirements:')
  142. check_python() # check python version
  143. if isinstance(requirements, (str, Path)): # requirements.txt file
  144. file = Path(requirements)
  145. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  146. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  147. else: # list or tuple of packages
  148. requirements = [x for x in requirements if x not in exclude]
  149. n = 0 # number of packages updates
  150. for r in requirements:
  151. try:
  152. pkg.require(r)
  153. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  154. s = f"{prefix} {r} not found and is required by YOLOv5"
  155. if install:
  156. print(f"{s}, attempting auto-update...")
  157. try:
  158. assert check_online(), f"'pip install {r}' skipped (offline)"
  159. print(check_output(f"pip install '{r}'", shell=True).decode())
  160. n += 1
  161. except Exception as e:
  162. print(f'{prefix} {e}')
  163. else:
  164. print(f'{s}. Please install and rerun your command.')
  165. if n: # if packages updated
  166. source = file.resolve() if 'file' in locals() else requirements
  167. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  168. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  169. print(emojis(s))
  170. def check_img_size(imgsz, s=32, floor=0):
  171. # Verify image size is a multiple of stride s in each dimension
  172. if isinstance(imgsz, int): # integer i.e. img_size=640
  173. new_size = max(make_divisible(imgsz, int(s)), floor)
  174. else: # list i.e. img_size=[640, 480]
  175. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  176. if new_size != imgsz:
  177. print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  178. return new_size
  179. def check_imshow():
  180. # Check if environment supports image displays
  181. try:
  182. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  183. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  184. cv2.imshow('test', np.zeros((1, 1, 3)))
  185. cv2.waitKey(1)
  186. cv2.destroyAllWindows()
  187. cv2.waitKey(1)
  188. return True
  189. except Exception as e:
  190. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  191. return False
  192. def check_file(file):
  193. # Search/download file (if necessary) and return path
  194. file = str(file) # convert to str()
  195. if Path(file).is_file() or file == '': # exists
  196. return file
  197. elif file.startswith(('http:/', 'https:/')): # download
  198. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  199. file = Path(urllib.parse.unquote(file)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  200. print(f'Downloading {url} to {file}...')
  201. torch.hub.download_url_to_file(url, file)
  202. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  203. return file
  204. else: # search
  205. files = glob.glob('./**/' + file, recursive=True) # find file
  206. assert len(files), f'File not found: {file}' # assert file was found
  207. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  208. return files[0] # return file
  209. def check_dataset(data, autodownload=True):
  210. # Download and/or unzip dataset if not found locally
  211. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  212. # Download (optional)
  213. extract_dir = ''
  214. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  215. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  216. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  217. extract_dir, autodownload = data.parent, False
  218. # Read yaml (optional)
  219. if isinstance(data, (str, Path)):
  220. with open(data, errors='ignore') as f:
  221. data = yaml.safe_load(f) # dictionary
  222. # Parse yaml
  223. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  224. for k in 'train', 'val', 'test':
  225. if data.get(k): # prepend path
  226. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  227. assert 'nc' in data, "Dataset 'nc' key missing."
  228. if 'names' not in data:
  229. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  230. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  231. if val:
  232. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  233. if not all(x.exists() for x in val):
  234. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  235. if s and autodownload: # download script
  236. if s.startswith('http') and s.endswith('.zip'): # URL
  237. f = Path(s).name # filename
  238. print(f'Downloading {s} ...')
  239. torch.hub.download_url_to_file(s, f)
  240. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  241. Path(root).mkdir(parents=True, exist_ok=True) # create root
  242. r = os.system(f'unzip -q {f} -d {root} && rm {f}') # unzip
  243. elif s.startswith('bash '): # bash script
  244. print(f'Running {s} ...')
  245. r = os.system(s)
  246. else: # python script
  247. r = exec(s, {'yaml': data}) # return None
  248. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  249. else:
  250. raise Exception('Dataset not found.')
  251. return data # dictionary
  252. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  253. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  254. def download_one(url, dir):
  255. # Download 1 file
  256. f = dir / Path(url).name # filename
  257. if Path(url).is_file(): # exists in current path
  258. Path(url).rename(f) # move to dir
  259. elif not f.exists():
  260. print(f'Downloading {url} to {f}...')
  261. if curl:
  262. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  263. else:
  264. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  265. if unzip and f.suffix in ('.zip', '.gz'):
  266. print(f'Unzipping {f}...')
  267. if f.suffix == '.zip':
  268. s = f'unzip -qo {f} -d {dir}' # unzip -quiet -overwrite
  269. elif f.suffix == '.gz':
  270. s = f'tar xfz {f} --directory {f.parent}' # unzip
  271. if delete: # delete zip file after unzip
  272. s += f' && rm {f}'
  273. os.system(s)
  274. dir = Path(dir)
  275. dir.mkdir(parents=True, exist_ok=True) # make directory
  276. if threads > 1:
  277. pool = ThreadPool(threads)
  278. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  279. pool.close()
  280. pool.join()
  281. else:
  282. for u in [url] if isinstance(url, (str, Path)) else url:
  283. download_one(u, dir)
  284. def make_divisible(x, divisor):
  285. # Returns x evenly divisible by divisor
  286. return math.ceil(x / divisor) * divisor
  287. def clean_str(s):
  288. # Cleans a string by replacing special characters with underscore _
  289. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  290. def one_cycle(y1=0.0, y2=1.0, steps=100):
  291. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  292. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  293. def colorstr(*input):
  294. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  295. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  296. colors = {'black': '\033[30m', # basic colors
  297. 'red': '\033[31m',
  298. 'green': '\033[32m',
  299. 'yellow': '\033[33m',
  300. 'blue': '\033[34m',
  301. 'magenta': '\033[35m',
  302. 'cyan': '\033[36m',
  303. 'white': '\033[37m',
  304. 'bright_black': '\033[90m', # bright colors
  305. 'bright_red': '\033[91m',
  306. 'bright_green': '\033[92m',
  307. 'bright_yellow': '\033[93m',
  308. 'bright_blue': '\033[94m',
  309. 'bright_magenta': '\033[95m',
  310. 'bright_cyan': '\033[96m',
  311. 'bright_white': '\033[97m',
  312. 'end': '\033[0m', # misc
  313. 'bold': '\033[1m',
  314. 'underline': '\033[4m'}
  315. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  316. def labels_to_class_weights(labels, nc=80):
  317. # Get class weights (inverse frequency) from training labels
  318. if labels[0] is None: # no labels loaded
  319. return torch.Tensor()
  320. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  321. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  322. weights = np.bincount(classes, minlength=nc) # occurrences per class
  323. # Prepend gridpoint count (for uCE training)
  324. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  325. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  326. weights[weights == 0] = 1 # replace empty bins with 1
  327. weights = 1 / weights # number of targets per class
  328. weights /= weights.sum() # normalize
  329. return torch.from_numpy(weights)
  330. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  331. # Produces image weights based on class_weights and image contents
  332. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  333. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  334. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  335. return image_weights
  336. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  337. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  338. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  339. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  340. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  341. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  342. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  343. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  344. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  345. return x
  346. def xyxy2xywh(x):
  347. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  348. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  349. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  350. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  351. y[:, 2] = x[:, 2] - x[:, 0] # width
  352. y[:, 3] = x[:, 3] - x[:, 1] # height
  353. return y
  354. def xywh2xyxy(x):
  355. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  356. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  357. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  358. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  359. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  360. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  361. return y
  362. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  363. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  364. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  365. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  366. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  367. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  368. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  369. return y
  370. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  371. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  372. if clip:
  373. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  374. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  375. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  376. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  377. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  378. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  379. return y
  380. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  381. # Convert normalized segments into pixel segments, shape (n,2)
  382. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  383. y[:, 0] = w * x[:, 0] + padw # top left x
  384. y[:, 1] = h * x[:, 1] + padh # top left y
  385. return y
  386. def segment2box(segment, width=640, height=640):
  387. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  388. x, y = segment.T # segment xy
  389. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  390. x, y, = x[inside], y[inside]
  391. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  392. def segments2boxes(segments):
  393. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  394. boxes = []
  395. for s in segments:
  396. x, y = s.T # segment xy
  397. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  398. return xyxy2xywh(np.array(boxes)) # cls, xywh
  399. def resample_segments(segments, n=1000):
  400. # Up-sample an (n,2) segment
  401. for i, s in enumerate(segments):
  402. x = np.linspace(0, len(s) - 1, n)
  403. xp = np.arange(len(s))
  404. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  405. return segments
  406. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  407. # Rescale coords (xyxy) from img1_shape to img0_shape
  408. if ratio_pad is None: # calculate from img0_shape
  409. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  410. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  411. else:
  412. gain = ratio_pad[0][0]
  413. pad = ratio_pad[1]
  414. coords[:, [0, 2]] -= pad[0] # x padding
  415. coords[:, [1, 3]] -= pad[1] # y padding
  416. coords[:, :4] /= gain
  417. clip_coords(coords, img0_shape)
  418. return coords
  419. def clip_coords(boxes, shape):
  420. # Clip bounding xyxy bounding boxes to image shape (height, width)
  421. if isinstance(boxes, torch.Tensor): # faster individually
  422. boxes[:, 0].clamp_(0, shape[1]) # x1
  423. boxes[:, 1].clamp_(0, shape[0]) # y1
  424. boxes[:, 2].clamp_(0, shape[1]) # x2
  425. boxes[:, 3].clamp_(0, shape[0]) # y2
  426. else: # np.array (faster grouped)
  427. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  428. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  429. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  430. labels=(), max_det=300):
  431. """Runs Non-Maximum Suppression (NMS) on inference results
  432. Returns:
  433. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  434. """
  435. nc = prediction.shape[2] - 5 # number of classes
  436. xc = prediction[..., 4] > conf_thres # candidates
  437. # Checks
  438. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  439. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  440. # Settings
  441. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  442. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  443. time_limit = 10.0 # seconds to quit after
  444. redundant = True # require redundant detections
  445. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  446. merge = False # use merge-NMS
  447. t = time.time()
  448. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  449. for xi, x in enumerate(prediction): # image index, image inference
  450. # Apply constraints
  451. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  452. x = x[xc[xi]] # confidence
  453. # Cat apriori labels if autolabelling
  454. if labels and len(labels[xi]):
  455. l = labels[xi]
  456. v = torch.zeros((len(l), nc + 5), device=x.device)
  457. v[:, :4] = l[:, 1:5] # box
  458. v[:, 4] = 1.0 # conf
  459. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  460. x = torch.cat((x, v), 0)
  461. # If none remain process next image
  462. if not x.shape[0]:
  463. continue
  464. # Compute conf
  465. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  466. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  467. box = xywh2xyxy(x[:, :4])
  468. # Detections matrix nx6 (xyxy, conf, cls)
  469. if multi_label:
  470. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  471. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  472. else: # best class only
  473. conf, j = x[:, 5:].max(1, keepdim=True)
  474. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  475. # Filter by class
  476. if classes is not None:
  477. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  478. # Apply finite constraint
  479. # if not torch.isfinite(x).all():
  480. # x = x[torch.isfinite(x).all(1)]
  481. # Check shape
  482. n = x.shape[0] # number of boxes
  483. if not n: # no boxes
  484. continue
  485. elif n > max_nms: # excess boxes
  486. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  487. # Batched NMS
  488. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  489. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  490. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  491. if i.shape[0] > max_det: # limit detections
  492. i = i[:max_det]
  493. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  494. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  495. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  496. weights = iou * scores[None] # box weights
  497. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  498. if redundant:
  499. i = i[iou.sum(1) > 1] # require redundancy
  500. output[xi] = x[i]
  501. if (time.time() - t) > time_limit:
  502. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  503. break # time limit exceeded
  504. return output
  505. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  506. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  507. x = torch.load(f, map_location=torch.device('cpu'))
  508. if x.get('ema'):
  509. x['model'] = x['ema'] # replace model with ema
  510. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  511. x[k] = None
  512. x['epoch'] = -1
  513. x['model'].half() # to FP16
  514. for p in x['model'].parameters():
  515. p.requires_grad = False
  516. torch.save(x, s or f)
  517. mb = os.path.getsize(s or f) / 1E6 # filesize
  518. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  519. def print_mutation(results, hyp, save_dir, bucket):
  520. evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
  521. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  522. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  523. keys = tuple(x.strip() for x in keys)
  524. vals = results + tuple(hyp.values())
  525. n = len(keys)
  526. # Download (optional)
  527. if bucket:
  528. url = f'gs://{bucket}/evolve.csv'
  529. if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
  530. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  531. # Log to evolve.csv
  532. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  533. with open(evolve_csv, 'a') as f:
  534. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  535. # Print to screen
  536. print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
  537. print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
  538. # Save yaml
  539. with open(evolve_yaml, 'w') as f:
  540. data = pd.read_csv(evolve_csv)
  541. data = data.rename(columns=lambda x: x.strip()) # strip keys
  542. i = np.argmax(fitness(data.values[:, :7])) #
  543. f.write(f'# YOLOv5 Hyperparameter Evolution Results\n' +
  544. f'# Best generation: {i}\n' +
  545. f'# Last generation: {len(data)}\n' +
  546. f'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  547. f'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  548. yaml.safe_dump(hyp, f, sort_keys=False)
  549. if bucket:
  550. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  551. def apply_classifier(x, model, img, im0):
  552. # Apply a second stage classifier to yolo outputs
  553. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  554. for i, d in enumerate(x): # per image
  555. if d is not None and len(d):
  556. d = d.clone()
  557. # Reshape and pad cutouts
  558. b = xyxy2xywh(d[:, :4]) # boxes
  559. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  560. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  561. d[:, :4] = xywh2xyxy(b).long()
  562. # Rescale boxes from img_size to im0 size
  563. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  564. # Classes
  565. pred_cls1 = d[:, 5].long()
  566. ims = []
  567. for j, a in enumerate(d): # per item
  568. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  569. im = cv2.resize(cutout, (224, 224)) # BGR
  570. # cv2.imwrite('example%i.jpg' % j, cutout)
  571. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  572. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  573. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  574. ims.append(im)
  575. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  576. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  577. return x
  578. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  579. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  580. xyxy = torch.tensor(xyxy).view(-1, 4)
  581. b = xyxy2xywh(xyxy) # boxes
  582. if square:
  583. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  584. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  585. xyxy = xywh2xyxy(b).long()
  586. clip_coords(xyxy, im.shape)
  587. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  588. if save:
  589. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  590. return crop
  591. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  592. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  593. path = Path(path) # os-agnostic
  594. if path.exists() and not exist_ok:
  595. suffix = path.suffix
  596. path = path.with_suffix('')
  597. dirs = glob.glob(f"{path}{sep}*") # similar paths
  598. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  599. i = [int(m.groups()[0]) for m in matches if m] # indices
  600. n = max(i) + 1 if i else 2 # increment number
  601. path = Path(f"{path}{sep}{n}{suffix}") # update path
  602. dir = path if path.suffix == '' else path.parent # directory
  603. if not dir.exists() and mkdir:
  604. dir.mkdir(parents=True, exist_ok=True) # make directory
  605. return path