You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

966 lines
39KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import inspect
  8. import logging
  9. import math
  10. import os
  11. import platform
  12. import random
  13. import re
  14. import shutil
  15. import signal
  16. import time
  17. import urllib
  18. from datetime import datetime
  19. from itertools import repeat
  20. from multiprocessing.pool import ThreadPool
  21. from pathlib import Path
  22. from subprocess import check_output
  23. from typing import Optional
  24. from zipfile import ZipFile
  25. import cv2
  26. import numpy as np
  27. import pandas as pd
  28. import pkg_resources as pkg
  29. import torch
  30. import torchvision
  31. import yaml
  32. from utils.downloads import gsutil_getsize
  33. from utils.metrics import box_iou, fitness
  34. # Settings
  35. FILE = Path(__file__).resolve()
  36. ROOT = FILE.parents[1] # YOLOv5 root directory
  37. DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
  38. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  39. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  40. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  41. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  42. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  43. pd.options.display.max_columns = 10
  44. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  45. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  46. os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy)
  47. def is_kaggle():
  48. # Is environment a Kaggle Notebook?
  49. try:
  50. assert os.environ.get('PWD') == '/kaggle/working'
  51. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  52. return True
  53. except AssertionError:
  54. return False
  55. def is_writeable(dir, test=False):
  56. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  57. if test: # method 1
  58. file = Path(dir) / 'tmp.txt'
  59. try:
  60. with open(file, 'w'): # open file with write permissions
  61. pass
  62. file.unlink() # remove file
  63. return True
  64. except OSError:
  65. return False
  66. else: # method 2
  67. return os.access(dir, os.R_OK) # possible issues on Windows
  68. def set_logging(name=None, verbose=VERBOSE):
  69. # Sets level and returns logger
  70. if is_kaggle():
  71. for h in logging.root.handlers:
  72. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  73. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  74. level = logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING
  75. log = logging.getLogger(name)
  76. log.setLevel(level)
  77. handler = logging.StreamHandler()
  78. handler.setFormatter(logging.Formatter("%(message)s"))
  79. handler.setLevel(level)
  80. log.addHandler(handler)
  81. set_logging() # run before defining LOGGER
  82. LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
  83. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  84. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  85. env = os.getenv(env_var)
  86. if env:
  87. path = Path(env) # use environment variable
  88. else:
  89. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  90. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  91. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  92. path.mkdir(exist_ok=True) # make if required
  93. return path
  94. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  95. class Profile(contextlib.ContextDecorator):
  96. # Usage: @Profile() decorator or 'with Profile():' context manager
  97. def __enter__(self):
  98. self.start = time.time()
  99. def __exit__(self, type, value, traceback):
  100. print(f'Profile results: {time.time() - self.start:.5f}s')
  101. class Timeout(contextlib.ContextDecorator):
  102. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  103. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  104. self.seconds = int(seconds)
  105. self.timeout_message = timeout_msg
  106. self.suppress = bool(suppress_timeout_errors)
  107. def _timeout_handler(self, signum, frame):
  108. raise TimeoutError(self.timeout_message)
  109. def __enter__(self):
  110. if platform.system() != 'Windows': # not supported on Windows
  111. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  112. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  113. def __exit__(self, exc_type, exc_val, exc_tb):
  114. if platform.system() != 'Windows':
  115. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  116. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  117. return True
  118. class WorkingDirectory(contextlib.ContextDecorator):
  119. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  120. def __init__(self, new_dir):
  121. self.dir = new_dir # new dir
  122. self.cwd = Path.cwd().resolve() # current dir
  123. def __enter__(self):
  124. os.chdir(self.dir)
  125. def __exit__(self, exc_type, exc_val, exc_tb):
  126. os.chdir(self.cwd)
  127. def try_except(func):
  128. # try-except function. Usage: @try_except decorator
  129. def handler(*args, **kwargs):
  130. try:
  131. func(*args, **kwargs)
  132. except Exception as e:
  133. print(e)
  134. return handler
  135. def methods(instance):
  136. # Get class/instance methods
  137. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  138. def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
  139. # Print function arguments (optional args dict)
  140. x = inspect.currentframe().f_back # previous frame
  141. file, _, fcn, _, _ = inspect.getframeinfo(x)
  142. if args is None: # get args automatically
  143. args, _, _, frm = inspect.getargvalues(x)
  144. args = {k: v for k, v in frm.items() if k in args}
  145. s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
  146. LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
  147. def init_seeds(seed=0):
  148. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  149. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  150. import torch.backends.cudnn as cudnn
  151. random.seed(seed)
  152. np.random.seed(seed)
  153. torch.manual_seed(seed)
  154. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  155. def intersect_dicts(da, db, exclude=()):
  156. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  157. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  158. def get_latest_run(search_dir='.'):
  159. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  160. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  161. return max(last_list, key=os.path.getctime) if last_list else ''
  162. def is_docker():
  163. # Is environment a Docker container?
  164. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  165. def is_colab():
  166. # Is environment a Google Colab instance?
  167. try:
  168. import google.colab
  169. return True
  170. except ImportError:
  171. return False
  172. def is_pip():
  173. # Is file in a pip package?
  174. return 'site-packages' in Path(__file__).resolve().parts
  175. def is_ascii(s=''):
  176. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  177. s = str(s) # convert list, tuple, None, etc. to str
  178. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  179. def is_chinese(s='人工智能'):
  180. # Is string composed of any Chinese characters?
  181. return True if re.search('[\u4e00-\u9fff]', str(s)) else False
  182. def emojis(str=''):
  183. # Return platform-dependent emoji-safe version of string
  184. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  185. def file_age(path=__file__):
  186. # Return days since last file update
  187. dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
  188. return dt.days # + dt.seconds / 86400 # fractional days
  189. def file_update_date(path=__file__):
  190. # Return human-readable file modification date, i.e. '2021-3-26'
  191. t = datetime.fromtimestamp(Path(path).stat().st_mtime)
  192. return f'{t.year}-{t.month}-{t.day}'
  193. def file_size(path):
  194. # Return file/dir size (MB)
  195. mb = 1 << 20 # bytes to MiB (1024 ** 2)
  196. path = Path(path)
  197. if path.is_file():
  198. return path.stat().st_size / mb
  199. elif path.is_dir():
  200. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
  201. else:
  202. return 0.0
  203. def check_online():
  204. # Check internet connectivity
  205. import socket
  206. try:
  207. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  208. return True
  209. except OSError:
  210. return False
  211. def git_describe(path=ROOT): # path must be a directory
  212. # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  213. try:
  214. return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
  215. except Exception:
  216. return ''
  217. @try_except
  218. @WorkingDirectory(ROOT)
  219. def check_git_status():
  220. # Recommend 'git pull' if code is out of date
  221. msg = ', for updates see https://github.com/ultralytics/yolov5'
  222. s = colorstr('github: ') # string
  223. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  224. assert not is_docker(), s + 'skipping check (Docker image)' + msg
  225. assert check_online(), s + 'skipping check (offline)' + msg
  226. cmd = 'git fetch && git config --get remote.origin.url'
  227. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  228. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  229. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  230. if n > 0:
  231. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  232. else:
  233. s += f'up to date with {url} ✅'
  234. LOGGER.info(emojis(s)) # emoji-safe
  235. def check_python(minimum='3.7.0'):
  236. # Check current python version vs. required python version
  237. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  238. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  239. # Check version vs. required version
  240. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  241. result = (current == minimum) if pinned else (current >= minimum) # bool
  242. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  243. if hard:
  244. assert result, s # assert min requirements met
  245. if verbose and not result:
  246. LOGGER.warning(s)
  247. return result
  248. @try_except
  249. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  250. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  251. prefix = colorstr('red', 'bold', 'requirements:')
  252. check_python() # check python version
  253. if isinstance(requirements, (str, Path)): # requirements.txt file
  254. file = Path(requirements)
  255. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  256. with file.open() as f:
  257. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  258. else: # list or tuple of packages
  259. requirements = [x for x in requirements if x not in exclude]
  260. n = 0 # number of packages updates
  261. for r in requirements:
  262. try:
  263. pkg.require(r)
  264. except Exception: # DistributionNotFound or VersionConflict if requirements not met
  265. s = f"{prefix} {r} not found and is required by YOLOv5"
  266. if install:
  267. LOGGER.info(f"{s}, attempting auto-update...")
  268. try:
  269. assert check_online(), f"'pip install {r}' skipped (offline)"
  270. LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode())
  271. n += 1
  272. except Exception as e:
  273. LOGGER.warning(f'{prefix} {e}')
  274. else:
  275. LOGGER.info(f'{s}. Please install and rerun your command.')
  276. if n: # if packages updated
  277. source = file.resolve() if 'file' in locals() else requirements
  278. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  279. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  280. LOGGER.info(emojis(s))
  281. def check_img_size(imgsz, s=32, floor=0):
  282. # Verify image size is a multiple of stride s in each dimension
  283. if isinstance(imgsz, int): # integer i.e. img_size=640
  284. new_size = max(make_divisible(imgsz, int(s)), floor)
  285. else: # list i.e. img_size=[640, 480]
  286. imgsz = list(imgsz) # convert to list if tuple
  287. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  288. if new_size != imgsz:
  289. LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  290. return new_size
  291. def check_imshow():
  292. # Check if environment supports image displays
  293. try:
  294. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  295. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  296. cv2.imshow('test', np.zeros((1, 1, 3)))
  297. cv2.waitKey(1)
  298. cv2.destroyAllWindows()
  299. cv2.waitKey(1)
  300. return True
  301. except Exception as e:
  302. LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  303. return False
  304. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  305. # Check file(s) for acceptable suffix
  306. if file and suffix:
  307. if isinstance(suffix, str):
  308. suffix = [suffix]
  309. for f in file if isinstance(file, (list, tuple)) else [file]:
  310. s = Path(f).suffix.lower() # file suffix
  311. if len(s):
  312. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  313. def check_yaml(file, suffix=('.yaml', '.yml')):
  314. # Search/download YAML file (if necessary) and return path, checking suffix
  315. return check_file(file, suffix)
  316. def check_file(file, suffix=''):
  317. # Search/download file (if necessary) and return path
  318. check_suffix(file, suffix) # optional
  319. file = str(file) # convert to str()
  320. if Path(file).is_file() or file == '': # exists
  321. return file
  322. elif file.startswith(('http:/', 'https:/')): # download
  323. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  324. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  325. if Path(file).is_file():
  326. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  327. else:
  328. LOGGER.info(f'Downloading {url} to {file}...')
  329. torch.hub.download_url_to_file(url, file)
  330. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  331. return file
  332. else: # search
  333. files = []
  334. for d in 'data', 'models', 'utils': # search directories
  335. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  336. assert len(files), f'File not found: {file}' # assert file was found
  337. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  338. return files[0] # return file
  339. def check_font(font=FONT):
  340. # Download font to CONFIG_DIR if necessary
  341. font = Path(font)
  342. if not font.exists() and not (CONFIG_DIR / font.name).exists():
  343. url = "https://ultralytics.com/assets/" + font.name
  344. LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...')
  345. torch.hub.download_url_to_file(url, str(font), progress=False)
  346. def check_dataset(data, autodownload=True):
  347. # Download and/or unzip dataset if not found locally
  348. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  349. # Download (optional)
  350. extract_dir = ''
  351. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  352. download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
  353. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  354. extract_dir, autodownload = data.parent, False
  355. # Read yaml (optional)
  356. if isinstance(data, (str, Path)):
  357. with open(data, errors='ignore') as f:
  358. data = yaml.safe_load(f) # dictionary
  359. # Resolve paths
  360. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  361. if not path.is_absolute():
  362. path = (ROOT / path).resolve()
  363. for k in 'train', 'val', 'test':
  364. if data.get(k): # prepend path
  365. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  366. # Parse yaml
  367. assert 'nc' in data, "Dataset 'nc' key missing."
  368. if 'names' not in data:
  369. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  370. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  371. if val:
  372. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  373. if not all(x.exists() for x in val):
  374. LOGGER.info(emojis('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()]))
  375. if s and autodownload: # download script
  376. t = time.time()
  377. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  378. if s.startswith('http') and s.endswith('.zip'): # URL
  379. f = Path(s).name # filename
  380. LOGGER.info(f'Downloading {s} to {f}...')
  381. torch.hub.download_url_to_file(s, f)
  382. Path(root).mkdir(parents=True, exist_ok=True) # create root
  383. ZipFile(f).extractall(path=root) # unzip
  384. Path(f).unlink() # remove zip
  385. r = None # success
  386. elif s.startswith('bash '): # bash script
  387. LOGGER.info(f'Running {s} ...')
  388. r = os.system(s)
  389. else: # python script
  390. r = exec(s, {'yaml': data}) # return None
  391. dt = f'({round(time.time() - t, 1)}s)'
  392. s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
  393. LOGGER.info(emojis(f"Dataset download {s}"))
  394. else:
  395. raise Exception(emojis('Dataset not found ❌'))
  396. return data # dictionary
  397. def url2file(url):
  398. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  399. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  400. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  401. return file
  402. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
  403. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  404. def download_one(url, dir):
  405. # Download 1 file
  406. success = True
  407. f = dir / Path(url).name # filename
  408. if Path(url).is_file(): # exists in current path
  409. Path(url).rename(f) # move to dir
  410. elif not f.exists():
  411. LOGGER.info(f'Downloading {url} to {f}...')
  412. for i in range(retry + 1):
  413. if curl:
  414. s = 'sS' if threads > 1 else '' # silent
  415. r = os.system(f"curl -{s}L '{url}' -o '{f}' --retry 9 -C -") # curl download
  416. success = r == 0
  417. else:
  418. torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
  419. success = f.is_file()
  420. if success:
  421. break
  422. elif i < retry:
  423. LOGGER.warning(f'Download failure, retrying {i + 1}/{retry} {url}...')
  424. else:
  425. LOGGER.warning(f'Failed to download {url}...')
  426. if unzip and success and f.suffix in ('.zip', '.gz'):
  427. LOGGER.info(f'Unzipping {f}...')
  428. if f.suffix == '.zip':
  429. ZipFile(f).extractall(path=dir) # unzip
  430. elif f.suffix == '.gz':
  431. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  432. if delete:
  433. f.unlink() # remove zip
  434. dir = Path(dir)
  435. dir.mkdir(parents=True, exist_ok=True) # make directory
  436. if threads > 1:
  437. pool = ThreadPool(threads)
  438. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  439. pool.close()
  440. pool.join()
  441. else:
  442. for u in [url] if isinstance(url, (str, Path)) else url:
  443. download_one(u, dir)
  444. def make_divisible(x, divisor):
  445. # Returns nearest x divisible by divisor
  446. if isinstance(divisor, torch.Tensor):
  447. divisor = int(divisor.max()) # to int
  448. return math.ceil(x / divisor) * divisor
  449. def clean_str(s):
  450. # Cleans a string by replacing special characters with underscore _
  451. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  452. def one_cycle(y1=0.0, y2=1.0, steps=100):
  453. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  454. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  455. def colorstr(*input):
  456. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  457. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  458. colors = {
  459. 'black': '\033[30m', # basic colors
  460. 'red': '\033[31m',
  461. 'green': '\033[32m',
  462. 'yellow': '\033[33m',
  463. 'blue': '\033[34m',
  464. 'magenta': '\033[35m',
  465. 'cyan': '\033[36m',
  466. 'white': '\033[37m',
  467. 'bright_black': '\033[90m', # bright colors
  468. 'bright_red': '\033[91m',
  469. 'bright_green': '\033[92m',
  470. 'bright_yellow': '\033[93m',
  471. 'bright_blue': '\033[94m',
  472. 'bright_magenta': '\033[95m',
  473. 'bright_cyan': '\033[96m',
  474. 'bright_white': '\033[97m',
  475. 'end': '\033[0m', # misc
  476. 'bold': '\033[1m',
  477. 'underline': '\033[4m'}
  478. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  479. def labels_to_class_weights(labels, nc=80):
  480. # Get class weights (inverse frequency) from training labels
  481. if labels[0] is None: # no labels loaded
  482. return torch.Tensor()
  483. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  484. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  485. weights = np.bincount(classes, minlength=nc) # occurrences per class
  486. # Prepend gridpoint count (for uCE training)
  487. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  488. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  489. weights[weights == 0] = 1 # replace empty bins with 1
  490. weights = 1 / weights # number of targets per class
  491. weights /= weights.sum() # normalize
  492. return torch.from_numpy(weights)
  493. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  494. # Produces image weights based on class_weights and image contents
  495. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  496. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  497. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  498. return image_weights
  499. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  500. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  501. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  502. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  503. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  504. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  505. x = [
  506. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  507. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  508. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  509. return x
  510. def xyxy2xywh(x):
  511. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  512. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  513. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  514. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  515. y[:, 2] = x[:, 2] - x[:, 0] # width
  516. y[:, 3] = x[:, 3] - x[:, 1] # height
  517. return y
  518. def xywh2xyxy(x):
  519. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  520. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  521. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  522. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  523. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  524. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  525. return y
  526. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  527. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  528. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  529. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  530. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  531. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  532. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  533. return y
  534. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  535. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  536. if clip:
  537. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  538. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  539. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  540. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  541. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  542. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  543. return y
  544. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  545. # Convert normalized segments into pixel segments, shape (n,2)
  546. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  547. y[:, 0] = w * x[:, 0] + padw # top left x
  548. y[:, 1] = h * x[:, 1] + padh # top left y
  549. return y
  550. def segment2box(segment, width=640, height=640):
  551. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  552. x, y = segment.T # segment xy
  553. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  554. x, y, = x[inside], y[inside]
  555. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  556. def segments2boxes(segments):
  557. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  558. boxes = []
  559. for s in segments:
  560. x, y = s.T # segment xy
  561. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  562. return xyxy2xywh(np.array(boxes)) # cls, xywh
  563. def resample_segments(segments, n=1000):
  564. # Up-sample an (n,2) segment
  565. for i, s in enumerate(segments):
  566. x = np.linspace(0, len(s) - 1, n)
  567. xp = np.arange(len(s))
  568. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  569. return segments
  570. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  571. # Rescale coords (xyxy) from img1_shape to img0_shape
  572. if ratio_pad is None: # calculate from img0_shape
  573. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  574. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  575. else:
  576. gain = ratio_pad[0][0]
  577. pad = ratio_pad[1]
  578. coords[:, [0, 2]] -= pad[0] # x padding
  579. coords[:, [1, 3]] -= pad[1] # y padding
  580. coords[:, :4] /= gain
  581. clip_coords(coords, img0_shape)
  582. return coords
  583. def clip_coords(boxes, shape):
  584. # Clip bounding xyxy bounding boxes to image shape (height, width)
  585. if isinstance(boxes, torch.Tensor): # faster individually
  586. boxes[:, 0].clamp_(0, shape[1]) # x1
  587. boxes[:, 1].clamp_(0, shape[0]) # y1
  588. boxes[:, 2].clamp_(0, shape[1]) # x2
  589. boxes[:, 3].clamp_(0, shape[0]) # y2
  590. else: # np.array (faster grouped)
  591. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  592. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  593. def non_max_suppression(prediction,
  594. conf_thres=0.25,
  595. iou_thres=0.45,
  596. classes=None,
  597. agnostic=False,
  598. multi_label=False,
  599. labels=(),
  600. max_det=300):
  601. """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes
  602. Returns:
  603. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  604. """
  605. bs = prediction.shape[0] # batch size
  606. nc = prediction.shape[2] - 5 # number of classes
  607. xc = prediction[..., 4] > conf_thres # candidates
  608. # Checks
  609. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  610. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  611. # Settings
  612. # min_wh = 2 # (pixels) minimum box width and height
  613. max_wh = 7680 # (pixels) maximum box width and height
  614. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  615. time_limit = 0.1 + 0.03 * bs # seconds to quit after
  616. redundant = True # require redundant detections
  617. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  618. merge = False # use merge-NMS
  619. t = time.time()
  620. output = [torch.zeros((0, 6), device=prediction.device)] * bs
  621. for xi, x in enumerate(prediction): # image index, image inference
  622. # Apply constraints
  623. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  624. x = x[xc[xi]] # confidence
  625. # Cat apriori labels if autolabelling
  626. if labels and len(labels[xi]):
  627. lb = labels[xi]
  628. v = torch.zeros((len(lb), nc + 5), device=x.device)
  629. v[:, :4] = lb[:, 1:5] # box
  630. v[:, 4] = 1.0 # conf
  631. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  632. x = torch.cat((x, v), 0)
  633. # If none remain process next image
  634. if not x.shape[0]:
  635. continue
  636. # Compute conf
  637. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  638. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  639. box = xywh2xyxy(x[:, :4])
  640. # Detections matrix nx6 (xyxy, conf, cls)
  641. if multi_label:
  642. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  643. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  644. else: # best class only
  645. conf, j = x[:, 5:].max(1, keepdim=True)
  646. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  647. # Filter by class
  648. if classes is not None:
  649. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  650. # Apply finite constraint
  651. # if not torch.isfinite(x).all():
  652. # x = x[torch.isfinite(x).all(1)]
  653. # Check shape
  654. n = x.shape[0] # number of boxes
  655. if not n: # no boxes
  656. continue
  657. elif n > max_nms: # excess boxes
  658. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  659. # Batched NMS
  660. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  661. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  662. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  663. if i.shape[0] > max_det: # limit detections
  664. i = i[:max_det]
  665. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  666. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  667. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  668. weights = iou * scores[None] # box weights
  669. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  670. if redundant:
  671. i = i[iou.sum(1) > 1] # require redundancy
  672. output[xi] = x[i]
  673. if (time.time() - t) > time_limit:
  674. LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
  675. break # time limit exceeded
  676. return output
  677. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  678. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  679. x = torch.load(f, map_location=torch.device('cpu'))
  680. if x.get('ema'):
  681. x['model'] = x['ema'] # replace model with ema
  682. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  683. x[k] = None
  684. x['epoch'] = -1
  685. x['model'].half() # to FP16
  686. for p in x['model'].parameters():
  687. p.requires_grad = False
  688. torch.save(x, s or f)
  689. mb = os.path.getsize(s or f) / 1E6 # filesize
  690. LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  691. def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  692. evolve_csv = save_dir / 'evolve.csv'
  693. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  694. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
  695. 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  696. keys = tuple(x.strip() for x in keys)
  697. vals = results + tuple(hyp.values())
  698. n = len(keys)
  699. # Download (optional)
  700. if bucket:
  701. url = f'gs://{bucket}/evolve.csv'
  702. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  703. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  704. # Log to evolve.csv
  705. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  706. with open(evolve_csv, 'a') as f:
  707. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  708. # Save yaml
  709. with open(evolve_yaml, 'w') as f:
  710. data = pd.read_csv(evolve_csv)
  711. data = data.rename(columns=lambda x: x.strip()) # strip keys
  712. i = np.argmax(fitness(data.values[:, :4])) #
  713. generations = len(data)
  714. f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
  715. f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
  716. '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  717. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  718. # Print to screen
  719. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
  720. ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
  721. for x in vals) + '\n\n')
  722. if bucket:
  723. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  724. def apply_classifier(x, model, img, im0):
  725. # Apply a second stage classifier to YOLO outputs
  726. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  727. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  728. for i, d in enumerate(x): # per image
  729. if d is not None and len(d):
  730. d = d.clone()
  731. # Reshape and pad cutouts
  732. b = xyxy2xywh(d[:, :4]) # boxes
  733. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  734. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  735. d[:, :4] = xywh2xyxy(b).long()
  736. # Rescale boxes from img_size to im0 size
  737. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  738. # Classes
  739. pred_cls1 = d[:, 5].long()
  740. ims = []
  741. for j, a in enumerate(d): # per item
  742. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  743. im = cv2.resize(cutout, (224, 224)) # BGR
  744. # cv2.imwrite('example%i.jpg' % j, cutout)
  745. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  746. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  747. im /= 255 # 0 - 255 to 0.0 - 1.0
  748. ims.append(im)
  749. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  750. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  751. return x
  752. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  753. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  754. path = Path(path) # os-agnostic
  755. if path.exists() and not exist_ok:
  756. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  757. dirs = glob.glob(f"{path}{sep}*") # similar paths
  758. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  759. i = [int(m.groups()[0]) for m in matches if m] # indices
  760. n = max(i) + 1 if i else 2 # increment number
  761. path = Path(f"{path}{sep}{n}{suffix}") # increment path
  762. if mkdir:
  763. path.mkdir(parents=True, exist_ok=True) # make directory
  764. return path
  765. # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
  766. imshow_ = cv2.imshow # copy to avoid recursion errors
  767. def imread(path, flags=cv2.IMREAD_COLOR):
  768. return cv2.imdecode(np.fromfile(path, np.uint8), flags)
  769. def imwrite(path, im):
  770. try:
  771. cv2.imencode(Path(path).suffix, im)[1].tofile(path)
  772. return True
  773. except Exception:
  774. return False
  775. def imshow(path, im):
  776. imshow_(path.encode('unicode_escape').decode(), im)
  777. cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
  778. # Variables ------------------------------------------------------------------------------------------------------------
  779. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm