You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

842 lines
34KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import shutil
  14. import signal
  15. import time
  16. import urllib
  17. from itertools import repeat
  18. from multiprocessing.pool import ThreadPool
  19. from pathlib import Path
  20. from subprocess import check_output
  21. from zipfile import ZipFile
  22. import cv2
  23. import numpy as np
  24. import pandas as pd
  25. import pkg_resources as pkg
  26. import torch
  27. import torchvision
  28. import yaml
  29. from utils.downloads import gsutil_getsize
  30. from utils.metrics import box_iou, fitness
  31. # Settings
  32. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  33. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  34. pd.options.display.max_columns = 10
  35. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  36. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  37. FILE = Path(__file__).resolve()
  38. ROOT = FILE.parents[1] # YOLOv5 root directory
  39. def set_logging(name=None, verbose=True):
  40. # Sets level and returns logger
  41. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  42. logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING)
  43. return logging.getLogger(name)
  44. LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.)
  45. class Profile(contextlib.ContextDecorator):
  46. # Usage: @Profile() decorator or 'with Profile():' context manager
  47. def __enter__(self):
  48. self.start = time.time()
  49. def __exit__(self, type, value, traceback):
  50. print(f'Profile results: {time.time() - self.start:.5f}s')
  51. class Timeout(contextlib.ContextDecorator):
  52. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  53. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  54. self.seconds = int(seconds)
  55. self.timeout_message = timeout_msg
  56. self.suppress = bool(suppress_timeout_errors)
  57. def _timeout_handler(self, signum, frame):
  58. raise TimeoutError(self.timeout_message)
  59. def __enter__(self):
  60. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  61. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  62. def __exit__(self, exc_type, exc_val, exc_tb):
  63. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  64. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  65. return True
  66. class WorkingDirectory(contextlib.ContextDecorator):
  67. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  68. def __init__(self, new_dir):
  69. self.dir = new_dir # new dir
  70. self.cwd = Path.cwd().resolve() # current dir
  71. def __enter__(self):
  72. os.chdir(self.dir)
  73. def __exit__(self, exc_type, exc_val, exc_tb):
  74. os.chdir(self.cwd)
  75. def try_except(func):
  76. # try-except function. Usage: @try_except decorator
  77. def handler(*args, **kwargs):
  78. try:
  79. func(*args, **kwargs)
  80. except Exception as e:
  81. print(e)
  82. return handler
  83. def methods(instance):
  84. # Get class/instance methods
  85. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  86. def print_args(name, opt):
  87. # Print argparser arguments
  88. LOGGER.info(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  89. def init_seeds(seed=0):
  90. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  91. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  92. import torch.backends.cudnn as cudnn
  93. random.seed(seed)
  94. np.random.seed(seed)
  95. torch.manual_seed(seed)
  96. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  97. def intersect_dicts(da, db, exclude=()):
  98. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  99. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  100. def get_latest_run(search_dir='.'):
  101. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  102. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  103. return max(last_list, key=os.path.getctime) if last_list else ''
  104. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  105. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  106. env = os.getenv(env_var)
  107. if env:
  108. path = Path(env) # use environment variable
  109. else:
  110. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  111. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  112. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  113. path.mkdir(exist_ok=True) # make if required
  114. return path
  115. def is_writeable(dir, test=False):
  116. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  117. if test: # method 1
  118. file = Path(dir) / 'tmp.txt'
  119. try:
  120. with open(file, 'w'): # open file with write permissions
  121. pass
  122. file.unlink() # remove file
  123. return True
  124. except OSError:
  125. return False
  126. else: # method 2
  127. return os.access(dir, os.R_OK) # possible issues on Windows
  128. def is_docker():
  129. # Is environment a Docker container?
  130. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  131. def is_colab():
  132. # Is environment a Google Colab instance?
  133. try:
  134. import google.colab
  135. return True
  136. except ImportError:
  137. return False
  138. def is_pip():
  139. # Is file in a pip package?
  140. return 'site-packages' in Path(__file__).resolve().parts
  141. def is_ascii(s=''):
  142. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  143. s = str(s) # convert list, tuple, None, etc. to str
  144. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  145. def is_chinese(s='人工智能'):
  146. # Is string composed of any Chinese characters?
  147. return re.search('[\u4e00-\u9fff]', s)
  148. def emojis(str=''):
  149. # Return platform-dependent emoji-safe version of string
  150. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  151. def file_size(path):
  152. # Return file/dir size (MB)
  153. path = Path(path)
  154. if path.is_file():
  155. return path.stat().st_size / 1E6
  156. elif path.is_dir():
  157. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  158. else:
  159. return 0.0
  160. def check_online():
  161. # Check internet connectivity
  162. import socket
  163. try:
  164. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  165. return True
  166. except OSError:
  167. return False
  168. @try_except
  169. @WorkingDirectory(ROOT)
  170. def check_git_status():
  171. # Recommend 'git pull' if code is out of date
  172. msg = ', for updates see https://github.com/ultralytics/yolov5'
  173. print(colorstr('github: '), end='')
  174. assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
  175. assert not is_docker(), 'skipping check (Docker image)' + msg
  176. assert check_online(), 'skipping check (offline)' + msg
  177. cmd = 'git fetch && git config --get remote.origin.url'
  178. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  179. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  180. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  181. if n > 0:
  182. s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  183. else:
  184. s = f'up to date with {url} ✅'
  185. print(emojis(s)) # emoji-safe
  186. def check_python(minimum='3.6.2'):
  187. # Check current python version vs. required python version
  188. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  189. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False):
  190. # Check version vs. required version
  191. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  192. result = (current == minimum) if pinned else (current >= minimum) # bool
  193. if hard: # assert min requirements met
  194. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  195. else:
  196. return result
  197. @try_except
  198. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  199. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  200. prefix = colorstr('red', 'bold', 'requirements:')
  201. check_python() # check python version
  202. if isinstance(requirements, (str, Path)): # requirements.txt file
  203. file = Path(requirements)
  204. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  205. with file.open() as f:
  206. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  207. else: # list or tuple of packages
  208. requirements = [x for x in requirements if x not in exclude]
  209. n = 0 # number of packages updates
  210. for r in requirements:
  211. try:
  212. pkg.require(r)
  213. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  214. s = f"{prefix} {r} not found and is required by YOLOv5"
  215. if install:
  216. print(f"{s}, attempting auto-update...")
  217. try:
  218. assert check_online(), f"'pip install {r}' skipped (offline)"
  219. print(check_output(f"pip install '{r}'", shell=True).decode())
  220. n += 1
  221. except Exception as e:
  222. print(f'{prefix} {e}')
  223. else:
  224. print(f'{s}. Please install and rerun your command.')
  225. if n: # if packages updated
  226. source = file.resolve() if 'file' in locals() else requirements
  227. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  228. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  229. print(emojis(s))
  230. def check_img_size(imgsz, s=32, floor=0):
  231. # Verify image size is a multiple of stride s in each dimension
  232. if isinstance(imgsz, int): # integer i.e. img_size=640
  233. new_size = max(make_divisible(imgsz, int(s)), floor)
  234. else: # list i.e. img_size=[640, 480]
  235. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  236. if new_size != imgsz:
  237. print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  238. return new_size
  239. def check_imshow():
  240. # Check if environment supports image displays
  241. try:
  242. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  243. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  244. cv2.imshow('test', np.zeros((1, 1, 3)))
  245. cv2.waitKey(1)
  246. cv2.destroyAllWindows()
  247. cv2.waitKey(1)
  248. return True
  249. except Exception as e:
  250. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  251. return False
  252. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  253. # Check file(s) for acceptable suffix
  254. if file and suffix:
  255. if isinstance(suffix, str):
  256. suffix = [suffix]
  257. for f in file if isinstance(file, (list, tuple)) else [file]:
  258. s = Path(f).suffix.lower() # file suffix
  259. if len(s):
  260. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  261. def check_yaml(file, suffix=('.yaml', '.yml')):
  262. # Search/download YAML file (if necessary) and return path, checking suffix
  263. return check_file(file, suffix)
  264. def check_file(file, suffix=''):
  265. # Search/download file (if necessary) and return path
  266. check_suffix(file, suffix) # optional
  267. file = str(file) # convert to str()
  268. if Path(file).is_file() or file == '': # exists
  269. return file
  270. elif file.startswith(('http:/', 'https:/')): # download
  271. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  272. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  273. if Path(file).is_file():
  274. print(f'Found {url} locally at {file}') # file already exists
  275. else:
  276. print(f'Downloading {url} to {file}...')
  277. torch.hub.download_url_to_file(url, file)
  278. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  279. return file
  280. else: # search
  281. files = []
  282. for d in 'data', 'models', 'utils': # search directories
  283. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  284. assert len(files), f'File not found: {file}' # assert file was found
  285. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  286. return files[0] # return file
  287. def check_dataset(data, autodownload=True):
  288. # Download and/or unzip dataset if not found locally
  289. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  290. # Download (optional)
  291. extract_dir = ''
  292. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  293. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  294. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  295. extract_dir, autodownload = data.parent, False
  296. # Read yaml (optional)
  297. if isinstance(data, (str, Path)):
  298. with open(data, errors='ignore') as f:
  299. data = yaml.safe_load(f) # dictionary
  300. # Parse yaml
  301. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  302. for k in 'train', 'val', 'test':
  303. if data.get(k): # prepend path
  304. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  305. assert 'nc' in data, "Dataset 'nc' key missing."
  306. if 'names' not in data:
  307. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  308. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  309. if val:
  310. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  311. if not all(x.exists() for x in val):
  312. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  313. if s and autodownload: # download script
  314. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  315. if s.startswith('http') and s.endswith('.zip'): # URL
  316. f = Path(s).name # filename
  317. print(f'Downloading {s} to {f}...')
  318. torch.hub.download_url_to_file(s, f)
  319. Path(root).mkdir(parents=True, exist_ok=True) # create root
  320. ZipFile(f).extractall(path=root) # unzip
  321. Path(f).unlink() # remove zip
  322. r = None # success
  323. elif s.startswith('bash '): # bash script
  324. print(f'Running {s} ...')
  325. r = os.system(s)
  326. else: # python script
  327. r = exec(s, {'yaml': data}) # return None
  328. print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n")
  329. else:
  330. raise Exception('Dataset not found.')
  331. return data # dictionary
  332. def url2file(url):
  333. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  334. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  335. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  336. return file
  337. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  338. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  339. def download_one(url, dir):
  340. # Download 1 file
  341. f = dir / Path(url).name # filename
  342. if Path(url).is_file(): # exists in current path
  343. Path(url).rename(f) # move to dir
  344. elif not f.exists():
  345. print(f'Downloading {url} to {f}...')
  346. if curl:
  347. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  348. else:
  349. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  350. if unzip and f.suffix in ('.zip', '.gz'):
  351. print(f'Unzipping {f}...')
  352. if f.suffix == '.zip':
  353. ZipFile(f).extractall(path=dir) # unzip
  354. elif f.suffix == '.gz':
  355. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  356. if delete:
  357. f.unlink() # remove zip
  358. dir = Path(dir)
  359. dir.mkdir(parents=True, exist_ok=True) # make directory
  360. if threads > 1:
  361. pool = ThreadPool(threads)
  362. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  363. pool.close()
  364. pool.join()
  365. else:
  366. for u in [url] if isinstance(url, (str, Path)) else url:
  367. download_one(u, dir)
  368. def make_divisible(x, divisor):
  369. # Returns x evenly divisible by divisor
  370. return math.ceil(x / divisor) * divisor
  371. def clean_str(s):
  372. # Cleans a string by replacing special characters with underscore _
  373. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  374. def one_cycle(y1=0.0, y2=1.0, steps=100):
  375. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  376. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  377. def colorstr(*input):
  378. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  379. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  380. colors = {'black': '\033[30m', # basic colors
  381. 'red': '\033[31m',
  382. 'green': '\033[32m',
  383. 'yellow': '\033[33m',
  384. 'blue': '\033[34m',
  385. 'magenta': '\033[35m',
  386. 'cyan': '\033[36m',
  387. 'white': '\033[37m',
  388. 'bright_black': '\033[90m', # bright colors
  389. 'bright_red': '\033[91m',
  390. 'bright_green': '\033[92m',
  391. 'bright_yellow': '\033[93m',
  392. 'bright_blue': '\033[94m',
  393. 'bright_magenta': '\033[95m',
  394. 'bright_cyan': '\033[96m',
  395. 'bright_white': '\033[97m',
  396. 'end': '\033[0m', # misc
  397. 'bold': '\033[1m',
  398. 'underline': '\033[4m'}
  399. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  400. def labels_to_class_weights(labels, nc=80):
  401. # Get class weights (inverse frequency) from training labels
  402. if labels[0] is None: # no labels loaded
  403. return torch.Tensor()
  404. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  405. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  406. weights = np.bincount(classes, minlength=nc) # occurrences per class
  407. # Prepend gridpoint count (for uCE training)
  408. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  409. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  410. weights[weights == 0] = 1 # replace empty bins with 1
  411. weights = 1 / weights # number of targets per class
  412. weights /= weights.sum() # normalize
  413. return torch.from_numpy(weights)
  414. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  415. # Produces image weights based on class_weights and image contents
  416. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  417. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  418. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  419. return image_weights
  420. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  421. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  422. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  423. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  424. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  425. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  426. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  427. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  428. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  429. return x
  430. def xyxy2xywh(x):
  431. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  432. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  433. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  434. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  435. y[:, 2] = x[:, 2] - x[:, 0] # width
  436. y[:, 3] = x[:, 3] - x[:, 1] # height
  437. return y
  438. def xywh2xyxy(x):
  439. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  440. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  441. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  442. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  443. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  444. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  445. return y
  446. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  447. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  448. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  449. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  450. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  451. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  452. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  453. return y
  454. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  455. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  456. if clip:
  457. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  458. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  459. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  460. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  461. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  462. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  463. return y
  464. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  465. # Convert normalized segments into pixel segments, shape (n,2)
  466. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  467. y[:, 0] = w * x[:, 0] + padw # top left x
  468. y[:, 1] = h * x[:, 1] + padh # top left y
  469. return y
  470. def segment2box(segment, width=640, height=640):
  471. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  472. x, y = segment.T # segment xy
  473. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  474. x, y, = x[inside], y[inside]
  475. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  476. def segments2boxes(segments):
  477. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  478. boxes = []
  479. for s in segments:
  480. x, y = s.T # segment xy
  481. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  482. return xyxy2xywh(np.array(boxes)) # cls, xywh
  483. def resample_segments(segments, n=1000):
  484. # Up-sample an (n,2) segment
  485. for i, s in enumerate(segments):
  486. x = np.linspace(0, len(s) - 1, n)
  487. xp = np.arange(len(s))
  488. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  489. return segments
  490. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  491. # Rescale coords (xyxy) from img1_shape to img0_shape
  492. if ratio_pad is None: # calculate from img0_shape
  493. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  494. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  495. else:
  496. gain = ratio_pad[0][0]
  497. pad = ratio_pad[1]
  498. coords[:, [0, 2]] -= pad[0] # x padding
  499. coords[:, [1, 3]] -= pad[1] # y padding
  500. coords[:, :4] /= gain
  501. clip_coords(coords, img0_shape)
  502. return coords
  503. def clip_coords(boxes, shape):
  504. # Clip bounding xyxy bounding boxes to image shape (height, width)
  505. if isinstance(boxes, torch.Tensor): # faster individually
  506. boxes[:, 0].clamp_(0, shape[1]) # x1
  507. boxes[:, 1].clamp_(0, shape[0]) # y1
  508. boxes[:, 2].clamp_(0, shape[1]) # x2
  509. boxes[:, 3].clamp_(0, shape[0]) # y2
  510. else: # np.array (faster grouped)
  511. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  512. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  513. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  514. labels=(), max_det=300):
  515. """Runs Non-Maximum Suppression (NMS) on inference results
  516. Returns:
  517. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  518. """
  519. nc = prediction.shape[2] - 5 # number of classes
  520. xc = prediction[..., 4] > conf_thres # candidates
  521. # Checks
  522. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  523. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  524. # Settings
  525. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  526. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  527. time_limit = 10.0 # seconds to quit after
  528. redundant = True # require redundant detections
  529. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  530. merge = False # use merge-NMS
  531. t = time.time()
  532. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  533. for xi, x in enumerate(prediction): # image index, image inference
  534. # Apply constraints
  535. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  536. x = x[xc[xi]] # confidence
  537. # Cat apriori labels if autolabelling
  538. if labels and len(labels[xi]):
  539. l = labels[xi]
  540. v = torch.zeros((len(l), nc + 5), device=x.device)
  541. v[:, :4] = l[:, 1:5] # box
  542. v[:, 4] = 1.0 # conf
  543. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  544. x = torch.cat((x, v), 0)
  545. # If none remain process next image
  546. if not x.shape[0]:
  547. continue
  548. # Compute conf
  549. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  550. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  551. box = xywh2xyxy(x[:, :4])
  552. # Detections matrix nx6 (xyxy, conf, cls)
  553. if multi_label:
  554. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  555. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  556. else: # best class only
  557. conf, j = x[:, 5:].max(1, keepdim=True)
  558. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  559. # Filter by class
  560. if classes is not None:
  561. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  562. # Apply finite constraint
  563. # if not torch.isfinite(x).all():
  564. # x = x[torch.isfinite(x).all(1)]
  565. # Check shape
  566. n = x.shape[0] # number of boxes
  567. if not n: # no boxes
  568. continue
  569. elif n > max_nms: # excess boxes
  570. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  571. # Batched NMS
  572. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  573. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  574. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  575. if i.shape[0] > max_det: # limit detections
  576. i = i[:max_det]
  577. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  578. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  579. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  580. weights = iou * scores[None] # box weights
  581. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  582. if redundant:
  583. i = i[iou.sum(1) > 1] # require redundancy
  584. output[xi] = x[i]
  585. if (time.time() - t) > time_limit:
  586. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  587. break # time limit exceeded
  588. return output
  589. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  590. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  591. x = torch.load(f, map_location=torch.device('cpu'))
  592. if x.get('ema'):
  593. x['model'] = x['ema'] # replace model with ema
  594. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  595. x[k] = None
  596. x['epoch'] = -1
  597. x['model'].half() # to FP16
  598. for p in x['model'].parameters():
  599. p.requires_grad = False
  600. torch.save(x, s or f)
  601. mb = os.path.getsize(s or f) / 1E6 # filesize
  602. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  603. def print_mutation(results, hyp, save_dir, bucket):
  604. evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
  605. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  606. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  607. keys = tuple(x.strip() for x in keys)
  608. vals = results + tuple(hyp.values())
  609. n = len(keys)
  610. # Download (optional)
  611. if bucket:
  612. url = f'gs://{bucket}/evolve.csv'
  613. if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
  614. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  615. # Log to evolve.csv
  616. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  617. with open(evolve_csv, 'a') as f:
  618. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  619. # Print to screen
  620. print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
  621. print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
  622. # Save yaml
  623. with open(evolve_yaml, 'w') as f:
  624. data = pd.read_csv(evolve_csv)
  625. data = data.rename(columns=lambda x: x.strip()) # strip keys
  626. i = np.argmax(fitness(data.values[:, :7])) #
  627. f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
  628. f'# Best generation: {i}\n' +
  629. f'# Last generation: {len(data)}\n' +
  630. '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  631. '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  632. yaml.safe_dump(hyp, f, sort_keys=False)
  633. if bucket:
  634. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  635. def apply_classifier(x, model, img, im0):
  636. # Apply a second stage classifier to YOLO outputs
  637. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  638. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  639. for i, d in enumerate(x): # per image
  640. if d is not None and len(d):
  641. d = d.clone()
  642. # Reshape and pad cutouts
  643. b = xyxy2xywh(d[:, :4]) # boxes
  644. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  645. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  646. d[:, :4] = xywh2xyxy(b).long()
  647. # Rescale boxes from img_size to im0 size
  648. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  649. # Classes
  650. pred_cls1 = d[:, 5].long()
  651. ims = []
  652. for j, a in enumerate(d): # per item
  653. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  654. im = cv2.resize(cutout, (224, 224)) # BGR
  655. # cv2.imwrite('example%i.jpg' % j, cutout)
  656. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  657. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  658. im /= 255 # 0 - 255 to 0.0 - 1.0
  659. ims.append(im)
  660. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  661. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  662. return x
  663. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  664. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  665. path = Path(path) # os-agnostic
  666. if path.exists() and not exist_ok:
  667. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  668. dirs = glob.glob(f"{path}{sep}*") # similar paths
  669. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  670. i = [int(m.groups()[0]) for m in matches if m] # indices
  671. n = max(i) + 1 if i else 2 # increment number
  672. path = Path(f"{path}{sep}{n}{suffix}") # increment path
  673. if mkdir:
  674. path.mkdir(parents=True, exist_ok=True) # make directory
  675. return path
  676. # Variables
  677. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size