You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

967 lines
39KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import inspect
  8. import logging
  9. import math
  10. import os
  11. import platform
  12. import random
  13. import re
  14. import shutil
  15. import signal
  16. import time
  17. import urllib
  18. from datetime import datetime
  19. from itertools import repeat
  20. from multiprocessing.pool import ThreadPool
  21. from pathlib import Path
  22. from subprocess import check_output
  23. from typing import Optional
  24. from zipfile import ZipFile
  25. import cv2
  26. import numpy as np
  27. import pandas as pd
  28. import pkg_resources as pkg
  29. import torch
  30. import torchvision
  31. import yaml
  32. from utils.downloads import gsutil_getsize
  33. from utils.metrics import box_iou, fitness
  34. # Settings
  35. FILE = Path(__file__).resolve()
  36. ROOT = FILE.parents[1] # YOLOv5 root directory
  37. DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
  38. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  39. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  40. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  41. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  42. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  43. pd.options.display.max_columns = 10
  44. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  45. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  46. os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy)
  47. def is_kaggle():
  48. # Is environment a Kaggle Notebook?
  49. try:
  50. assert os.environ.get('PWD') == '/kaggle/working'
  51. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  52. return True
  53. except AssertionError:
  54. return False
  55. def is_writeable(dir, test=False):
  56. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  57. if test: # method 1
  58. file = Path(dir) / 'tmp.txt'
  59. try:
  60. with open(file, 'w'): # open file with write permissions
  61. pass
  62. file.unlink() # remove file
  63. return True
  64. except OSError:
  65. return False
  66. else: # method 2
  67. return os.access(dir, os.R_OK) # possible issues on Windows
  68. def set_logging(name=None, verbose=VERBOSE):
  69. # Sets level and returns logger
  70. if is_kaggle():
  71. for h in logging.root.handlers:
  72. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  73. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  74. level = logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING
  75. log = logging.getLogger(name)
  76. log.setLevel(level)
  77. handler = logging.StreamHandler()
  78. handler.setFormatter(logging.Formatter("%(message)s"))
  79. handler.setLevel(level)
  80. log.addHandler(handler)
  81. set_logging() # run before defining LOGGER
  82. LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
  83. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  84. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  85. env = os.getenv(env_var)
  86. if env:
  87. path = Path(env) # use environment variable
  88. else:
  89. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  90. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  91. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  92. path.mkdir(exist_ok=True) # make if required
  93. return path
  94. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  95. class Profile(contextlib.ContextDecorator):
  96. # Usage: @Profile() decorator or 'with Profile():' context manager
  97. def __enter__(self):
  98. self.start = time.time()
  99. def __exit__(self, type, value, traceback):
  100. print(f'Profile results: {time.time() - self.start:.5f}s')
  101. class Timeout(contextlib.ContextDecorator):
  102. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  103. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  104. self.seconds = int(seconds)
  105. self.timeout_message = timeout_msg
  106. self.suppress = bool(suppress_timeout_errors)
  107. def _timeout_handler(self, signum, frame):
  108. raise TimeoutError(self.timeout_message)
  109. def __enter__(self):
  110. if platform.system() != 'Windows': # not supported on Windows
  111. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  112. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  113. def __exit__(self, exc_type, exc_val, exc_tb):
  114. if platform.system() != 'Windows':
  115. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  116. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  117. return True
  118. class WorkingDirectory(contextlib.ContextDecorator):
  119. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  120. def __init__(self, new_dir):
  121. self.dir = new_dir # new dir
  122. self.cwd = Path.cwd().resolve() # current dir
  123. def __enter__(self):
  124. os.chdir(self.dir)
  125. def __exit__(self, exc_type, exc_val, exc_tb):
  126. os.chdir(self.cwd)
  127. def try_except(func):
  128. # try-except function. Usage: @try_except decorator
  129. def handler(*args, **kwargs):
  130. try:
  131. func(*args, **kwargs)
  132. except Exception as e:
  133. print(e)
  134. return handler
  135. def methods(instance):
  136. # Get class/instance methods
  137. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  138. def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
  139. # Print function arguments (optional args dict)
  140. x = inspect.currentframe().f_back # previous frame
  141. file, _, fcn, _, _ = inspect.getframeinfo(x)
  142. if args is None: # get args automatically
  143. args, _, _, frm = inspect.getargvalues(x)
  144. args = {k: v for k, v in frm.items() if k in args}
  145. s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
  146. LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
  147. def init_seeds(seed=0):
  148. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  149. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  150. import torch.backends.cudnn as cudnn
  151. random.seed(seed)
  152. np.random.seed(seed)
  153. torch.manual_seed(seed)
  154. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  155. def intersect_dicts(da, db, exclude=()):
  156. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  157. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  158. def get_latest_run(search_dir='.'):
  159. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  160. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  161. return max(last_list, key=os.path.getctime) if last_list else ''
  162. def is_docker():
  163. # Is environment a Docker container?
  164. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  165. def is_colab():
  166. # Is environment a Google Colab instance?
  167. try:
  168. import google.colab
  169. return True
  170. except ImportError:
  171. return False
  172. def is_pip():
  173. # Is file in a pip package?
  174. return 'site-packages' in Path(__file__).resolve().parts
  175. def is_ascii(s=''):
  176. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  177. s = str(s) # convert list, tuple, None, etc. to str
  178. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  179. def is_chinese(s='人工智能'):
  180. # Is string composed of any Chinese characters?
  181. return True if re.search('[\u4e00-\u9fff]', str(s)) else False
  182. def emojis(str=''):
  183. # Return platform-dependent emoji-safe version of string
  184. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  185. def file_age(path=__file__):
  186. # Return days since last file update
  187. dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
  188. return dt.days # + dt.seconds / 86400 # fractional days
  189. def file_update_date(path=__file__):
  190. # Return human-readable file modification date, i.e. '2021-3-26'
  191. t = datetime.fromtimestamp(Path(path).stat().st_mtime)
  192. return f'{t.year}-{t.month}-{t.day}'
  193. def file_size(path):
  194. # Return file/dir size (MB)
  195. mb = 1 << 20 # bytes to MiB (1024 ** 2)
  196. path = Path(path)
  197. if path.is_file():
  198. return path.stat().st_size / mb
  199. elif path.is_dir():
  200. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
  201. else:
  202. return 0.0
  203. def check_online():
  204. # Check internet connectivity
  205. import socket
  206. try:
  207. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  208. return True
  209. except OSError:
  210. return False
  211. def git_describe(path=ROOT): # path must be a directory
  212. # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  213. try:
  214. assert (Path(path) / '.git').is_dir()
  215. return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
  216. except Exception:
  217. return ''
  218. @try_except
  219. @WorkingDirectory(ROOT)
  220. def check_git_status():
  221. # Recommend 'git pull' if code is out of date
  222. msg = ', for updates see https://github.com/ultralytics/yolov5'
  223. s = colorstr('github: ') # string
  224. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  225. assert not is_docker(), s + 'skipping check (Docker image)' + msg
  226. assert check_online(), s + 'skipping check (offline)' + msg
  227. cmd = 'git fetch && git config --get remote.origin.url'
  228. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  229. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  230. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  231. if n > 0:
  232. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  233. else:
  234. s += f'up to date with {url} ✅'
  235. LOGGER.info(emojis(s)) # emoji-safe
  236. def check_python(minimum='3.7.0'):
  237. # Check current python version vs. required python version
  238. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  239. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  240. # Check version vs. required version
  241. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  242. result = (current == minimum) if pinned else (current >= minimum) # bool
  243. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  244. if hard:
  245. assert result, s # assert min requirements met
  246. if verbose and not result:
  247. LOGGER.warning(s)
  248. return result
  249. @try_except
  250. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  251. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  252. prefix = colorstr('red', 'bold', 'requirements:')
  253. check_python() # check python version
  254. if isinstance(requirements, (str, Path)): # requirements.txt file
  255. file = Path(requirements)
  256. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  257. with file.open() as f:
  258. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  259. else: # list or tuple of packages
  260. requirements = [x for x in requirements if x not in exclude]
  261. n = 0 # number of packages updates
  262. for r in requirements:
  263. try:
  264. pkg.require(r)
  265. except Exception: # DistributionNotFound or VersionConflict if requirements not met
  266. s = f"{prefix} {r} not found and is required by YOLOv5"
  267. if install:
  268. LOGGER.info(f"{s}, attempting auto-update...")
  269. try:
  270. assert check_online(), f"'pip install {r}' skipped (offline)"
  271. LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode())
  272. n += 1
  273. except Exception as e:
  274. LOGGER.warning(f'{prefix} {e}')
  275. else:
  276. LOGGER.info(f'{s}. Please install and rerun your command.')
  277. if n: # if packages updated
  278. source = file.resolve() if 'file' in locals() else requirements
  279. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  280. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  281. LOGGER.info(emojis(s))
  282. def check_img_size(imgsz, s=32, floor=0):
  283. # Verify image size is a multiple of stride s in each dimension
  284. if isinstance(imgsz, int): # integer i.e. img_size=640
  285. new_size = max(make_divisible(imgsz, int(s)), floor)
  286. else: # list i.e. img_size=[640, 480]
  287. imgsz = list(imgsz) # convert to list if tuple
  288. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  289. if new_size != imgsz:
  290. LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  291. return new_size
  292. def check_imshow():
  293. # Check if environment supports image displays
  294. try:
  295. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  296. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  297. cv2.imshow('test', np.zeros((1, 1, 3)))
  298. cv2.waitKey(1)
  299. cv2.destroyAllWindows()
  300. cv2.waitKey(1)
  301. return True
  302. except Exception as e:
  303. LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  304. return False
  305. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  306. # Check file(s) for acceptable suffix
  307. if file and suffix:
  308. if isinstance(suffix, str):
  309. suffix = [suffix]
  310. for f in file if isinstance(file, (list, tuple)) else [file]:
  311. s = Path(f).suffix.lower() # file suffix
  312. if len(s):
  313. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  314. def check_yaml(file, suffix=('.yaml', '.yml')):
  315. # Search/download YAML file (if necessary) and return path, checking suffix
  316. return check_file(file, suffix)
  317. def check_file(file, suffix=''):
  318. # Search/download file (if necessary) and return path
  319. check_suffix(file, suffix) # optional
  320. file = str(file) # convert to str()
  321. if Path(file).is_file() or file == '': # exists
  322. return file
  323. elif file.startswith(('http:/', 'https:/')): # download
  324. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  325. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  326. if Path(file).is_file():
  327. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  328. else:
  329. LOGGER.info(f'Downloading {url} to {file}...')
  330. torch.hub.download_url_to_file(url, file)
  331. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  332. return file
  333. else: # search
  334. files = []
  335. for d in 'data', 'models', 'utils': # search directories
  336. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  337. assert len(files), f'File not found: {file}' # assert file was found
  338. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  339. return files[0] # return file
  340. def check_font(font=FONT):
  341. # Download font to CONFIG_DIR if necessary
  342. font = Path(font)
  343. if not font.exists() and not (CONFIG_DIR / font.name).exists():
  344. url = "https://ultralytics.com/assets/" + font.name
  345. LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...')
  346. torch.hub.download_url_to_file(url, str(font), progress=False)
  347. def check_dataset(data, autodownload=True):
  348. # Download and/or unzip dataset if not found locally
  349. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  350. # Download (optional)
  351. extract_dir = ''
  352. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  353. download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
  354. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  355. extract_dir, autodownload = data.parent, False
  356. # Read yaml (optional)
  357. if isinstance(data, (str, Path)):
  358. with open(data, errors='ignore') as f:
  359. data = yaml.safe_load(f) # dictionary
  360. # Resolve paths
  361. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  362. if not path.is_absolute():
  363. path = (ROOT / path).resolve()
  364. for k in 'train', 'val', 'test':
  365. if data.get(k): # prepend path
  366. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  367. # Parse yaml
  368. assert 'nc' in data, "Dataset 'nc' key missing."
  369. if 'names' not in data:
  370. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  371. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  372. if val:
  373. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  374. if not all(x.exists() for x in val):
  375. LOGGER.info(emojis('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()]))
  376. if s and autodownload: # download script
  377. t = time.time()
  378. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  379. if s.startswith('http') and s.endswith('.zip'): # URL
  380. f = Path(s).name # filename
  381. LOGGER.info(f'Downloading {s} to {f}...')
  382. torch.hub.download_url_to_file(s, f)
  383. Path(root).mkdir(parents=True, exist_ok=True) # create root
  384. ZipFile(f).extractall(path=root) # unzip
  385. Path(f).unlink() # remove zip
  386. r = None # success
  387. elif s.startswith('bash '): # bash script
  388. LOGGER.info(f'Running {s} ...')
  389. r = os.system(s)
  390. else: # python script
  391. r = exec(s, {'yaml': data}) # return None
  392. dt = f'({round(time.time() - t, 1)}s)'
  393. s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
  394. LOGGER.info(emojis(f"Dataset download {s}"))
  395. else:
  396. raise Exception(emojis('Dataset not found ❌'))
  397. return data # dictionary
  398. def url2file(url):
  399. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  400. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  401. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  402. return file
  403. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
  404. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  405. def download_one(url, dir):
  406. # Download 1 file
  407. success = True
  408. f = dir / Path(url).name # filename
  409. if Path(url).is_file(): # exists in current path
  410. Path(url).rename(f) # move to dir
  411. elif not f.exists():
  412. LOGGER.info(f'Downloading {url} to {f}...')
  413. for i in range(retry + 1):
  414. if curl:
  415. s = 'sS' if threads > 1 else '' # silent
  416. r = os.system(f"curl -{s}L '{url}' -o '{f}' --retry 9 -C -") # curl download
  417. success = r == 0
  418. else:
  419. torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
  420. success = f.is_file()
  421. if success:
  422. break
  423. elif i < retry:
  424. LOGGER.warning(f'Download failure, retrying {i + 1}/{retry} {url}...')
  425. else:
  426. LOGGER.warning(f'Failed to download {url}...')
  427. if unzip and success and f.suffix in ('.zip', '.gz'):
  428. LOGGER.info(f'Unzipping {f}...')
  429. if f.suffix == '.zip':
  430. ZipFile(f).extractall(path=dir) # unzip
  431. elif f.suffix == '.gz':
  432. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  433. if delete:
  434. f.unlink() # remove zip
  435. dir = Path(dir)
  436. dir.mkdir(parents=True, exist_ok=True) # make directory
  437. if threads > 1:
  438. pool = ThreadPool(threads)
  439. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  440. pool.close()
  441. pool.join()
  442. else:
  443. for u in [url] if isinstance(url, (str, Path)) else url:
  444. download_one(u, dir)
  445. def make_divisible(x, divisor):
  446. # Returns nearest x divisible by divisor
  447. if isinstance(divisor, torch.Tensor):
  448. divisor = int(divisor.max()) # to int
  449. return math.ceil(x / divisor) * divisor
  450. def clean_str(s):
  451. # Cleans a string by replacing special characters with underscore _
  452. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  453. def one_cycle(y1=0.0, y2=1.0, steps=100):
  454. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  455. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  456. def colorstr(*input):
  457. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  458. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  459. colors = {
  460. 'black': '\033[30m', # basic colors
  461. 'red': '\033[31m',
  462. 'green': '\033[32m',
  463. 'yellow': '\033[33m',
  464. 'blue': '\033[34m',
  465. 'magenta': '\033[35m',
  466. 'cyan': '\033[36m',
  467. 'white': '\033[37m',
  468. 'bright_black': '\033[90m', # bright colors
  469. 'bright_red': '\033[91m',
  470. 'bright_green': '\033[92m',
  471. 'bright_yellow': '\033[93m',
  472. 'bright_blue': '\033[94m',
  473. 'bright_magenta': '\033[95m',
  474. 'bright_cyan': '\033[96m',
  475. 'bright_white': '\033[97m',
  476. 'end': '\033[0m', # misc
  477. 'bold': '\033[1m',
  478. 'underline': '\033[4m'}
  479. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  480. def labels_to_class_weights(labels, nc=80):
  481. # Get class weights (inverse frequency) from training labels
  482. if labels[0] is None: # no labels loaded
  483. return torch.Tensor()
  484. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  485. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  486. weights = np.bincount(classes, minlength=nc) # occurrences per class
  487. # Prepend gridpoint count (for uCE training)
  488. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  489. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  490. weights[weights == 0] = 1 # replace empty bins with 1
  491. weights = 1 / weights # number of targets per class
  492. weights /= weights.sum() # normalize
  493. return torch.from_numpy(weights)
  494. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  495. # Produces image weights based on class_weights and image contents
  496. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  497. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  498. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  499. return image_weights
  500. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  501. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  502. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  503. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  504. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  505. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  506. x = [
  507. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  508. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  509. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  510. return x
  511. def xyxy2xywh(x):
  512. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  513. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  514. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  515. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  516. y[:, 2] = x[:, 2] - x[:, 0] # width
  517. y[:, 3] = x[:, 3] - x[:, 1] # height
  518. return y
  519. def xywh2xyxy(x):
  520. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  521. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  522. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  523. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  524. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  525. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  526. return y
  527. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  528. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  529. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  530. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  531. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  532. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  533. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  534. return y
  535. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  536. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  537. if clip:
  538. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  539. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  540. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  541. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  542. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  543. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  544. return y
  545. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  546. # Convert normalized segments into pixel segments, shape (n,2)
  547. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  548. y[:, 0] = w * x[:, 0] + padw # top left x
  549. y[:, 1] = h * x[:, 1] + padh # top left y
  550. return y
  551. def segment2box(segment, width=640, height=640):
  552. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  553. x, y = segment.T # segment xy
  554. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  555. x, y, = x[inside], y[inside]
  556. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  557. def segments2boxes(segments):
  558. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  559. boxes = []
  560. for s in segments:
  561. x, y = s.T # segment xy
  562. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  563. return xyxy2xywh(np.array(boxes)) # cls, xywh
  564. def resample_segments(segments, n=1000):
  565. # Up-sample an (n,2) segment
  566. for i, s in enumerate(segments):
  567. x = np.linspace(0, len(s) - 1, n)
  568. xp = np.arange(len(s))
  569. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  570. return segments
  571. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  572. # Rescale coords (xyxy) from img1_shape to img0_shape
  573. if ratio_pad is None: # calculate from img0_shape
  574. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  575. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  576. else:
  577. gain = ratio_pad[0][0]
  578. pad = ratio_pad[1]
  579. coords[:, [0, 2]] -= pad[0] # x padding
  580. coords[:, [1, 3]] -= pad[1] # y padding
  581. coords[:, :4] /= gain
  582. clip_coords(coords, img0_shape)
  583. return coords
  584. def clip_coords(boxes, shape):
  585. # Clip bounding xyxy bounding boxes to image shape (height, width)
  586. if isinstance(boxes, torch.Tensor): # faster individually
  587. boxes[:, 0].clamp_(0, shape[1]) # x1
  588. boxes[:, 1].clamp_(0, shape[0]) # y1
  589. boxes[:, 2].clamp_(0, shape[1]) # x2
  590. boxes[:, 3].clamp_(0, shape[0]) # y2
  591. else: # np.array (faster grouped)
  592. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  593. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  594. def non_max_suppression(prediction,
  595. conf_thres=0.25,
  596. iou_thres=0.45,
  597. classes=None,
  598. agnostic=False,
  599. multi_label=False,
  600. labels=(),
  601. max_det=300):
  602. """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes
  603. Returns:
  604. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  605. """
  606. bs = prediction.shape[0] # batch size
  607. nc = prediction.shape[2] - 5 # number of classes
  608. xc = prediction[..., 4] > conf_thres # candidates
  609. # Checks
  610. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  611. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  612. # Settings
  613. # min_wh = 2 # (pixels) minimum box width and height
  614. max_wh = 7680 # (pixels) maximum box width and height
  615. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  616. time_limit = 0.1 + 0.03 * bs # seconds to quit after
  617. redundant = True # require redundant detections
  618. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  619. merge = False # use merge-NMS
  620. t = time.time()
  621. output = [torch.zeros((0, 6), device=prediction.device)] * bs
  622. for xi, x in enumerate(prediction): # image index, image inference
  623. # Apply constraints
  624. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  625. x = x[xc[xi]] # confidence
  626. # Cat apriori labels if autolabelling
  627. if labels and len(labels[xi]):
  628. lb = labels[xi]
  629. v = torch.zeros((len(lb), nc + 5), device=x.device)
  630. v[:, :4] = lb[:, 1:5] # box
  631. v[:, 4] = 1.0 # conf
  632. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  633. x = torch.cat((x, v), 0)
  634. # If none remain process next image
  635. if not x.shape[0]:
  636. continue
  637. # Compute conf
  638. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  639. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  640. box = xywh2xyxy(x[:, :4])
  641. # Detections matrix nx6 (xyxy, conf, cls)
  642. if multi_label:
  643. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  644. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  645. else: # best class only
  646. conf, j = x[:, 5:].max(1, keepdim=True)
  647. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  648. # Filter by class
  649. if classes is not None:
  650. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  651. # Apply finite constraint
  652. # if not torch.isfinite(x).all():
  653. # x = x[torch.isfinite(x).all(1)]
  654. # Check shape
  655. n = x.shape[0] # number of boxes
  656. if not n: # no boxes
  657. continue
  658. elif n > max_nms: # excess boxes
  659. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  660. # Batched NMS
  661. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  662. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  663. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  664. if i.shape[0] > max_det: # limit detections
  665. i = i[:max_det]
  666. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  667. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  668. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  669. weights = iou * scores[None] # box weights
  670. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  671. if redundant:
  672. i = i[iou.sum(1) > 1] # require redundancy
  673. output[xi] = x[i]
  674. if (time.time() - t) > time_limit:
  675. LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
  676. break # time limit exceeded
  677. return output
  678. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  679. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  680. x = torch.load(f, map_location=torch.device('cpu'))
  681. if x.get('ema'):
  682. x['model'] = x['ema'] # replace model with ema
  683. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  684. x[k] = None
  685. x['epoch'] = -1
  686. x['model'].half() # to FP16
  687. for p in x['model'].parameters():
  688. p.requires_grad = False
  689. torch.save(x, s or f)
  690. mb = os.path.getsize(s or f) / 1E6 # filesize
  691. LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  692. def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  693. evolve_csv = save_dir / 'evolve.csv'
  694. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  695. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
  696. 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  697. keys = tuple(x.strip() for x in keys)
  698. vals = results + tuple(hyp.values())
  699. n = len(keys)
  700. # Download (optional)
  701. if bucket:
  702. url = f'gs://{bucket}/evolve.csv'
  703. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  704. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  705. # Log to evolve.csv
  706. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  707. with open(evolve_csv, 'a') as f:
  708. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  709. # Save yaml
  710. with open(evolve_yaml, 'w') as f:
  711. data = pd.read_csv(evolve_csv)
  712. data = data.rename(columns=lambda x: x.strip()) # strip keys
  713. i = np.argmax(fitness(data.values[:, :4])) #
  714. generations = len(data)
  715. f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
  716. f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
  717. '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  718. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  719. # Print to screen
  720. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
  721. ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
  722. for x in vals) + '\n\n')
  723. if bucket:
  724. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  725. def apply_classifier(x, model, img, im0):
  726. # Apply a second stage classifier to YOLO outputs
  727. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  728. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  729. for i, d in enumerate(x): # per image
  730. if d is not None and len(d):
  731. d = d.clone()
  732. # Reshape and pad cutouts
  733. b = xyxy2xywh(d[:, :4]) # boxes
  734. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  735. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  736. d[:, :4] = xywh2xyxy(b).long()
  737. # Rescale boxes from img_size to im0 size
  738. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  739. # Classes
  740. pred_cls1 = d[:, 5].long()
  741. ims = []
  742. for j, a in enumerate(d): # per item
  743. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  744. im = cv2.resize(cutout, (224, 224)) # BGR
  745. # cv2.imwrite('example%i.jpg' % j, cutout)
  746. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  747. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  748. im /= 255 # 0 - 255 to 0.0 - 1.0
  749. ims.append(im)
  750. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  751. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  752. return x
  753. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  754. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  755. path = Path(path) # os-agnostic
  756. if path.exists() and not exist_ok:
  757. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  758. dirs = glob.glob(f"{path}{sep}*") # similar paths
  759. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  760. i = [int(m.groups()[0]) for m in matches if m] # indices
  761. n = max(i) + 1 if i else 2 # increment number
  762. path = Path(f"{path}{sep}{n}{suffix}") # increment path
  763. if mkdir:
  764. path.mkdir(parents=True, exist_ok=True) # make directory
  765. return path
  766. # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
  767. imshow_ = cv2.imshow # copy to avoid recursion errors
  768. def imread(path, flags=cv2.IMREAD_COLOR):
  769. return cv2.imdecode(np.fromfile(path, np.uint8), flags)
  770. def imwrite(path, im):
  771. try:
  772. cv2.imencode(Path(path).suffix, im)[1].tofile(path)
  773. return True
  774. except Exception:
  775. return False
  776. def imshow(path, im):
  777. imshow_(path.encode('unicode_escape').decode(), im)
  778. cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
  779. # Variables ------------------------------------------------------------------------------------------------------------
  780. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm