You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

823 lines
33KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import signal
  14. import time
  15. import urllib
  16. from itertools import repeat
  17. from multiprocessing.pool import ThreadPool
  18. from pathlib import Path
  19. from subprocess import check_output
  20. from zipfile import ZipFile
  21. import cv2
  22. import numpy as np
  23. import pandas as pd
  24. import pkg_resources as pkg
  25. import torch
  26. import torchvision
  27. import yaml
  28. from utils.downloads import gsutil_getsize
  29. from utils.metrics import box_iou, fitness
  30. # Settings
  31. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  32. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  33. pd.options.display.max_columns = 10
  34. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  35. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  36. FILE = Path(__file__).resolve()
  37. ROOT = FILE.parents[1] # YOLOv5 root directory
  38. class Profile(contextlib.ContextDecorator):
  39. # Usage: @Profile() decorator or 'with Profile():' context manager
  40. def __enter__(self):
  41. self.start = time.time()
  42. def __exit__(self, type, value, traceback):
  43. print(f'Profile results: {time.time() - self.start:.5f}s')
  44. class Timeout(contextlib.ContextDecorator):
  45. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  46. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  47. self.seconds = int(seconds)
  48. self.timeout_message = timeout_msg
  49. self.suppress = bool(suppress_timeout_errors)
  50. def _timeout_handler(self, signum, frame):
  51. raise TimeoutError(self.timeout_message)
  52. def __enter__(self):
  53. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  54. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  55. def __exit__(self, exc_type, exc_val, exc_tb):
  56. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  57. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  58. return True
  59. def try_except(func):
  60. # try-except function. Usage: @try_except decorator
  61. def handler(*args, **kwargs):
  62. try:
  63. func(*args, **kwargs)
  64. except Exception as e:
  65. print(e)
  66. return handler
  67. def methods(instance):
  68. # Get class/instance methods
  69. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  70. def set_logging(rank=-1, verbose=True):
  71. logging.basicConfig(
  72. format="%(message)s",
  73. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  74. def print_args(name, opt):
  75. # Print argparser arguments
  76. print(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  77. def init_seeds(seed=0):
  78. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  79. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  80. import torch.backends.cudnn as cudnn
  81. random.seed(seed)
  82. np.random.seed(seed)
  83. torch.manual_seed(seed)
  84. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  85. def get_latest_run(search_dir='.'):
  86. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  87. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  88. return max(last_list, key=os.path.getctime) if last_list else ''
  89. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  90. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  91. env = os.getenv(env_var)
  92. if env:
  93. path = Path(env) # use environment variable
  94. else:
  95. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  96. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  97. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  98. path.mkdir(exist_ok=True) # make if required
  99. return path
  100. def is_writeable(dir, test=False):
  101. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  102. if test: # method 1
  103. file = Path(dir) / 'tmp.txt'
  104. try:
  105. with open(file, 'w'): # open file with write permissions
  106. pass
  107. file.unlink() # remove file
  108. return True
  109. except IOError:
  110. return False
  111. else: # method 2
  112. return os.access(dir, os.R_OK) # possible issues on Windows
  113. def is_docker():
  114. # Is environment a Docker container?
  115. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  116. def is_colab():
  117. # Is environment a Google Colab instance?
  118. try:
  119. import google.colab
  120. return True
  121. except ImportError:
  122. return False
  123. def is_pip():
  124. # Is file in a pip package?
  125. return 'site-packages' in Path(__file__).resolve().parts
  126. def is_ascii(s=''):
  127. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  128. s = str(s) # convert list, tuple, None, etc. to str
  129. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  130. def is_chinese(s='人工智能'):
  131. # Is string composed of any Chinese characters?
  132. return re.search('[\u4e00-\u9fff]', s)
  133. def emojis(str=''):
  134. # Return platform-dependent emoji-safe version of string
  135. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  136. def file_size(path):
  137. # Return file/dir size (MB)
  138. path = Path(path)
  139. if path.is_file():
  140. return path.stat().st_size / 1E6
  141. elif path.is_dir():
  142. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  143. else:
  144. return 0.0
  145. def check_online():
  146. # Check internet connectivity
  147. import socket
  148. try:
  149. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  150. return True
  151. except OSError:
  152. return False
  153. @try_except
  154. def check_git_status():
  155. # Recommend 'git pull' if code is out of date
  156. msg = ', for updates see https://github.com/ultralytics/yolov5'
  157. print(colorstr('github: '), end='')
  158. assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
  159. assert not is_docker(), 'skipping check (Docker image)' + msg
  160. assert check_online(), 'skipping check (offline)' + msg
  161. cmd = 'git fetch && git config --get remote.origin.url'
  162. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  163. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  164. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  165. if n > 0:
  166. s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  167. else:
  168. s = f'up to date with {url} ✅'
  169. print(emojis(s)) # emoji-safe
  170. def check_python(minimum='3.6.2'):
  171. # Check current python version vs. required python version
  172. check_version(platform.python_version(), minimum, name='Python ')
  173. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
  174. # Check version vs. required version
  175. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  176. result = (current == minimum) if pinned else (current >= minimum)
  177. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  178. @try_except
  179. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  180. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  181. prefix = colorstr('red', 'bold', 'requirements:')
  182. check_python() # check python version
  183. if isinstance(requirements, (str, Path)): # requirements.txt file
  184. file = Path(requirements)
  185. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  186. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  187. else: # list or tuple of packages
  188. requirements = [x for x in requirements if x not in exclude]
  189. n = 0 # number of packages updates
  190. for r in requirements:
  191. try:
  192. pkg.require(r)
  193. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  194. s = f"{prefix} {r} not found and is required by YOLOv5"
  195. if install:
  196. print(f"{s}, attempting auto-update...")
  197. try:
  198. assert check_online(), f"'pip install {r}' skipped (offline)"
  199. print(check_output(f"pip install '{r}'", shell=True).decode())
  200. n += 1
  201. except Exception as e:
  202. print(f'{prefix} {e}')
  203. else:
  204. print(f'{s}. Please install and rerun your command.')
  205. if n: # if packages updated
  206. source = file.resolve() if 'file' in locals() else requirements
  207. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  208. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  209. print(emojis(s))
  210. def check_img_size(imgsz, s=32, floor=0):
  211. # Verify image size is a multiple of stride s in each dimension
  212. if isinstance(imgsz, int): # integer i.e. img_size=640
  213. new_size = max(make_divisible(imgsz, int(s)), floor)
  214. else: # list i.e. img_size=[640, 480]
  215. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  216. if new_size != imgsz:
  217. print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  218. return new_size
  219. def check_imshow():
  220. # Check if environment supports image displays
  221. try:
  222. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  223. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  224. cv2.imshow('test', np.zeros((1, 1, 3)))
  225. cv2.waitKey(1)
  226. cv2.destroyAllWindows()
  227. cv2.waitKey(1)
  228. return True
  229. except Exception as e:
  230. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  231. return False
  232. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  233. # Check file(s) for acceptable suffix
  234. if file and suffix:
  235. if isinstance(suffix, str):
  236. suffix = [suffix]
  237. for f in file if isinstance(file, (list, tuple)) else [file]:
  238. s = Path(f).suffix.lower() # file suffix
  239. if len(s):
  240. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  241. def check_yaml(file, suffix=('.yaml', '.yml')):
  242. # Search/download YAML file (if necessary) and return path, checking suffix
  243. return check_file(file, suffix)
  244. def check_file(file, suffix=''):
  245. # Search/download file (if necessary) and return path
  246. check_suffix(file, suffix) # optional
  247. file = str(file) # convert to str()
  248. if Path(file).is_file() or file == '': # exists
  249. return file
  250. elif file.startswith(('http:/', 'https:/')): # download
  251. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  252. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  253. print(f'Downloading {url} to {file}...')
  254. torch.hub.download_url_to_file(url, file)
  255. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  256. return file
  257. else: # search
  258. files = []
  259. for d in 'data', 'models', 'utils': # search directories
  260. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  261. assert len(files), f'File not found: {file}' # assert file was found
  262. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  263. return files[0] # return file
  264. def check_dataset(data, autodownload=True):
  265. # Download and/or unzip dataset if not found locally
  266. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  267. # Download (optional)
  268. extract_dir = ''
  269. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  270. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  271. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  272. extract_dir, autodownload = data.parent, False
  273. # Read yaml (optional)
  274. if isinstance(data, (str, Path)):
  275. with open(data, errors='ignore') as f:
  276. data = yaml.safe_load(f) # dictionary
  277. # Parse yaml
  278. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  279. for k in 'train', 'val', 'test':
  280. if data.get(k): # prepend path
  281. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  282. assert 'nc' in data, "Dataset 'nc' key missing."
  283. if 'names' not in data:
  284. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  285. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  286. if val:
  287. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  288. if not all(x.exists() for x in val):
  289. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  290. if s and autodownload: # download script
  291. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  292. if s.startswith('http') and s.endswith('.zip'): # URL
  293. f = Path(s).name # filename
  294. print(f'Downloading {s} to {f}...')
  295. torch.hub.download_url_to_file(s, f)
  296. Path(root).mkdir(parents=True, exist_ok=True) # create root
  297. ZipFile(f).extractall(path=root) # unzip
  298. Path(f).unlink() # remove zip
  299. r = None # success
  300. elif s.startswith('bash '): # bash script
  301. print(f'Running {s} ...')
  302. r = os.system(s)
  303. else: # python script
  304. r = exec(s, {'yaml': data}) # return None
  305. print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n")
  306. else:
  307. raise Exception('Dataset not found.')
  308. return data # dictionary
  309. def url2file(url):
  310. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  311. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  312. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  313. return file
  314. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  315. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  316. def download_one(url, dir):
  317. # Download 1 file
  318. f = dir / Path(url).name # filename
  319. if Path(url).is_file(): # exists in current path
  320. Path(url).rename(f) # move to dir
  321. elif not f.exists():
  322. print(f'Downloading {url} to {f}...')
  323. if curl:
  324. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  325. else:
  326. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  327. if unzip and f.suffix in ('.zip', '.gz'):
  328. print(f'Unzipping {f}...')
  329. if f.suffix == '.zip':
  330. ZipFile(f).extractall(path=dir) # unzip
  331. elif f.suffix == '.gz':
  332. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  333. if delete:
  334. f.unlink() # remove zip
  335. dir = Path(dir)
  336. dir.mkdir(parents=True, exist_ok=True) # make directory
  337. if threads > 1:
  338. pool = ThreadPool(threads)
  339. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  340. pool.close()
  341. pool.join()
  342. else:
  343. for u in [url] if isinstance(url, (str, Path)) else url:
  344. download_one(u, dir)
  345. def make_divisible(x, divisor):
  346. # Returns x evenly divisible by divisor
  347. return math.ceil(x / divisor) * divisor
  348. def clean_str(s):
  349. # Cleans a string by replacing special characters with underscore _
  350. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  351. def one_cycle(y1=0.0, y2=1.0, steps=100):
  352. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  353. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  354. def colorstr(*input):
  355. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  356. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  357. colors = {'black': '\033[30m', # basic colors
  358. 'red': '\033[31m',
  359. 'green': '\033[32m',
  360. 'yellow': '\033[33m',
  361. 'blue': '\033[34m',
  362. 'magenta': '\033[35m',
  363. 'cyan': '\033[36m',
  364. 'white': '\033[37m',
  365. 'bright_black': '\033[90m', # bright colors
  366. 'bright_red': '\033[91m',
  367. 'bright_green': '\033[92m',
  368. 'bright_yellow': '\033[93m',
  369. 'bright_blue': '\033[94m',
  370. 'bright_magenta': '\033[95m',
  371. 'bright_cyan': '\033[96m',
  372. 'bright_white': '\033[97m',
  373. 'end': '\033[0m', # misc
  374. 'bold': '\033[1m',
  375. 'underline': '\033[4m'}
  376. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  377. def labels_to_class_weights(labels, nc=80):
  378. # Get class weights (inverse frequency) from training labels
  379. if labels[0] is None: # no labels loaded
  380. return torch.Tensor()
  381. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  382. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  383. weights = np.bincount(classes, minlength=nc) # occurrences per class
  384. # Prepend gridpoint count (for uCE training)
  385. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  386. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  387. weights[weights == 0] = 1 # replace empty bins with 1
  388. weights = 1 / weights # number of targets per class
  389. weights /= weights.sum() # normalize
  390. return torch.from_numpy(weights)
  391. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  392. # Produces image weights based on class_weights and image contents
  393. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  394. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  395. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  396. return image_weights
  397. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  398. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  399. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  400. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  401. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  402. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  403. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  404. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  405. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  406. return x
  407. def xyxy2xywh(x):
  408. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  409. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  410. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  411. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  412. y[:, 2] = x[:, 2] - x[:, 0] # width
  413. y[:, 3] = x[:, 3] - x[:, 1] # height
  414. return y
  415. def xywh2xyxy(x):
  416. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  417. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  418. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  419. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  420. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  421. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  422. return y
  423. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  424. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  425. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  426. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  427. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  428. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  429. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  430. return y
  431. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  432. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  433. if clip:
  434. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  435. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  436. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  437. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  438. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  439. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  440. return y
  441. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  442. # Convert normalized segments into pixel segments, shape (n,2)
  443. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  444. y[:, 0] = w * x[:, 0] + padw # top left x
  445. y[:, 1] = h * x[:, 1] + padh # top left y
  446. return y
  447. def segment2box(segment, width=640, height=640):
  448. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  449. x, y = segment.T # segment xy
  450. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  451. x, y, = x[inside], y[inside]
  452. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  453. def segments2boxes(segments):
  454. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  455. boxes = []
  456. for s in segments:
  457. x, y = s.T # segment xy
  458. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  459. return xyxy2xywh(np.array(boxes)) # cls, xywh
  460. def resample_segments(segments, n=1000):
  461. # Up-sample an (n,2) segment
  462. for i, s in enumerate(segments):
  463. x = np.linspace(0, len(s) - 1, n)
  464. xp = np.arange(len(s))
  465. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  466. return segments
  467. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  468. # Rescale coords (xyxy) from img1_shape to img0_shape
  469. if ratio_pad is None: # calculate from img0_shape
  470. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  471. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  472. else:
  473. gain = ratio_pad[0][0]
  474. pad = ratio_pad[1]
  475. coords[:, [0, 2]] -= pad[0] # x padding
  476. coords[:, [1, 3]] -= pad[1] # y padding
  477. coords[:, :4] /= gain
  478. clip_coords(coords, img0_shape)
  479. return coords
  480. def clip_coords(boxes, shape):
  481. # Clip bounding xyxy bounding boxes to image shape (height, width)
  482. if isinstance(boxes, torch.Tensor): # faster individually
  483. boxes[:, 0].clamp_(0, shape[1]) # x1
  484. boxes[:, 1].clamp_(0, shape[0]) # y1
  485. boxes[:, 2].clamp_(0, shape[1]) # x2
  486. boxes[:, 3].clamp_(0, shape[0]) # y2
  487. else: # np.array (faster grouped)
  488. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  489. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  490. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  491. labels=(), max_det=300):
  492. """Runs Non-Maximum Suppression (NMS) on inference results
  493. Returns:
  494. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  495. """
  496. nc = prediction.shape[2] - 5 # number of classes
  497. xc = prediction[..., 4] > conf_thres # candidates
  498. # Checks
  499. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  500. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  501. # Settings
  502. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  503. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  504. time_limit = 10.0 # seconds to quit after
  505. redundant = True # require redundant detections
  506. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  507. merge = False # use merge-NMS
  508. t = time.time()
  509. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  510. for xi, x in enumerate(prediction): # image index, image inference
  511. # Apply constraints
  512. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  513. x = x[xc[xi]] # confidence
  514. # Cat apriori labels if autolabelling
  515. if labels and len(labels[xi]):
  516. l = labels[xi]
  517. v = torch.zeros((len(l), nc + 5), device=x.device)
  518. v[:, :4] = l[:, 1:5] # box
  519. v[:, 4] = 1.0 # conf
  520. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  521. x = torch.cat((x, v), 0)
  522. # If none remain process next image
  523. if not x.shape[0]:
  524. continue
  525. # Compute conf
  526. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  527. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  528. box = xywh2xyxy(x[:, :4])
  529. # Detections matrix nx6 (xyxy, conf, cls)
  530. if multi_label:
  531. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  532. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  533. else: # best class only
  534. conf, j = x[:, 5:].max(1, keepdim=True)
  535. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  536. # Filter by class
  537. if classes is not None:
  538. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  539. # Apply finite constraint
  540. # if not torch.isfinite(x).all():
  541. # x = x[torch.isfinite(x).all(1)]
  542. # Check shape
  543. n = x.shape[0] # number of boxes
  544. if not n: # no boxes
  545. continue
  546. elif n > max_nms: # excess boxes
  547. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  548. # Batched NMS
  549. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  550. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  551. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  552. if i.shape[0] > max_det: # limit detections
  553. i = i[:max_det]
  554. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  555. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  556. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  557. weights = iou * scores[None] # box weights
  558. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  559. if redundant:
  560. i = i[iou.sum(1) > 1] # require redundancy
  561. output[xi] = x[i]
  562. if (time.time() - t) > time_limit:
  563. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  564. break # time limit exceeded
  565. return output
  566. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  567. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  568. x = torch.load(f, map_location=torch.device('cpu'))
  569. if x.get('ema'):
  570. x['model'] = x['ema'] # replace model with ema
  571. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  572. x[k] = None
  573. x['epoch'] = -1
  574. x['model'].half() # to FP16
  575. for p in x['model'].parameters():
  576. p.requires_grad = False
  577. torch.save(x, s or f)
  578. mb = os.path.getsize(s or f) / 1E6 # filesize
  579. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  580. def print_mutation(results, hyp, save_dir, bucket):
  581. evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
  582. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  583. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  584. keys = tuple(x.strip() for x in keys)
  585. vals = results + tuple(hyp.values())
  586. n = len(keys)
  587. # Download (optional)
  588. if bucket:
  589. url = f'gs://{bucket}/evolve.csv'
  590. if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
  591. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  592. # Log to evolve.csv
  593. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  594. with open(evolve_csv, 'a') as f:
  595. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  596. # Print to screen
  597. print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
  598. print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
  599. # Save yaml
  600. with open(evolve_yaml, 'w') as f:
  601. data = pd.read_csv(evolve_csv)
  602. data = data.rename(columns=lambda x: x.strip()) # strip keys
  603. i = np.argmax(fitness(data.values[:, :7])) #
  604. f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
  605. f'# Best generation: {i}\n' +
  606. f'# Last generation: {len(data)}\n' +
  607. '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  608. '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  609. yaml.safe_dump(hyp, f, sort_keys=False)
  610. if bucket:
  611. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  612. def apply_classifier(x, model, img, im0):
  613. # Apply a second stage classifier to yolo outputs
  614. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  615. for i, d in enumerate(x): # per image
  616. if d is not None and len(d):
  617. d = d.clone()
  618. # Reshape and pad cutouts
  619. b = xyxy2xywh(d[:, :4]) # boxes
  620. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  621. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  622. d[:, :4] = xywh2xyxy(b).long()
  623. # Rescale boxes from img_size to im0 size
  624. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  625. # Classes
  626. pred_cls1 = d[:, 5].long()
  627. ims = []
  628. for j, a in enumerate(d): # per item
  629. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  630. im = cv2.resize(cutout, (224, 224)) # BGR
  631. # cv2.imwrite('example%i.jpg' % j, cutout)
  632. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  633. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  634. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  635. ims.append(im)
  636. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  637. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  638. return x
  639. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  640. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  641. xyxy = torch.tensor(xyxy).view(-1, 4)
  642. b = xyxy2xywh(xyxy) # boxes
  643. if square:
  644. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  645. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  646. xyxy = xywh2xyxy(b).long()
  647. clip_coords(xyxy, im.shape)
  648. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  649. if save:
  650. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  651. return crop
  652. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  653. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  654. path = Path(path) # os-agnostic
  655. if path.exists() and not exist_ok:
  656. suffix = path.suffix
  657. path = path.with_suffix('')
  658. dirs = glob.glob(f"{path}{sep}*") # similar paths
  659. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  660. i = [int(m.groups()[0]) for m in matches if m] # indices
  661. n = max(i) + 1 if i else 2 # increment number
  662. path = Path(f"{path}{sep}{n}{suffix}") # update path
  663. dir = path if path.suffix == '' else path.parent # directory
  664. if not dir.exists() and mkdir:
  665. dir.mkdir(parents=True, exist_ok=True) # make directory
  666. return path