You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

761 lines
31KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import signal
  14. import time
  15. import urllib
  16. from itertools import repeat
  17. from multiprocessing.pool import ThreadPool
  18. from pathlib import Path
  19. from subprocess import check_output
  20. import cv2
  21. import numpy as np
  22. import pandas as pd
  23. import pkg_resources as pkg
  24. import torch
  25. import torchvision
  26. import yaml
  27. from utils.downloads import gsutil_getsize
  28. from utils.metrics import box_iou, fitness
  29. from utils.torch_utils import init_torch_seeds
  30. # Settings
  31. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  32. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  33. pd.options.display.max_columns = 10
  34. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  35. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  36. class Profile(contextlib.ContextDecorator):
  37. # Usage: @Profile() decorator or 'with Profile():' context manager
  38. def __enter__(self):
  39. self.start = time.time()
  40. def __exit__(self, type, value, traceback):
  41. print(f'Profile results: {time.time() - self.start:.5f}s')
  42. class Timeout(contextlib.ContextDecorator):
  43. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  44. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  45. self.seconds = int(seconds)
  46. self.timeout_message = timeout_msg
  47. self.suppress = bool(suppress_timeout_errors)
  48. def _timeout_handler(self, signum, frame):
  49. raise TimeoutError(self.timeout_message)
  50. def __enter__(self):
  51. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  52. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  53. def __exit__(self, exc_type, exc_val, exc_tb):
  54. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  55. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  56. return True
  57. def try_except(func):
  58. # try-except function. Usage: @try_except decorator
  59. def handler(*args, **kwargs):
  60. try:
  61. func(*args, **kwargs)
  62. except Exception as e:
  63. print(e)
  64. return handler
  65. def methods(instance):
  66. # Get class/instance methods
  67. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  68. def set_logging(rank=-1, verbose=True):
  69. logging.basicConfig(
  70. format="%(message)s",
  71. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  72. def init_seeds(seed=0):
  73. # Initialize random number generator (RNG) seeds
  74. random.seed(seed)
  75. np.random.seed(seed)
  76. init_torch_seeds(seed)
  77. def get_latest_run(search_dir='.'):
  78. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  79. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  80. return max(last_list, key=os.path.getctime) if last_list else ''
  81. def is_docker():
  82. # Is environment a Docker container?
  83. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  84. def is_colab():
  85. # Is environment a Google Colab instance?
  86. try:
  87. import google.colab
  88. return True
  89. except Exception as e:
  90. return False
  91. def is_pip():
  92. # Is file in a pip package?
  93. return 'site-packages' in Path(__file__).absolute().parts
  94. def is_ascii(s=''):
  95. # Is string composed of all ASCII (no UTF) characters?
  96. s = str(s) # convert list, tuple, None, etc. to str
  97. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  98. def emojis(str=''):
  99. # Return platform-dependent emoji-safe version of string
  100. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  101. def file_size(file):
  102. # Return file size in MB
  103. return Path(file).stat().st_size / 1e6
  104. def check_online():
  105. # Check internet connectivity
  106. import socket
  107. try:
  108. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  109. return True
  110. except OSError:
  111. return False
  112. @try_except
  113. def check_git_status():
  114. # Recommend 'git pull' if code is out of date
  115. msg = ', for updates see https://github.com/ultralytics/yolov5'
  116. print(colorstr('github: '), end='')
  117. assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
  118. assert not is_docker(), 'skipping check (Docker image)' + msg
  119. assert check_online(), 'skipping check (offline)' + msg
  120. cmd = 'git fetch && git config --get remote.origin.url'
  121. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  122. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  123. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  124. if n > 0:
  125. s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  126. else:
  127. s = f'up to date with {url} ✅'
  128. print(emojis(s)) # emoji-safe
  129. def check_python(minimum='3.6.2'):
  130. # Check current python version vs. required python version
  131. check_version(platform.python_version(), minimum, name='Python ')
  132. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
  133. # Check version vs. required version
  134. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  135. result = (current == minimum) if pinned else (current >= minimum)
  136. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  137. @try_except
  138. def check_requirements(requirements='requirements.txt', exclude=(), install=True):
  139. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  140. prefix = colorstr('red', 'bold', 'requirements:')
  141. check_python() # check python version
  142. if isinstance(requirements, (str, Path)): # requirements.txt file
  143. file = Path(requirements)
  144. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  145. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  146. else: # list or tuple of packages
  147. requirements = [x for x in requirements if x not in exclude]
  148. n = 0 # number of packages updates
  149. for r in requirements:
  150. try:
  151. pkg.require(r)
  152. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  153. s = f"{prefix} {r} not found and is required by YOLOv5"
  154. if install:
  155. print(f"{s}, attempting auto-update...")
  156. try:
  157. assert check_online(), f"'pip install {r}' skipped (offline)"
  158. print(check_output(f"pip install '{r}'", shell=True).decode())
  159. n += 1
  160. except Exception as e:
  161. print(f'{prefix} {e}')
  162. else:
  163. print(f'{s}. Please install and rerun your command.')
  164. if n: # if packages updated
  165. source = file.resolve() if 'file' in locals() else requirements
  166. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  167. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  168. print(emojis(s))
  169. def check_img_size(imgsz, s=32, floor=0):
  170. # Verify image size is a multiple of stride s in each dimension
  171. if isinstance(imgsz, int): # integer i.e. img_size=640
  172. new_size = max(make_divisible(imgsz, int(s)), floor)
  173. else: # list i.e. img_size=[640, 480]
  174. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  175. if new_size != imgsz:
  176. print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  177. return new_size
  178. def check_imshow():
  179. # Check if environment supports image displays
  180. try:
  181. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  182. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  183. cv2.imshow('test', np.zeros((1, 1, 3)))
  184. cv2.waitKey(1)
  185. cv2.destroyAllWindows()
  186. cv2.waitKey(1)
  187. return True
  188. except Exception as e:
  189. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  190. return False
  191. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  192. # Check file(s) for acceptable suffixes
  193. if file and suffix:
  194. if isinstance(suffix, str):
  195. suffix = [suffix]
  196. for f in file if isinstance(file, (list, tuple)) else [file]:
  197. assert Path(f).suffix.lower() in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  198. def check_yaml(file, suffix=('.yaml', '.yml')):
  199. # Check YAML file(s) for acceptable suffixes
  200. return check_file(file, suffix)
  201. def check_file(file, suffix=''):
  202. # Search/download file (if necessary) and return path
  203. check_suffix(file, suffix) # optional
  204. file = str(file) # convert to str()
  205. if Path(file).is_file() or file == '': # exists
  206. return file
  207. elif file.startswith(('http:/', 'https:/')): # download
  208. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  209. file = Path(urllib.parse.unquote(file)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  210. print(f'Downloading {url} to {file}...')
  211. torch.hub.download_url_to_file(url, file)
  212. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  213. return file
  214. else: # search
  215. files = glob.glob('./**/' + file, recursive=True) # find file
  216. assert len(files), f'File not found: {file}' # assert file was found
  217. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  218. return files[0] # return file
  219. def check_dataset(data, autodownload=True):
  220. # Download and/or unzip dataset if not found locally
  221. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  222. # Download (optional)
  223. extract_dir = ''
  224. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  225. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  226. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  227. extract_dir, autodownload = data.parent, False
  228. # Read yaml (optional)
  229. if isinstance(data, (str, Path)):
  230. with open(data, errors='ignore') as f:
  231. data = yaml.safe_load(f) # dictionary
  232. # Parse yaml
  233. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  234. for k in 'train', 'val', 'test':
  235. if data.get(k): # prepend path
  236. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  237. assert 'nc' in data, "Dataset 'nc' key missing."
  238. if 'names' not in data:
  239. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  240. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  241. if val:
  242. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  243. if not all(x.exists() for x in val):
  244. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  245. if s and autodownload: # download script
  246. if s.startswith('http') and s.endswith('.zip'): # URL
  247. f = Path(s).name # filename
  248. print(f'Downloading {s} ...')
  249. torch.hub.download_url_to_file(s, f)
  250. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  251. Path(root).mkdir(parents=True, exist_ok=True) # create root
  252. r = os.system(f'unzip -q {f} -d {root} && rm {f}') # unzip
  253. elif s.startswith('bash '): # bash script
  254. print(f'Running {s} ...')
  255. r = os.system(s)
  256. else: # python script
  257. r = exec(s, {'yaml': data}) # return None
  258. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  259. else:
  260. raise Exception('Dataset not found.')
  261. return data # dictionary
  262. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  263. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  264. def download_one(url, dir):
  265. # Download 1 file
  266. f = dir / Path(url).name # filename
  267. if Path(url).is_file(): # exists in current path
  268. Path(url).rename(f) # move to dir
  269. elif not f.exists():
  270. print(f'Downloading {url} to {f}...')
  271. if curl:
  272. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  273. else:
  274. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  275. if unzip and f.suffix in ('.zip', '.gz'):
  276. print(f'Unzipping {f}...')
  277. if f.suffix == '.zip':
  278. s = f'unzip -qo {f} -d {dir}' # unzip -quiet -overwrite
  279. elif f.suffix == '.gz':
  280. s = f'tar xfz {f} --directory {f.parent}' # unzip
  281. if delete: # delete zip file after unzip
  282. s += f' && rm {f}'
  283. os.system(s)
  284. dir = Path(dir)
  285. dir.mkdir(parents=True, exist_ok=True) # make directory
  286. if threads > 1:
  287. pool = ThreadPool(threads)
  288. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  289. pool.close()
  290. pool.join()
  291. else:
  292. for u in [url] if isinstance(url, (str, Path)) else url:
  293. download_one(u, dir)
  294. def make_divisible(x, divisor):
  295. # Returns x evenly divisible by divisor
  296. return math.ceil(x / divisor) * divisor
  297. def clean_str(s):
  298. # Cleans a string by replacing special characters with underscore _
  299. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  300. def one_cycle(y1=0.0, y2=1.0, steps=100):
  301. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  302. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  303. def colorstr(*input):
  304. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  305. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  306. colors = {'black': '\033[30m', # basic colors
  307. 'red': '\033[31m',
  308. 'green': '\033[32m',
  309. 'yellow': '\033[33m',
  310. 'blue': '\033[34m',
  311. 'magenta': '\033[35m',
  312. 'cyan': '\033[36m',
  313. 'white': '\033[37m',
  314. 'bright_black': '\033[90m', # bright colors
  315. 'bright_red': '\033[91m',
  316. 'bright_green': '\033[92m',
  317. 'bright_yellow': '\033[93m',
  318. 'bright_blue': '\033[94m',
  319. 'bright_magenta': '\033[95m',
  320. 'bright_cyan': '\033[96m',
  321. 'bright_white': '\033[97m',
  322. 'end': '\033[0m', # misc
  323. 'bold': '\033[1m',
  324. 'underline': '\033[4m'}
  325. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  326. def labels_to_class_weights(labels, nc=80):
  327. # Get class weights (inverse frequency) from training labels
  328. if labels[0] is None: # no labels loaded
  329. return torch.Tensor()
  330. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  331. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  332. weights = np.bincount(classes, minlength=nc) # occurrences per class
  333. # Prepend gridpoint count (for uCE training)
  334. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  335. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  336. weights[weights == 0] = 1 # replace empty bins with 1
  337. weights = 1 / weights # number of targets per class
  338. weights /= weights.sum() # normalize
  339. return torch.from_numpy(weights)
  340. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  341. # Produces image weights based on class_weights and image contents
  342. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  343. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  344. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  345. return image_weights
  346. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  347. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  348. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  349. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  350. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  351. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  352. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  353. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  354. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  355. return x
  356. def xyxy2xywh(x):
  357. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  358. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  359. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  360. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  361. y[:, 2] = x[:, 2] - x[:, 0] # width
  362. y[:, 3] = x[:, 3] - x[:, 1] # height
  363. return y
  364. def xywh2xyxy(x):
  365. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  366. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  367. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  368. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  369. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  370. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  371. return y
  372. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  373. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  374. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  375. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  376. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  377. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  378. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  379. return y
  380. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  381. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  382. if clip:
  383. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  384. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  385. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  386. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  387. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  388. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  389. return y
  390. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  391. # Convert normalized segments into pixel segments, shape (n,2)
  392. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  393. y[:, 0] = w * x[:, 0] + padw # top left x
  394. y[:, 1] = h * x[:, 1] + padh # top left y
  395. return y
  396. def segment2box(segment, width=640, height=640):
  397. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  398. x, y = segment.T # segment xy
  399. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  400. x, y, = x[inside], y[inside]
  401. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  402. def segments2boxes(segments):
  403. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  404. boxes = []
  405. for s in segments:
  406. x, y = s.T # segment xy
  407. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  408. return xyxy2xywh(np.array(boxes)) # cls, xywh
  409. def resample_segments(segments, n=1000):
  410. # Up-sample an (n,2) segment
  411. for i, s in enumerate(segments):
  412. x = np.linspace(0, len(s) - 1, n)
  413. xp = np.arange(len(s))
  414. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  415. return segments
  416. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  417. # Rescale coords (xyxy) from img1_shape to img0_shape
  418. if ratio_pad is None: # calculate from img0_shape
  419. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  420. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  421. else:
  422. gain = ratio_pad[0][0]
  423. pad = ratio_pad[1]
  424. coords[:, [0, 2]] -= pad[0] # x padding
  425. coords[:, [1, 3]] -= pad[1] # y padding
  426. coords[:, :4] /= gain
  427. clip_coords(coords, img0_shape)
  428. return coords
  429. def clip_coords(boxes, shape):
  430. # Clip bounding xyxy bounding boxes to image shape (height, width)
  431. if isinstance(boxes, torch.Tensor): # faster individually
  432. boxes[:, 0].clamp_(0, shape[1]) # x1
  433. boxes[:, 1].clamp_(0, shape[0]) # y1
  434. boxes[:, 2].clamp_(0, shape[1]) # x2
  435. boxes[:, 3].clamp_(0, shape[0]) # y2
  436. else: # np.array (faster grouped)
  437. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  438. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  439. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  440. labels=(), max_det=300):
  441. """Runs Non-Maximum Suppression (NMS) on inference results
  442. Returns:
  443. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  444. """
  445. nc = prediction.shape[2] - 5 # number of classes
  446. xc = prediction[..., 4] > conf_thres # candidates
  447. # Checks
  448. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  449. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  450. # Settings
  451. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  452. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  453. time_limit = 10.0 # seconds to quit after
  454. redundant = True # require redundant detections
  455. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  456. merge = False # use merge-NMS
  457. t = time.time()
  458. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  459. for xi, x in enumerate(prediction): # image index, image inference
  460. # Apply constraints
  461. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  462. x = x[xc[xi]] # confidence
  463. # Cat apriori labels if autolabelling
  464. if labels and len(labels[xi]):
  465. l = labels[xi]
  466. v = torch.zeros((len(l), nc + 5), device=x.device)
  467. v[:, :4] = l[:, 1:5] # box
  468. v[:, 4] = 1.0 # conf
  469. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  470. x = torch.cat((x, v), 0)
  471. # If none remain process next image
  472. if not x.shape[0]:
  473. continue
  474. # Compute conf
  475. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  476. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  477. box = xywh2xyxy(x[:, :4])
  478. # Detections matrix nx6 (xyxy, conf, cls)
  479. if multi_label:
  480. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  481. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  482. else: # best class only
  483. conf, j = x[:, 5:].max(1, keepdim=True)
  484. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  485. # Filter by class
  486. if classes is not None:
  487. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  488. # Apply finite constraint
  489. # if not torch.isfinite(x).all():
  490. # x = x[torch.isfinite(x).all(1)]
  491. # Check shape
  492. n = x.shape[0] # number of boxes
  493. if not n: # no boxes
  494. continue
  495. elif n > max_nms: # excess boxes
  496. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  497. # Batched NMS
  498. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  499. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  500. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  501. if i.shape[0] > max_det: # limit detections
  502. i = i[:max_det]
  503. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  504. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  505. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  506. weights = iou * scores[None] # box weights
  507. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  508. if redundant:
  509. i = i[iou.sum(1) > 1] # require redundancy
  510. output[xi] = x[i]
  511. if (time.time() - t) > time_limit:
  512. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  513. break # time limit exceeded
  514. return output
  515. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  516. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  517. x = torch.load(f, map_location=torch.device('cpu'))
  518. if x.get('ema'):
  519. x['model'] = x['ema'] # replace model with ema
  520. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  521. x[k] = None
  522. x['epoch'] = -1
  523. x['model'].half() # to FP16
  524. for p in x['model'].parameters():
  525. p.requires_grad = False
  526. torch.save(x, s or f)
  527. mb = os.path.getsize(s or f) / 1E6 # filesize
  528. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  529. def print_mutation(results, hyp, save_dir, bucket):
  530. evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
  531. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  532. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  533. keys = tuple(x.strip() for x in keys)
  534. vals = results + tuple(hyp.values())
  535. n = len(keys)
  536. # Download (optional)
  537. if bucket:
  538. url = f'gs://{bucket}/evolve.csv'
  539. if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
  540. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  541. # Log to evolve.csv
  542. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  543. with open(evolve_csv, 'a') as f:
  544. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  545. # Print to screen
  546. print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
  547. print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
  548. # Save yaml
  549. with open(evolve_yaml, 'w') as f:
  550. data = pd.read_csv(evolve_csv)
  551. data = data.rename(columns=lambda x: x.strip()) # strip keys
  552. i = np.argmax(fitness(data.values[:, :7])) #
  553. f.write(f'# YOLOv5 Hyperparameter Evolution Results\n' +
  554. f'# Best generation: {i}\n' +
  555. f'# Last generation: {len(data)}\n' +
  556. f'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  557. f'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  558. yaml.safe_dump(hyp, f, sort_keys=False)
  559. if bucket:
  560. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  561. def apply_classifier(x, model, img, im0):
  562. # Apply a second stage classifier to yolo outputs
  563. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  564. for i, d in enumerate(x): # per image
  565. if d is not None and len(d):
  566. d = d.clone()
  567. # Reshape and pad cutouts
  568. b = xyxy2xywh(d[:, :4]) # boxes
  569. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  570. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  571. d[:, :4] = xywh2xyxy(b).long()
  572. # Rescale boxes from img_size to im0 size
  573. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  574. # Classes
  575. pred_cls1 = d[:, 5].long()
  576. ims = []
  577. for j, a in enumerate(d): # per item
  578. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  579. im = cv2.resize(cutout, (224, 224)) # BGR
  580. # cv2.imwrite('example%i.jpg' % j, cutout)
  581. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  582. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  583. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  584. ims.append(im)
  585. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  586. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  587. return x
  588. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  589. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  590. xyxy = torch.tensor(xyxy).view(-1, 4)
  591. b = xyxy2xywh(xyxy) # boxes
  592. if square:
  593. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  594. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  595. xyxy = xywh2xyxy(b).long()
  596. clip_coords(xyxy, im.shape)
  597. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  598. if save:
  599. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  600. return crop
  601. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  602. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  603. path = Path(path) # os-agnostic
  604. if path.exists() and not exist_ok:
  605. suffix = path.suffix
  606. path = path.with_suffix('')
  607. dirs = glob.glob(f"{path}{sep}*") # similar paths
  608. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  609. i = [int(m.groups()[0]) for m in matches if m] # indices
  610. n = max(i) + 1 if i else 2 # increment number
  611. path = Path(f"{path}{sep}{n}{suffix}") # update path
  612. dir = path if path.suffix == '' else path.parent # directory
  613. if not dir.exists() and mkdir:
  614. dir.mkdir(parents=True, exist_ok=True) # make directory
  615. return path