You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

882 lines
36KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import shutil
  14. import signal
  15. import time
  16. import urllib
  17. from itertools import repeat
  18. from multiprocessing.pool import ThreadPool
  19. from pathlib import Path
  20. from subprocess import check_output
  21. from zipfile import ZipFile
  22. import cv2
  23. import numpy as np
  24. import pandas as pd
  25. import pkg_resources as pkg
  26. import torch
  27. import torchvision
  28. import yaml
  29. from utils.downloads import gsutil_getsize
  30. from utils.metrics import box_iou, fitness
  31. # Settings
  32. FILE = Path(__file__).resolve()
  33. ROOT = FILE.parents[1] # YOLOv5 root directory
  34. DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
  35. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  36. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  37. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  38. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  39. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  40. pd.options.display.max_columns = 10
  41. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  42. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  43. os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy)
  44. def is_kaggle():
  45. # Is environment a Kaggle Notebook?
  46. try:
  47. assert os.environ.get('PWD') == '/kaggle/working'
  48. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  49. return True
  50. except AssertionError:
  51. return False
  52. def is_writeable(dir, test=False):
  53. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  54. if test: # method 1
  55. file = Path(dir) / 'tmp.txt'
  56. try:
  57. with open(file, 'w'): # open file with write permissions
  58. pass
  59. file.unlink() # remove file
  60. return True
  61. except OSError:
  62. return False
  63. else: # method 2
  64. return os.access(dir, os.R_OK) # possible issues on Windows
  65. def set_logging(name=None, verbose=VERBOSE):
  66. # Sets level and returns logger
  67. if is_kaggle():
  68. for h in logging.root.handlers:
  69. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  70. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  71. logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING)
  72. return logging.getLogger(name)
  73. LOGGER = set_logging('yolov5') # define globally (used in train.py, val.py, detect.py, etc.)
  74. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  75. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  76. env = os.getenv(env_var)
  77. if env:
  78. path = Path(env) # use environment variable
  79. else:
  80. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  81. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  82. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  83. path.mkdir(exist_ok=True) # make if required
  84. return path
  85. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  86. class Profile(contextlib.ContextDecorator):
  87. # Usage: @Profile() decorator or 'with Profile():' context manager
  88. def __enter__(self):
  89. self.start = time.time()
  90. def __exit__(self, type, value, traceback):
  91. print(f'Profile results: {time.time() - self.start:.5f}s')
  92. class Timeout(contextlib.ContextDecorator):
  93. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  94. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  95. self.seconds = int(seconds)
  96. self.timeout_message = timeout_msg
  97. self.suppress = bool(suppress_timeout_errors)
  98. def _timeout_handler(self, signum, frame):
  99. raise TimeoutError(self.timeout_message)
  100. def __enter__(self):
  101. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  102. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  103. def __exit__(self, exc_type, exc_val, exc_tb):
  104. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  105. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  106. return True
  107. class WorkingDirectory(contextlib.ContextDecorator):
  108. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  109. def __init__(self, new_dir):
  110. self.dir = new_dir # new dir
  111. self.cwd = Path.cwd().resolve() # current dir
  112. def __enter__(self):
  113. os.chdir(self.dir)
  114. def __exit__(self, exc_type, exc_val, exc_tb):
  115. os.chdir(self.cwd)
  116. def try_except(func):
  117. # try-except function. Usage: @try_except decorator
  118. def handler(*args, **kwargs):
  119. try:
  120. func(*args, **kwargs)
  121. except Exception as e:
  122. print(e)
  123. return handler
  124. def methods(instance):
  125. # Get class/instance methods
  126. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  127. def print_args(name, opt):
  128. # Print argparser arguments
  129. LOGGER.info(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  130. def init_seeds(seed=0):
  131. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  132. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  133. import torch.backends.cudnn as cudnn
  134. random.seed(seed)
  135. np.random.seed(seed)
  136. torch.manual_seed(seed)
  137. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  138. def intersect_dicts(da, db, exclude=()):
  139. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  140. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  141. def get_latest_run(search_dir='.'):
  142. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  143. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  144. return max(last_list, key=os.path.getctime) if last_list else ''
  145. def is_docker():
  146. # Is environment a Docker container?
  147. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  148. def is_colab():
  149. # Is environment a Google Colab instance?
  150. try:
  151. import google.colab
  152. return True
  153. except ImportError:
  154. return False
  155. def is_pip():
  156. # Is file in a pip package?
  157. return 'site-packages' in Path(__file__).resolve().parts
  158. def is_ascii(s=''):
  159. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  160. s = str(s) # convert list, tuple, None, etc. to str
  161. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  162. def is_chinese(s='人工智能'):
  163. # Is string composed of any Chinese characters?
  164. return True if re.search('[\u4e00-\u9fff]', str(s)) else False
  165. def emojis(str=''):
  166. # Return platform-dependent emoji-safe version of string
  167. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  168. def file_size(path):
  169. # Return file/dir size (MB)
  170. path = Path(path)
  171. if path.is_file():
  172. return path.stat().st_size / 1E6
  173. elif path.is_dir():
  174. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  175. else:
  176. return 0.0
  177. def check_online():
  178. # Check internet connectivity
  179. import socket
  180. try:
  181. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  182. return True
  183. except OSError:
  184. return False
  185. @try_except
  186. @WorkingDirectory(ROOT)
  187. def check_git_status():
  188. # Recommend 'git pull' if code is out of date
  189. msg = ', for updates see https://github.com/ultralytics/yolov5'
  190. s = colorstr('github: ') # string
  191. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  192. assert not is_docker(), s + 'skipping check (Docker image)' + msg
  193. assert check_online(), s + 'skipping check (offline)' + msg
  194. cmd = 'git fetch && git config --get remote.origin.url'
  195. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  196. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  197. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  198. if n > 0:
  199. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  200. else:
  201. s += f'up to date with {url} ✅'
  202. LOGGER.info(emojis(s)) # emoji-safe
  203. def check_python(minimum='3.6.2'):
  204. # Check current python version vs. required python version
  205. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  206. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  207. # Check version vs. required version
  208. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  209. result = (current == minimum) if pinned else (current >= minimum) # bool
  210. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  211. if hard:
  212. assert result, s # assert min requirements met
  213. if verbose and not result:
  214. LOGGER.warning(s)
  215. return result
  216. @try_except
  217. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  218. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  219. prefix = colorstr('red', 'bold', 'requirements:')
  220. check_python() # check python version
  221. if isinstance(requirements, (str, Path)): # requirements.txt file
  222. file = Path(requirements)
  223. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  224. with file.open() as f:
  225. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  226. else: # list or tuple of packages
  227. requirements = [x for x in requirements if x not in exclude]
  228. n = 0 # number of packages updates
  229. for r in requirements:
  230. try:
  231. pkg.require(r)
  232. except Exception: # DistributionNotFound or VersionConflict if requirements not met
  233. s = f"{prefix} {r} not found and is required by YOLOv5"
  234. if install:
  235. LOGGER.info(f"{s}, attempting auto-update...")
  236. try:
  237. assert check_online(), f"'pip install {r}' skipped (offline)"
  238. LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode())
  239. n += 1
  240. except Exception as e:
  241. LOGGER.warning(f'{prefix} {e}')
  242. else:
  243. LOGGER.info(f'{s}. Please install and rerun your command.')
  244. if n: # if packages updated
  245. source = file.resolve() if 'file' in locals() else requirements
  246. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  247. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  248. LOGGER.info(emojis(s))
  249. def check_img_size(imgsz, s=32, floor=0):
  250. # Verify image size is a multiple of stride s in each dimension
  251. if isinstance(imgsz, int): # integer i.e. img_size=640
  252. new_size = max(make_divisible(imgsz, int(s)), floor)
  253. else: # list i.e. img_size=[640, 480]
  254. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  255. if new_size != imgsz:
  256. LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  257. return new_size
  258. def check_imshow():
  259. # Check if environment supports image displays
  260. try:
  261. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  262. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  263. cv2.imshow('test', np.zeros((1, 1, 3)))
  264. cv2.waitKey(1)
  265. cv2.destroyAllWindows()
  266. cv2.waitKey(1)
  267. return True
  268. except Exception as e:
  269. LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  270. return False
  271. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  272. # Check file(s) for acceptable suffix
  273. if file and suffix:
  274. if isinstance(suffix, str):
  275. suffix = [suffix]
  276. for f in file if isinstance(file, (list, tuple)) else [file]:
  277. s = Path(f).suffix.lower() # file suffix
  278. if len(s):
  279. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  280. def check_yaml(file, suffix=('.yaml', '.yml')):
  281. # Search/download YAML file (if necessary) and return path, checking suffix
  282. return check_file(file, suffix)
  283. def check_file(file, suffix=''):
  284. # Search/download file (if necessary) and return path
  285. check_suffix(file, suffix) # optional
  286. file = str(file) # convert to str()
  287. if Path(file).is_file() or file == '': # exists
  288. return file
  289. elif file.startswith(('http:/', 'https:/')): # download
  290. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  291. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  292. if Path(file).is_file():
  293. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  294. else:
  295. LOGGER.info(f'Downloading {url} to {file}...')
  296. torch.hub.download_url_to_file(url, file)
  297. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  298. return file
  299. else: # search
  300. files = []
  301. for d in 'data', 'models', 'utils': # search directories
  302. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  303. assert len(files), f'File not found: {file}' # assert file was found
  304. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  305. return files[0] # return file
  306. def check_font(font=FONT):
  307. # Download font to CONFIG_DIR if necessary
  308. font = Path(font)
  309. if not font.exists() and not (CONFIG_DIR / font.name).exists():
  310. url = "https://ultralytics.com/assets/" + font.name
  311. LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...')
  312. torch.hub.download_url_to_file(url, str(font), progress=False)
  313. def check_dataset(data, autodownload=True):
  314. # Download and/or unzip dataset if not found locally
  315. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  316. # Download (optional)
  317. extract_dir = ''
  318. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  319. download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
  320. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  321. extract_dir, autodownload = data.parent, False
  322. # Read yaml (optional)
  323. if isinstance(data, (str, Path)):
  324. with open(data, errors='ignore') as f:
  325. data = yaml.safe_load(f) # dictionary
  326. # Resolve paths
  327. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  328. if not path.is_absolute():
  329. path = (ROOT / path).resolve()
  330. for k in 'train', 'val', 'test':
  331. if data.get(k): # prepend path
  332. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  333. # Parse yaml
  334. assert 'nc' in data, "Dataset 'nc' key missing."
  335. if 'names' not in data:
  336. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  337. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  338. if val:
  339. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  340. if not all(x.exists() for x in val):
  341. LOGGER.info('\nDataset not found, missing paths: %s' % [str(x) for x in val if not x.exists()])
  342. if s and autodownload: # download script
  343. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  344. if s.startswith('http') and s.endswith('.zip'): # URL
  345. f = Path(s).name # filename
  346. LOGGER.info(f'Downloading {s} to {f}...')
  347. torch.hub.download_url_to_file(s, f)
  348. Path(root).mkdir(parents=True, exist_ok=True) # create root
  349. ZipFile(f).extractall(path=root) # unzip
  350. Path(f).unlink() # remove zip
  351. r = None # success
  352. elif s.startswith('bash '): # bash script
  353. LOGGER.info(f'Running {s} ...')
  354. r = os.system(s)
  355. else: # python script
  356. r = exec(s, {'yaml': data}) # return None
  357. LOGGER.info(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n")
  358. else:
  359. raise Exception('Dataset not found.')
  360. return data # dictionary
  361. def url2file(url):
  362. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  363. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  364. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  365. return file
  366. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  367. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  368. def download_one(url, dir):
  369. # Download 1 file
  370. f = dir / Path(url).name # filename
  371. if Path(url).is_file(): # exists in current path
  372. Path(url).rename(f) # move to dir
  373. elif not f.exists():
  374. LOGGER.info(f'Downloading {url} to {f}...')
  375. if curl:
  376. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  377. else:
  378. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  379. if unzip and f.suffix in ('.zip', '.gz'):
  380. LOGGER.info(f'Unzipping {f}...')
  381. if f.suffix == '.zip':
  382. ZipFile(f).extractall(path=dir) # unzip
  383. elif f.suffix == '.gz':
  384. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  385. if delete:
  386. f.unlink() # remove zip
  387. dir = Path(dir)
  388. dir.mkdir(parents=True, exist_ok=True) # make directory
  389. if threads > 1:
  390. pool = ThreadPool(threads)
  391. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  392. pool.close()
  393. pool.join()
  394. else:
  395. for u in [url] if isinstance(url, (str, Path)) else url:
  396. download_one(u, dir)
  397. def make_divisible(x, divisor):
  398. # Returns nearest x divisible by divisor
  399. if isinstance(divisor, torch.Tensor):
  400. divisor = int(divisor.max()) # to int
  401. return math.ceil(x / divisor) * divisor
  402. def clean_str(s):
  403. # Cleans a string by replacing special characters with underscore _
  404. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  405. def one_cycle(y1=0.0, y2=1.0, steps=100):
  406. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  407. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  408. def colorstr(*input):
  409. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  410. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  411. colors = {'black': '\033[30m', # basic colors
  412. 'red': '\033[31m',
  413. 'green': '\033[32m',
  414. 'yellow': '\033[33m',
  415. 'blue': '\033[34m',
  416. 'magenta': '\033[35m',
  417. 'cyan': '\033[36m',
  418. 'white': '\033[37m',
  419. 'bright_black': '\033[90m', # bright colors
  420. 'bright_red': '\033[91m',
  421. 'bright_green': '\033[92m',
  422. 'bright_yellow': '\033[93m',
  423. 'bright_blue': '\033[94m',
  424. 'bright_magenta': '\033[95m',
  425. 'bright_cyan': '\033[96m',
  426. 'bright_white': '\033[97m',
  427. 'end': '\033[0m', # misc
  428. 'bold': '\033[1m',
  429. 'underline': '\033[4m'}
  430. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  431. def labels_to_class_weights(labels, nc=80):
  432. # Get class weights (inverse frequency) from training labels
  433. if labels[0] is None: # no labels loaded
  434. return torch.Tensor()
  435. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  436. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  437. weights = np.bincount(classes, minlength=nc) # occurrences per class
  438. # Prepend gridpoint count (for uCE training)
  439. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  440. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  441. weights[weights == 0] = 1 # replace empty bins with 1
  442. weights = 1 / weights # number of targets per class
  443. weights /= weights.sum() # normalize
  444. return torch.from_numpy(weights)
  445. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  446. # Produces image weights based on class_weights and image contents
  447. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  448. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  449. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  450. return image_weights
  451. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  452. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  453. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  454. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  455. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  456. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  457. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  458. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  459. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  460. return x
  461. def xyxy2xywh(x):
  462. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  463. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  464. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  465. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  466. y[:, 2] = x[:, 2] - x[:, 0] # width
  467. y[:, 3] = x[:, 3] - x[:, 1] # height
  468. return y
  469. def xywh2xyxy(x):
  470. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  471. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  472. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  473. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  474. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  475. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  476. return y
  477. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  478. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  479. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  480. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  481. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  482. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  483. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  484. return y
  485. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  486. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  487. if clip:
  488. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  489. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  490. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  491. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  492. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  493. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  494. return y
  495. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  496. # Convert normalized segments into pixel segments, shape (n,2)
  497. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  498. y[:, 0] = w * x[:, 0] + padw # top left x
  499. y[:, 1] = h * x[:, 1] + padh # top left y
  500. return y
  501. def segment2box(segment, width=640, height=640):
  502. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  503. x, y = segment.T # segment xy
  504. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  505. x, y, = x[inside], y[inside]
  506. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  507. def segments2boxes(segments):
  508. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  509. boxes = []
  510. for s in segments:
  511. x, y = s.T # segment xy
  512. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  513. return xyxy2xywh(np.array(boxes)) # cls, xywh
  514. def resample_segments(segments, n=1000):
  515. # Up-sample an (n,2) segment
  516. for i, s in enumerate(segments):
  517. x = np.linspace(0, len(s) - 1, n)
  518. xp = np.arange(len(s))
  519. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  520. return segments
  521. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  522. # Rescale coords (xyxy) from img1_shape to img0_shape
  523. if ratio_pad is None: # calculate from img0_shape
  524. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  525. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  526. else:
  527. gain = ratio_pad[0][0]
  528. pad = ratio_pad[1]
  529. coords[:, [0, 2]] -= pad[0] # x padding
  530. coords[:, [1, 3]] -= pad[1] # y padding
  531. coords[:, :4] /= gain
  532. clip_coords(coords, img0_shape)
  533. return coords
  534. def clip_coords(boxes, shape):
  535. # Clip bounding xyxy bounding boxes to image shape (height, width)
  536. if isinstance(boxes, torch.Tensor): # faster individually
  537. boxes[:, 0].clamp_(0, shape[1]) # x1
  538. boxes[:, 1].clamp_(0, shape[0]) # y1
  539. boxes[:, 2].clamp_(0, shape[1]) # x2
  540. boxes[:, 3].clamp_(0, shape[0]) # y2
  541. else: # np.array (faster grouped)
  542. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  543. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  544. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  545. labels=(), max_det=300):
  546. """Runs Non-Maximum Suppression (NMS) on inference results
  547. Returns:
  548. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  549. """
  550. nc = prediction.shape[2] - 5 # number of classes
  551. xc = prediction[..., 4] > conf_thres # candidates
  552. # Checks
  553. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  554. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  555. # Settings
  556. min_wh, max_wh = 2, 7680 # (pixels) minimum and maximum box width and height
  557. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  558. time_limit = 10.0 # seconds to quit after
  559. redundant = True # require redundant detections
  560. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  561. merge = False # use merge-NMS
  562. t = time.time()
  563. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  564. for xi, x in enumerate(prediction): # image index, image inference
  565. # Apply constraints
  566. x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  567. x = x[xc[xi]] # confidence
  568. # Cat apriori labels if autolabelling
  569. if labels and len(labels[xi]):
  570. lb = labels[xi]
  571. v = torch.zeros((len(lb), nc + 5), device=x.device)
  572. v[:, :4] = lb[:, 1:5] # box
  573. v[:, 4] = 1.0 # conf
  574. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  575. x = torch.cat((x, v), 0)
  576. # If none remain process next image
  577. if not x.shape[0]:
  578. continue
  579. # Compute conf
  580. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  581. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  582. box = xywh2xyxy(x[:, :4])
  583. # Detections matrix nx6 (xyxy, conf, cls)
  584. if multi_label:
  585. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  586. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  587. else: # best class only
  588. conf, j = x[:, 5:].max(1, keepdim=True)
  589. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  590. # Filter by class
  591. if classes is not None:
  592. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  593. # Apply finite constraint
  594. # if not torch.isfinite(x).all():
  595. # x = x[torch.isfinite(x).all(1)]
  596. # Check shape
  597. n = x.shape[0] # number of boxes
  598. if not n: # no boxes
  599. continue
  600. elif n > max_nms: # excess boxes
  601. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  602. # Batched NMS
  603. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  604. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  605. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  606. if i.shape[0] > max_det: # limit detections
  607. i = i[:max_det]
  608. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  609. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  610. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  611. weights = iou * scores[None] # box weights
  612. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  613. if redundant:
  614. i = i[iou.sum(1) > 1] # require redundancy
  615. output[xi] = x[i]
  616. if (time.time() - t) > time_limit:
  617. LOGGER.warning(f'WARNING: NMS time limit {time_limit}s exceeded')
  618. break # time limit exceeded
  619. return output
  620. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  621. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  622. x = torch.load(f, map_location=torch.device('cpu'))
  623. if x.get('ema'):
  624. x['model'] = x['ema'] # replace model with ema
  625. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  626. x[k] = None
  627. x['epoch'] = -1
  628. x['model'].half() # to FP16
  629. for p in x['model'].parameters():
  630. p.requires_grad = False
  631. torch.save(x, s or f)
  632. mb = os.path.getsize(s or f) / 1E6 # filesize
  633. LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  634. def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  635. evolve_csv = save_dir / 'evolve.csv'
  636. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  637. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  638. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  639. keys = tuple(x.strip() for x in keys)
  640. vals = results + tuple(hyp.values())
  641. n = len(keys)
  642. # Download (optional)
  643. if bucket:
  644. url = f'gs://{bucket}/evolve.csv'
  645. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  646. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  647. # Log to evolve.csv
  648. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  649. with open(evolve_csv, 'a') as f:
  650. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  651. # Save yaml
  652. with open(evolve_yaml, 'w') as f:
  653. data = pd.read_csv(evolve_csv)
  654. data = data.rename(columns=lambda x: x.strip()) # strip keys
  655. i = np.argmax(fitness(data.values[:, :4])) #
  656. generations = len(data)
  657. f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
  658. f'# Best generation: {i}\n' +
  659. f'# Last generation: {generations - 1}\n' +
  660. '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  661. '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  662. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  663. # Print to screen
  664. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' +
  665. prefix + ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' +
  666. prefix + ', '.join(f'{x:20.5g}' for x in vals) + '\n\n')
  667. if bucket:
  668. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  669. def apply_classifier(x, model, img, im0):
  670. # Apply a second stage classifier to YOLO outputs
  671. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  672. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  673. for i, d in enumerate(x): # per image
  674. if d is not None and len(d):
  675. d = d.clone()
  676. # Reshape and pad cutouts
  677. b = xyxy2xywh(d[:, :4]) # boxes
  678. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  679. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  680. d[:, :4] = xywh2xyxy(b).long()
  681. # Rescale boxes from img_size to im0 size
  682. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  683. # Classes
  684. pred_cls1 = d[:, 5].long()
  685. ims = []
  686. for j, a in enumerate(d): # per item
  687. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  688. im = cv2.resize(cutout, (224, 224)) # BGR
  689. # cv2.imwrite('example%i.jpg' % j, cutout)
  690. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  691. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  692. im /= 255 # 0 - 255 to 0.0 - 1.0
  693. ims.append(im)
  694. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  695. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  696. return x
  697. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  698. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  699. path = Path(path) # os-agnostic
  700. if path.exists() and not exist_ok:
  701. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  702. dirs = glob.glob(f"{path}{sep}*") # similar paths
  703. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  704. i = [int(m.groups()[0]) for m in matches if m] # indices
  705. n = max(i) + 1 if i else 2 # increment number
  706. path = Path(f"{path}{sep}{n}{suffix}") # increment path
  707. if mkdir:
  708. path.mkdir(parents=True, exist_ok=True) # make directory
  709. return path
  710. # Variables
  711. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm