You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

860 lines
35KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import shutil
  14. import signal
  15. import time
  16. import urllib
  17. from itertools import repeat
  18. from multiprocessing.pool import ThreadPool
  19. from pathlib import Path
  20. from subprocess import check_output
  21. from zipfile import ZipFile
  22. import cv2
  23. import numpy as np
  24. import pandas as pd
  25. import pkg_resources as pkg
  26. import torch
  27. import torchvision
  28. import yaml
  29. from utils.downloads import gsutil_getsize
  30. from utils.metrics import box_iou, fitness
  31. # Settings
  32. FILE = Path(__file__).resolve()
  33. ROOT = FILE.parents[1] # YOLOv5 root directory
  34. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  35. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  36. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  37. pd.options.display.max_columns = 10
  38. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  39. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  40. def is_kaggle():
  41. # Is environment a Kaggle Notebook?
  42. try:
  43. assert os.environ.get('PWD') == '/kaggle/working'
  44. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  45. return True
  46. except AssertionError:
  47. return False
  48. def set_logging(name=None, verbose=True):
  49. # Sets level and returns logger
  50. if is_kaggle():
  51. for h in logging.root.handlers:
  52. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  53. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  54. logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING)
  55. return logging.getLogger(name)
  56. LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.)
  57. class Profile(contextlib.ContextDecorator):
  58. # Usage: @Profile() decorator or 'with Profile():' context manager
  59. def __enter__(self):
  60. self.start = time.time()
  61. def __exit__(self, type, value, traceback):
  62. print(f'Profile results: {time.time() - self.start:.5f}s')
  63. class Timeout(contextlib.ContextDecorator):
  64. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  65. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  66. self.seconds = int(seconds)
  67. self.timeout_message = timeout_msg
  68. self.suppress = bool(suppress_timeout_errors)
  69. def _timeout_handler(self, signum, frame):
  70. raise TimeoutError(self.timeout_message)
  71. def __enter__(self):
  72. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  73. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  74. def __exit__(self, exc_type, exc_val, exc_tb):
  75. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  76. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  77. return True
  78. class WorkingDirectory(contextlib.ContextDecorator):
  79. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  80. def __init__(self, new_dir):
  81. self.dir = new_dir # new dir
  82. self.cwd = Path.cwd().resolve() # current dir
  83. def __enter__(self):
  84. os.chdir(self.dir)
  85. def __exit__(self, exc_type, exc_val, exc_tb):
  86. os.chdir(self.cwd)
  87. def try_except(func):
  88. # try-except function. Usage: @try_except decorator
  89. def handler(*args, **kwargs):
  90. try:
  91. func(*args, **kwargs)
  92. except Exception as e:
  93. print(e)
  94. return handler
  95. def methods(instance):
  96. # Get class/instance methods
  97. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  98. def print_args(name, opt):
  99. # Print argparser arguments
  100. LOGGER.info(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  101. def init_seeds(seed=0):
  102. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  103. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  104. import torch.backends.cudnn as cudnn
  105. random.seed(seed)
  106. np.random.seed(seed)
  107. torch.manual_seed(seed)
  108. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  109. def intersect_dicts(da, db, exclude=()):
  110. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  111. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  112. def get_latest_run(search_dir='.'):
  113. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  114. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  115. return max(last_list, key=os.path.getctime) if last_list else ''
  116. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  117. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  118. env = os.getenv(env_var)
  119. if env:
  120. path = Path(env) # use environment variable
  121. else:
  122. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  123. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  124. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  125. path.mkdir(exist_ok=True) # make if required
  126. return path
  127. def is_writeable(dir, test=False):
  128. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  129. if test: # method 1
  130. file = Path(dir) / 'tmp.txt'
  131. try:
  132. with open(file, 'w'): # open file with write permissions
  133. pass
  134. file.unlink() # remove file
  135. return True
  136. except OSError:
  137. return False
  138. else: # method 2
  139. return os.access(dir, os.R_OK) # possible issues on Windows
  140. def is_docker():
  141. # Is environment a Docker container?
  142. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  143. def is_colab():
  144. # Is environment a Google Colab instance?
  145. try:
  146. import google.colab
  147. return True
  148. except ImportError:
  149. return False
  150. def is_pip():
  151. # Is file in a pip package?
  152. return 'site-packages' in Path(__file__).resolve().parts
  153. def is_ascii(s=''):
  154. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  155. s = str(s) # convert list, tuple, None, etc. to str
  156. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  157. def is_chinese(s='人工智能'):
  158. # Is string composed of any Chinese characters?
  159. return re.search('[\u4e00-\u9fff]', s)
  160. def emojis(str=''):
  161. # Return platform-dependent emoji-safe version of string
  162. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  163. def file_size(path):
  164. # Return file/dir size (MB)
  165. path = Path(path)
  166. if path.is_file():
  167. return path.stat().st_size / 1E6
  168. elif path.is_dir():
  169. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  170. else:
  171. return 0.0
  172. def check_online():
  173. # Check internet connectivity
  174. import socket
  175. try:
  176. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  177. return True
  178. except OSError:
  179. return False
  180. @try_except
  181. @WorkingDirectory(ROOT)
  182. def check_git_status():
  183. # Recommend 'git pull' if code is out of date
  184. msg = ', for updates see https://github.com/ultralytics/yolov5'
  185. print(colorstr('github: '), end='')
  186. assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
  187. assert not is_docker(), 'skipping check (Docker image)' + msg
  188. assert check_online(), 'skipping check (offline)' + msg
  189. cmd = 'git fetch && git config --get remote.origin.url'
  190. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  191. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  192. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  193. if n > 0:
  194. s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  195. else:
  196. s = f'up to date with {url} ✅'
  197. print(emojis(s)) # emoji-safe
  198. def check_python(minimum='3.6.2'):
  199. # Check current python version vs. required python version
  200. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  201. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  202. # Check version vs. required version
  203. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  204. result = (current == minimum) if pinned else (current >= minimum) # bool
  205. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  206. if hard:
  207. assert result, s # assert min requirements met
  208. if verbose and not result:
  209. LOGGER.warning(s)
  210. return result
  211. @try_except
  212. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  213. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  214. prefix = colorstr('red', 'bold', 'requirements:')
  215. check_python() # check python version
  216. if isinstance(requirements, (str, Path)): # requirements.txt file
  217. file = Path(requirements)
  218. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  219. with file.open() as f:
  220. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  221. else: # list or tuple of packages
  222. requirements = [x for x in requirements if x not in exclude]
  223. n = 0 # number of packages updates
  224. for r in requirements:
  225. try:
  226. pkg.require(r)
  227. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  228. s = f"{prefix} {r} not found and is required by YOLOv5"
  229. if install:
  230. print(f"{s}, attempting auto-update...")
  231. try:
  232. assert check_online(), f"'pip install {r}' skipped (offline)"
  233. print(check_output(f"pip install '{r}'", shell=True).decode())
  234. n += 1
  235. except Exception as e:
  236. print(f'{prefix} {e}')
  237. else:
  238. print(f'{s}. Please install and rerun your command.')
  239. if n: # if packages updated
  240. source = file.resolve() if 'file' in locals() else requirements
  241. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  242. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  243. print(emojis(s))
  244. def check_img_size(imgsz, s=32, floor=0):
  245. # Verify image size is a multiple of stride s in each dimension
  246. if isinstance(imgsz, int): # integer i.e. img_size=640
  247. new_size = max(make_divisible(imgsz, int(s)), floor)
  248. else: # list i.e. img_size=[640, 480]
  249. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  250. if new_size != imgsz:
  251. print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  252. return new_size
  253. def check_imshow():
  254. # Check if environment supports image displays
  255. try:
  256. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  257. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  258. cv2.imshow('test', np.zeros((1, 1, 3)))
  259. cv2.waitKey(1)
  260. cv2.destroyAllWindows()
  261. cv2.waitKey(1)
  262. return True
  263. except Exception as e:
  264. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  265. return False
  266. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  267. # Check file(s) for acceptable suffix
  268. if file and suffix:
  269. if isinstance(suffix, str):
  270. suffix = [suffix]
  271. for f in file if isinstance(file, (list, tuple)) else [file]:
  272. s = Path(f).suffix.lower() # file suffix
  273. if len(s):
  274. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  275. def check_yaml(file, suffix=('.yaml', '.yml')):
  276. # Search/download YAML file (if necessary) and return path, checking suffix
  277. return check_file(file, suffix)
  278. def check_file(file, suffix=''):
  279. # Search/download file (if necessary) and return path
  280. check_suffix(file, suffix) # optional
  281. file = str(file) # convert to str()
  282. if Path(file).is_file() or file == '': # exists
  283. return file
  284. elif file.startswith(('http:/', 'https:/')): # download
  285. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  286. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  287. if Path(file).is_file():
  288. print(f'Found {url} locally at {file}') # file already exists
  289. else:
  290. print(f'Downloading {url} to {file}...')
  291. torch.hub.download_url_to_file(url, file)
  292. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  293. return file
  294. else: # search
  295. files = []
  296. for d in 'data', 'models', 'utils': # search directories
  297. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  298. assert len(files), f'File not found: {file}' # assert file was found
  299. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  300. return files[0] # return file
  301. def check_dataset(data, autodownload=True):
  302. # Download and/or unzip dataset if not found locally
  303. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  304. # Download (optional)
  305. extract_dir = ''
  306. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  307. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  308. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  309. extract_dir, autodownload = data.parent, False
  310. # Read yaml (optional)
  311. if isinstance(data, (str, Path)):
  312. with open(data, errors='ignore') as f:
  313. data = yaml.safe_load(f) # dictionary
  314. # Parse yaml
  315. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  316. for k in 'train', 'val', 'test':
  317. if data.get(k): # prepend path
  318. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  319. assert 'nc' in data, "Dataset 'nc' key missing."
  320. if 'names' not in data:
  321. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  322. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  323. if val:
  324. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  325. if not all(x.exists() for x in val):
  326. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  327. if s and autodownload: # download script
  328. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  329. if s.startswith('http') and s.endswith('.zip'): # URL
  330. f = Path(s).name # filename
  331. print(f'Downloading {s} to {f}...')
  332. torch.hub.download_url_to_file(s, f)
  333. Path(root).mkdir(parents=True, exist_ok=True) # create root
  334. ZipFile(f).extractall(path=root) # unzip
  335. Path(f).unlink() # remove zip
  336. r = None # success
  337. elif s.startswith('bash '): # bash script
  338. print(f'Running {s} ...')
  339. r = os.system(s)
  340. else: # python script
  341. r = exec(s, {'yaml': data}) # return None
  342. print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n")
  343. else:
  344. raise Exception('Dataset not found.')
  345. return data # dictionary
  346. def url2file(url):
  347. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  348. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  349. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  350. return file
  351. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  352. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  353. def download_one(url, dir):
  354. # Download 1 file
  355. f = dir / Path(url).name # filename
  356. if Path(url).is_file(): # exists in current path
  357. Path(url).rename(f) # move to dir
  358. elif not f.exists():
  359. print(f'Downloading {url} to {f}...')
  360. if curl:
  361. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  362. else:
  363. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  364. if unzip and f.suffix in ('.zip', '.gz'):
  365. print(f'Unzipping {f}...')
  366. if f.suffix == '.zip':
  367. ZipFile(f).extractall(path=dir) # unzip
  368. elif f.suffix == '.gz':
  369. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  370. if delete:
  371. f.unlink() # remove zip
  372. dir = Path(dir)
  373. dir.mkdir(parents=True, exist_ok=True) # make directory
  374. if threads > 1:
  375. pool = ThreadPool(threads)
  376. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  377. pool.close()
  378. pool.join()
  379. else:
  380. for u in [url] if isinstance(url, (str, Path)) else url:
  381. download_one(u, dir)
  382. def make_divisible(x, divisor):
  383. # Returns nearest x divisible by divisor
  384. if isinstance(divisor, torch.Tensor):
  385. divisor = int(divisor.max()) # to int
  386. return math.ceil(x / divisor) * divisor
  387. def clean_str(s):
  388. # Cleans a string by replacing special characters with underscore _
  389. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  390. def one_cycle(y1=0.0, y2=1.0, steps=100):
  391. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  392. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  393. def colorstr(*input):
  394. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  395. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  396. colors = {'black': '\033[30m', # basic colors
  397. 'red': '\033[31m',
  398. 'green': '\033[32m',
  399. 'yellow': '\033[33m',
  400. 'blue': '\033[34m',
  401. 'magenta': '\033[35m',
  402. 'cyan': '\033[36m',
  403. 'white': '\033[37m',
  404. 'bright_black': '\033[90m', # bright colors
  405. 'bright_red': '\033[91m',
  406. 'bright_green': '\033[92m',
  407. 'bright_yellow': '\033[93m',
  408. 'bright_blue': '\033[94m',
  409. 'bright_magenta': '\033[95m',
  410. 'bright_cyan': '\033[96m',
  411. 'bright_white': '\033[97m',
  412. 'end': '\033[0m', # misc
  413. 'bold': '\033[1m',
  414. 'underline': '\033[4m'}
  415. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  416. def labels_to_class_weights(labels, nc=80):
  417. # Get class weights (inverse frequency) from training labels
  418. if labels[0] is None: # no labels loaded
  419. return torch.Tensor()
  420. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  421. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  422. weights = np.bincount(classes, minlength=nc) # occurrences per class
  423. # Prepend gridpoint count (for uCE training)
  424. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  425. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  426. weights[weights == 0] = 1 # replace empty bins with 1
  427. weights = 1 / weights # number of targets per class
  428. weights /= weights.sum() # normalize
  429. return torch.from_numpy(weights)
  430. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  431. # Produces image weights based on class_weights and image contents
  432. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  433. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  434. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  435. return image_weights
  436. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  437. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  438. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  439. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  440. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  441. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  442. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  443. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  444. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  445. return x
  446. def xyxy2xywh(x):
  447. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  448. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  449. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  450. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  451. y[:, 2] = x[:, 2] - x[:, 0] # width
  452. y[:, 3] = x[:, 3] - x[:, 1] # height
  453. return y
  454. def xywh2xyxy(x):
  455. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  456. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  457. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  458. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  459. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  460. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  461. return y
  462. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  463. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  464. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  465. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  466. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  467. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  468. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  469. return y
  470. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  471. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  472. if clip:
  473. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  474. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  475. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  476. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  477. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  478. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  479. return y
  480. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  481. # Convert normalized segments into pixel segments, shape (n,2)
  482. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  483. y[:, 0] = w * x[:, 0] + padw # top left x
  484. y[:, 1] = h * x[:, 1] + padh # top left y
  485. return y
  486. def segment2box(segment, width=640, height=640):
  487. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  488. x, y = segment.T # segment xy
  489. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  490. x, y, = x[inside], y[inside]
  491. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  492. def segments2boxes(segments):
  493. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  494. boxes = []
  495. for s in segments:
  496. x, y = s.T # segment xy
  497. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  498. return xyxy2xywh(np.array(boxes)) # cls, xywh
  499. def resample_segments(segments, n=1000):
  500. # Up-sample an (n,2) segment
  501. for i, s in enumerate(segments):
  502. x = np.linspace(0, len(s) - 1, n)
  503. xp = np.arange(len(s))
  504. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  505. return segments
  506. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  507. # Rescale coords (xyxy) from img1_shape to img0_shape
  508. if ratio_pad is None: # calculate from img0_shape
  509. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  510. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  511. else:
  512. gain = ratio_pad[0][0]
  513. pad = ratio_pad[1]
  514. coords[:, [0, 2]] -= pad[0] # x padding
  515. coords[:, [1, 3]] -= pad[1] # y padding
  516. coords[:, :4] /= gain
  517. clip_coords(coords, img0_shape)
  518. return coords
  519. def clip_coords(boxes, shape):
  520. # Clip bounding xyxy bounding boxes to image shape (height, width)
  521. if isinstance(boxes, torch.Tensor): # faster individually
  522. boxes[:, 0].clamp_(0, shape[1]) # x1
  523. boxes[:, 1].clamp_(0, shape[0]) # y1
  524. boxes[:, 2].clamp_(0, shape[1]) # x2
  525. boxes[:, 3].clamp_(0, shape[0]) # y2
  526. else: # np.array (faster grouped)
  527. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  528. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  529. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  530. labels=(), max_det=300):
  531. """Runs Non-Maximum Suppression (NMS) on inference results
  532. Returns:
  533. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  534. """
  535. nc = prediction.shape[2] - 5 # number of classes
  536. xc = prediction[..., 4] > conf_thres # candidates
  537. # Checks
  538. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  539. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  540. # Settings
  541. min_wh, max_wh = 2, 7680 # (pixels) minimum and maximum box width and height
  542. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  543. time_limit = 10.0 # seconds to quit after
  544. redundant = True # require redundant detections
  545. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  546. merge = False # use merge-NMS
  547. t = time.time()
  548. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  549. for xi, x in enumerate(prediction): # image index, image inference
  550. # Apply constraints
  551. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  552. x = x[xc[xi]] # confidence
  553. # Cat apriori labels if autolabelling
  554. if labels and len(labels[xi]):
  555. l = labels[xi]
  556. v = torch.zeros((len(l), nc + 5), device=x.device)
  557. v[:, :4] = l[:, 1:5] # box
  558. v[:, 4] = 1.0 # conf
  559. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  560. x = torch.cat((x, v), 0)
  561. # If none remain process next image
  562. if not x.shape[0]:
  563. continue
  564. # Compute conf
  565. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  566. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  567. box = xywh2xyxy(x[:, :4])
  568. # Detections matrix nx6 (xyxy, conf, cls)
  569. if multi_label:
  570. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  571. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  572. else: # best class only
  573. conf, j = x[:, 5:].max(1, keepdim=True)
  574. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  575. # Filter by class
  576. if classes is not None:
  577. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  578. # Apply finite constraint
  579. # if not torch.isfinite(x).all():
  580. # x = x[torch.isfinite(x).all(1)]
  581. # Check shape
  582. n = x.shape[0] # number of boxes
  583. if not n: # no boxes
  584. continue
  585. elif n > max_nms: # excess boxes
  586. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  587. # Batched NMS
  588. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  589. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  590. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  591. if i.shape[0] > max_det: # limit detections
  592. i = i[:max_det]
  593. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  594. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  595. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  596. weights = iou * scores[None] # box weights
  597. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  598. if redundant:
  599. i = i[iou.sum(1) > 1] # require redundancy
  600. output[xi] = x[i]
  601. if (time.time() - t) > time_limit:
  602. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  603. break # time limit exceeded
  604. return output
  605. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  606. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  607. x = torch.load(f, map_location=torch.device('cpu'))
  608. if x.get('ema'):
  609. x['model'] = x['ema'] # replace model with ema
  610. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  611. x[k] = None
  612. x['epoch'] = -1
  613. x['model'].half() # to FP16
  614. for p in x['model'].parameters():
  615. p.requires_grad = False
  616. torch.save(x, s or f)
  617. mb = os.path.getsize(s or f) / 1E6 # filesize
  618. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  619. def print_mutation(results, hyp, save_dir, bucket):
  620. evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
  621. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  622. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  623. keys = tuple(x.strip() for x in keys)
  624. vals = results + tuple(hyp.values())
  625. n = len(keys)
  626. # Download (optional)
  627. if bucket:
  628. url = f'gs://{bucket}/evolve.csv'
  629. if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
  630. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  631. # Log to evolve.csv
  632. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  633. with open(evolve_csv, 'a') as f:
  634. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  635. # Print to screen
  636. print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
  637. print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
  638. # Save yaml
  639. with open(evolve_yaml, 'w') as f:
  640. data = pd.read_csv(evolve_csv)
  641. data = data.rename(columns=lambda x: x.strip()) # strip keys
  642. i = np.argmax(fitness(data.values[:, :7])) #
  643. f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
  644. f'# Best generation: {i}\n' +
  645. f'# Last generation: {len(data) - 1}\n' +
  646. '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  647. '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  648. yaml.safe_dump(hyp, f, sort_keys=False)
  649. if bucket:
  650. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  651. def apply_classifier(x, model, img, im0):
  652. # Apply a second stage classifier to YOLO outputs
  653. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  654. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  655. for i, d in enumerate(x): # per image
  656. if d is not None and len(d):
  657. d = d.clone()
  658. # Reshape and pad cutouts
  659. b = xyxy2xywh(d[:, :4]) # boxes
  660. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  661. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  662. d[:, :4] = xywh2xyxy(b).long()
  663. # Rescale boxes from img_size to im0 size
  664. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  665. # Classes
  666. pred_cls1 = d[:, 5].long()
  667. ims = []
  668. for j, a in enumerate(d): # per item
  669. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  670. im = cv2.resize(cutout, (224, 224)) # BGR
  671. # cv2.imwrite('example%i.jpg' % j, cutout)
  672. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  673. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  674. im /= 255 # 0 - 255 to 0.0 - 1.0
  675. ims.append(im)
  676. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  677. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  678. return x
  679. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  680. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  681. path = Path(path) # os-agnostic
  682. if path.exists() and not exist_ok:
  683. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  684. dirs = glob.glob(f"{path}{sep}*") # similar paths
  685. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  686. i = [int(m.groups()[0]) for m in matches if m] # indices
  687. n = max(i) + 1 if i else 2 # increment number
  688. path = Path(f"{path}{sep}{n}{suffix}") # increment path
  689. if mkdir:
  690. path.mkdir(parents=True, exist_ok=True) # make directory
  691. return path
  692. # Variables
  693. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm