You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

981 lines
39KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import inspect
  8. import logging
  9. import math
  10. import os
  11. import platform
  12. import random
  13. import re
  14. import shutil
  15. import signal
  16. import time
  17. import urllib
  18. from datetime import datetime
  19. from itertools import repeat
  20. from multiprocessing.pool import ThreadPool
  21. from pathlib import Path
  22. from subprocess import check_output
  23. from typing import Optional
  24. from zipfile import ZipFile
  25. import cv2
  26. import numpy as np
  27. import pandas as pd
  28. import pkg_resources as pkg
  29. import torch
  30. import torchvision
  31. import yaml
  32. from utils.downloads import gsutil_getsize
  33. from utils.metrics import box_iou, fitness
  34. # Settings
  35. FILE = Path(__file__).resolve()
  36. ROOT = FILE.parents[1] # YOLOv5 root directory
  37. DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
  38. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  39. AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
  40. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  41. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  42. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  43. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  44. pd.options.display.max_columns = 10
  45. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  46. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  47. os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy)
  48. def is_kaggle():
  49. # Is environment a Kaggle Notebook?
  50. try:
  51. assert os.environ.get('PWD') == '/kaggle/working'
  52. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  53. return True
  54. except AssertionError:
  55. return False
  56. def is_writeable(dir, test=False):
  57. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  58. if test: # method 1
  59. file = Path(dir) / 'tmp.txt'
  60. try:
  61. with open(file, 'w'): # open file with write permissions
  62. pass
  63. file.unlink() # remove file
  64. return True
  65. except OSError:
  66. return False
  67. else: # method 2
  68. return os.access(dir, os.R_OK) # possible issues on Windows
  69. def set_logging(name=None, verbose=VERBOSE):
  70. # Sets level and returns logger
  71. if is_kaggle():
  72. for h in logging.root.handlers:
  73. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  74. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  75. level = logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING
  76. log = logging.getLogger(name)
  77. log.setLevel(level)
  78. handler = logging.StreamHandler()
  79. handler.setFormatter(logging.Formatter("%(message)s"))
  80. handler.setLevel(level)
  81. log.addHandler(handler)
  82. set_logging() # run before defining LOGGER
  83. LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
  84. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  85. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  86. env = os.getenv(env_var)
  87. if env:
  88. path = Path(env) # use environment variable
  89. else:
  90. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  91. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  92. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  93. path.mkdir(exist_ok=True) # make if required
  94. return path
  95. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  96. class Profile(contextlib.ContextDecorator):
  97. # Usage: @Profile() decorator or 'with Profile():' context manager
  98. def __enter__(self):
  99. self.start = time.time()
  100. def __exit__(self, type, value, traceback):
  101. print(f'Profile results: {time.time() - self.start:.5f}s')
  102. class Timeout(contextlib.ContextDecorator):
  103. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  104. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  105. self.seconds = int(seconds)
  106. self.timeout_message = timeout_msg
  107. self.suppress = bool(suppress_timeout_errors)
  108. def _timeout_handler(self, signum, frame):
  109. raise TimeoutError(self.timeout_message)
  110. def __enter__(self):
  111. if platform.system() != 'Windows': # not supported on Windows
  112. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  113. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  114. def __exit__(self, exc_type, exc_val, exc_tb):
  115. if platform.system() != 'Windows':
  116. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  117. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  118. return True
  119. class WorkingDirectory(contextlib.ContextDecorator):
  120. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  121. def __init__(self, new_dir):
  122. self.dir = new_dir # new dir
  123. self.cwd = Path.cwd().resolve() # current dir
  124. def __enter__(self):
  125. os.chdir(self.dir)
  126. def __exit__(self, exc_type, exc_val, exc_tb):
  127. os.chdir(self.cwd)
  128. def try_except(func):
  129. # try-except function. Usage: @try_except decorator
  130. def handler(*args, **kwargs):
  131. try:
  132. func(*args, **kwargs)
  133. except Exception as e:
  134. print(e)
  135. return handler
  136. def methods(instance):
  137. # Get class/instance methods
  138. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  139. def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
  140. # Print function arguments (optional args dict)
  141. x = inspect.currentframe().f_back # previous frame
  142. file, _, fcn, _, _ = inspect.getframeinfo(x)
  143. if args is None: # get args automatically
  144. args, _, _, frm = inspect.getargvalues(x)
  145. args = {k: v for k, v in frm.items() if k in args}
  146. s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
  147. LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
  148. def init_seeds(seed=0):
  149. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  150. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  151. import torch.backends.cudnn as cudnn
  152. random.seed(seed)
  153. np.random.seed(seed)
  154. torch.manual_seed(seed)
  155. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  156. def intersect_dicts(da, db, exclude=()):
  157. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  158. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  159. def get_latest_run(search_dir='.'):
  160. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  161. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  162. return max(last_list, key=os.path.getctime) if last_list else ''
  163. def is_docker():
  164. # Is environment a Docker container?
  165. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  166. def is_colab():
  167. # Is environment a Google Colab instance?
  168. try:
  169. import google.colab
  170. return True
  171. except ImportError:
  172. return False
  173. def is_pip():
  174. # Is file in a pip package?
  175. return 'site-packages' in Path(__file__).resolve().parts
  176. def is_ascii(s=''):
  177. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  178. s = str(s) # convert list, tuple, None, etc. to str
  179. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  180. def is_chinese(s='人工智能'):
  181. # Is string composed of any Chinese characters?
  182. return True if re.search('[\u4e00-\u9fff]', str(s)) else False
  183. def emojis(str=''):
  184. # Return platform-dependent emoji-safe version of string
  185. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  186. def file_age(path=__file__):
  187. # Return days since last file update
  188. dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
  189. return dt.days # + dt.seconds / 86400 # fractional days
  190. def file_update_date(path=__file__):
  191. # Return human-readable file modification date, i.e. '2021-3-26'
  192. t = datetime.fromtimestamp(Path(path).stat().st_mtime)
  193. return f'{t.year}-{t.month}-{t.day}'
  194. def file_size(path):
  195. # Return file/dir size (MB)
  196. mb = 1 << 20 # bytes to MiB (1024 ** 2)
  197. path = Path(path)
  198. if path.is_file():
  199. return path.stat().st_size / mb
  200. elif path.is_dir():
  201. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
  202. else:
  203. return 0.0
  204. def check_online():
  205. # Check internet connectivity
  206. import socket
  207. try:
  208. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  209. return True
  210. except OSError:
  211. return False
  212. def git_describe(path=ROOT): # path must be a directory
  213. # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  214. try:
  215. assert (Path(path) / '.git').is_dir()
  216. return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
  217. except Exception:
  218. return ''
  219. @try_except
  220. @WorkingDirectory(ROOT)
  221. def check_git_status():
  222. # Recommend 'git pull' if code is out of date
  223. msg = ', for updates see https://github.com/ultralytics/yolov5'
  224. s = colorstr('github: ') # string
  225. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  226. assert not is_docker(), s + 'skipping check (Docker image)' + msg
  227. assert check_online(), s + 'skipping check (offline)' + msg
  228. cmd = 'git fetch && git config --get remote.origin.url'
  229. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  230. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  231. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  232. if n > 0:
  233. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  234. else:
  235. s += f'up to date with {url} ✅'
  236. LOGGER.info(emojis(s)) # emoji-safe
  237. def check_python(minimum='3.7.0'):
  238. # Check current python version vs. required python version
  239. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  240. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  241. # Check version vs. required version
  242. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  243. result = (current == minimum) if pinned else (current >= minimum) # bool
  244. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  245. if hard:
  246. assert result, s # assert min requirements met
  247. if verbose and not result:
  248. LOGGER.warning(s)
  249. return result
  250. @try_except
  251. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
  252. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  253. prefix = colorstr('red', 'bold', 'requirements:')
  254. check_python() # check python version
  255. if isinstance(requirements, (str, Path)): # requirements.txt file
  256. file = Path(requirements)
  257. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  258. with file.open() as f:
  259. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  260. else: # list or tuple of packages
  261. requirements = [x for x in requirements if x not in exclude]
  262. n = 0 # number of packages updates
  263. for i, r in enumerate(requirements):
  264. try:
  265. pkg.require(r)
  266. except Exception: # DistributionNotFound or VersionConflict if requirements not met
  267. s = f"{prefix} {r} not found and is required by YOLOv5"
  268. if install and AUTOINSTALL: # check environment variable
  269. LOGGER.info(f"{s}, attempting auto-update...")
  270. try:
  271. assert check_online(), f"'pip install {r}' skipped (offline)"
  272. LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode())
  273. n += 1
  274. except Exception as e:
  275. LOGGER.warning(f'{prefix} {e}')
  276. else:
  277. LOGGER.info(f'{s}. Please install and rerun your command.')
  278. if n: # if packages updated
  279. source = file.resolve() if 'file' in locals() else requirements
  280. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  281. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  282. LOGGER.info(emojis(s))
  283. def check_img_size(imgsz, s=32, floor=0):
  284. # Verify image size is a multiple of stride s in each dimension
  285. if isinstance(imgsz, int): # integer i.e. img_size=640
  286. new_size = max(make_divisible(imgsz, int(s)), floor)
  287. else: # list i.e. img_size=[640, 480]
  288. imgsz = list(imgsz) # convert to list if tuple
  289. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  290. if new_size != imgsz:
  291. LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  292. return new_size
  293. def check_imshow():
  294. # Check if environment supports image displays
  295. try:
  296. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  297. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  298. cv2.imshow('test', np.zeros((1, 1, 3)))
  299. cv2.waitKey(1)
  300. cv2.destroyAllWindows()
  301. cv2.waitKey(1)
  302. return True
  303. except Exception as e:
  304. LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  305. return False
  306. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  307. # Check file(s) for acceptable suffix
  308. if file and suffix:
  309. if isinstance(suffix, str):
  310. suffix = [suffix]
  311. for f in file if isinstance(file, (list, tuple)) else [file]:
  312. s = Path(f).suffix.lower() # file suffix
  313. if len(s):
  314. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  315. def check_yaml(file, suffix=('.yaml', '.yml')):
  316. # Search/download YAML file (if necessary) and return path, checking suffix
  317. return check_file(file, suffix)
  318. def check_file(file, suffix=''):
  319. # Search/download file (if necessary) and return path
  320. check_suffix(file, suffix) # optional
  321. file = str(file) # convert to str()
  322. if Path(file).is_file() or file == '': # exists
  323. return file
  324. elif file.startswith(('http:/', 'https:/')): # download
  325. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  326. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  327. if Path(file).is_file():
  328. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  329. else:
  330. LOGGER.info(f'Downloading {url} to {file}...')
  331. torch.hub.download_url_to_file(url, file)
  332. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  333. return file
  334. else: # search
  335. files = []
  336. for d in 'data', 'models', 'utils': # search directories
  337. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  338. assert len(files), f'File not found: {file}' # assert file was found
  339. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  340. return files[0] # return file
  341. def check_font(font=FONT, progress=False):
  342. # Download font to CONFIG_DIR if necessary
  343. font = Path(font)
  344. file = CONFIG_DIR / font.name
  345. if not font.exists() and not file.exists():
  346. url = "https://ultralytics.com/assets/" + font.name
  347. LOGGER.info(f'Downloading {url} to {file}...')
  348. torch.hub.download_url_to_file(url, str(file), progress=progress)
  349. def check_dataset(data, autodownload=True):
  350. # Download and/or unzip dataset if not found locally
  351. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  352. # Download (optional)
  353. extract_dir = ''
  354. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  355. download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
  356. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  357. extract_dir, autodownload = data.parent, False
  358. # Read yaml (optional)
  359. if isinstance(data, (str, Path)):
  360. with open(data, errors='ignore') as f:
  361. data = yaml.safe_load(f) # dictionary
  362. # Resolve paths
  363. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  364. if not path.is_absolute():
  365. path = (ROOT / path).resolve()
  366. for k in 'train', 'val', 'test':
  367. if data.get(k): # prepend path
  368. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  369. # Parse yaml
  370. assert 'nc' in data, "Dataset 'nc' key missing."
  371. if 'names' not in data:
  372. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  373. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  374. if val:
  375. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  376. if not all(x.exists() for x in val):
  377. LOGGER.info(emojis('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()]))
  378. if s and autodownload: # download script
  379. t = time.time()
  380. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  381. if s.startswith('http') and s.endswith('.zip'): # URL
  382. f = Path(s).name # filename
  383. LOGGER.info(f'Downloading {s} to {f}...')
  384. torch.hub.download_url_to_file(s, f)
  385. Path(root).mkdir(parents=True, exist_ok=True) # create root
  386. ZipFile(f).extractall(path=root) # unzip
  387. Path(f).unlink() # remove zip
  388. r = None # success
  389. elif s.startswith('bash '): # bash script
  390. LOGGER.info(f'Running {s} ...')
  391. r = os.system(s)
  392. else: # python script
  393. r = exec(s, {'yaml': data}) # return None
  394. dt = f'({round(time.time() - t, 1)}s)'
  395. s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
  396. LOGGER.info(emojis(f"Dataset download {s}"))
  397. else:
  398. raise Exception(emojis('Dataset not found ❌'))
  399. check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
  400. return data # dictionary
  401. def url2file(url):
  402. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  403. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  404. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  405. return file
  406. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
  407. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  408. def download_one(url, dir):
  409. # Download 1 file
  410. success = True
  411. f = dir / Path(url).name # filename
  412. if Path(url).is_file(): # exists in current path
  413. Path(url).rename(f) # move to dir
  414. elif not f.exists():
  415. LOGGER.info(f'Downloading {url} to {f}...')
  416. for i in range(retry + 1):
  417. if curl:
  418. s = 'sS' if threads > 1 else '' # silent
  419. r = os.system(f"curl -{s}L '{url}' -o '{f}' --retry 9 -C -") # curl download
  420. success = r == 0
  421. else:
  422. torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
  423. success = f.is_file()
  424. if success:
  425. break
  426. elif i < retry:
  427. LOGGER.warning(f'Download failure, retrying {i + 1}/{retry} {url}...')
  428. else:
  429. LOGGER.warning(f'Failed to download {url}...')
  430. if unzip and success and f.suffix in ('.zip', '.gz'):
  431. LOGGER.info(f'Unzipping {f}...')
  432. if f.suffix == '.zip':
  433. ZipFile(f).extractall(path=dir) # unzip
  434. elif f.suffix == '.gz':
  435. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  436. if delete:
  437. f.unlink() # remove zip
  438. dir = Path(dir)
  439. dir.mkdir(parents=True, exist_ok=True) # make directory
  440. if threads > 1:
  441. pool = ThreadPool(threads)
  442. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  443. pool.close()
  444. pool.join()
  445. else:
  446. for u in [url] if isinstance(url, (str, Path)) else url:
  447. download_one(u, dir)
  448. def make_divisible(x, divisor):
  449. # Returns nearest x divisible by divisor
  450. if isinstance(divisor, torch.Tensor):
  451. divisor = int(divisor.max()) # to int
  452. return math.ceil(x / divisor) * divisor
  453. def clean_str(s):
  454. # Cleans a string by replacing special characters with underscore _
  455. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  456. def one_cycle(y1=0.0, y2=1.0, steps=100):
  457. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  458. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  459. def colorstr(*input):
  460. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  461. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  462. colors = {
  463. 'black': '\033[30m', # basic colors
  464. 'red': '\033[31m',
  465. 'green': '\033[32m',
  466. 'yellow': '\033[33m',
  467. 'blue': '\033[34m',
  468. 'magenta': '\033[35m',
  469. 'cyan': '\033[36m',
  470. 'white': '\033[37m',
  471. 'bright_black': '\033[90m', # bright colors
  472. 'bright_red': '\033[91m',
  473. 'bright_green': '\033[92m',
  474. 'bright_yellow': '\033[93m',
  475. 'bright_blue': '\033[94m',
  476. 'bright_magenta': '\033[95m',
  477. 'bright_cyan': '\033[96m',
  478. 'bright_white': '\033[97m',
  479. 'end': '\033[0m', # misc
  480. 'bold': '\033[1m',
  481. 'underline': '\033[4m'}
  482. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  483. def labels_to_class_weights(labels, nc=80):
  484. # Get class weights (inverse frequency) from training labels
  485. if labels[0] is None: # no labels loaded
  486. return torch.Tensor()
  487. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  488. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  489. weights = np.bincount(classes, minlength=nc) # occurrences per class
  490. # Prepend gridpoint count (for uCE training)
  491. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  492. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  493. weights[weights == 0] = 1 # replace empty bins with 1
  494. weights = 1 / weights # number of targets per class
  495. weights /= weights.sum() # normalize
  496. return torch.from_numpy(weights)
  497. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  498. # Produces image weights based on class_weights and image contents
  499. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  500. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  501. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  502. return image_weights
  503. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  504. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  505. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  506. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  507. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  508. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  509. x = [
  510. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  511. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  512. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  513. return x
  514. def xyxy2xywh(x):
  515. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  516. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  517. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  518. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  519. y[:, 2] = x[:, 2] - x[:, 0] # width
  520. y[:, 3] = x[:, 3] - x[:, 1] # height
  521. return y
  522. def xywh2xyxy(x):
  523. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  524. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  525. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  526. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  527. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  528. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  529. return y
  530. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  531. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  532. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  533. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  534. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  535. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  536. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  537. return y
  538. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  539. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  540. if clip:
  541. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  542. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  543. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  544. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  545. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  546. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  547. return y
  548. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  549. # Convert normalized segments into pixel segments, shape (n,2)
  550. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  551. y[:, 0] = w * x[:, 0] + padw # top left x
  552. y[:, 1] = h * x[:, 1] + padh # top left y
  553. return y
  554. def segment2box(segment, width=640, height=640):
  555. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  556. x, y = segment.T # segment xy
  557. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  558. x, y, = x[inside], y[inside]
  559. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  560. def segments2boxes(segments):
  561. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  562. boxes = []
  563. for s in segments:
  564. x, y = s.T # segment xy
  565. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  566. return xyxy2xywh(np.array(boxes)) # cls, xywh
  567. def resample_segments(segments, n=1000):
  568. # Up-sample an (n,2) segment
  569. for i, s in enumerate(segments):
  570. x = np.linspace(0, len(s) - 1, n)
  571. xp = np.arange(len(s))
  572. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  573. return segments
  574. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  575. # Rescale coords (xyxy) from img1_shape to img0_shape
  576. if ratio_pad is None: # calculate from img0_shape
  577. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  578. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  579. else:
  580. gain = ratio_pad[0][0]
  581. pad = ratio_pad[1]
  582. coords[:, [0, 2]] -= pad[0] # x padding
  583. coords[:, [1, 3]] -= pad[1] # y padding
  584. coords[:, :4] /= gain
  585. clip_coords(coords, img0_shape)
  586. return coords
  587. def clip_coords(boxes, shape):
  588. # Clip bounding xyxy bounding boxes to image shape (height, width)
  589. if isinstance(boxes, torch.Tensor): # faster individually
  590. boxes[:, 0].clamp_(0, shape[1]) # x1
  591. boxes[:, 1].clamp_(0, shape[0]) # y1
  592. boxes[:, 2].clamp_(0, shape[1]) # x2
  593. boxes[:, 3].clamp_(0, shape[0]) # y2
  594. else: # np.array (faster grouped)
  595. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  596. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  597. def non_max_suppression(prediction,
  598. conf_thres=0.25,
  599. iou_thres=0.45,
  600. classes=None,
  601. agnostic=False,
  602. multi_label=False,
  603. labels=(),
  604. max_det=300):
  605. """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes
  606. Returns:
  607. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  608. """
  609. bs = prediction.shape[0] # batch size
  610. nc = prediction.shape[2] - 5 # number of classes
  611. xc = prediction[..., 4] > conf_thres # candidates
  612. # Checks
  613. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  614. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  615. # Settings
  616. # min_wh = 2 # (pixels) minimum box width and height
  617. max_wh = 7680 # (pixels) maximum box width and height
  618. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  619. time_limit = 0.1 + 0.03 * bs # seconds to quit after
  620. redundant = True # require redundant detections
  621. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  622. merge = False # use merge-NMS
  623. t = time.time()
  624. output = [torch.zeros((0, 6), device=prediction.device)] * bs
  625. for xi, x in enumerate(prediction): # image index, image inference
  626. # Apply constraints
  627. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  628. x = x[xc[xi]] # confidence
  629. # Cat apriori labels if autolabelling
  630. if labels and len(labels[xi]):
  631. lb = labels[xi]
  632. v = torch.zeros((len(lb), nc + 5), device=x.device)
  633. v[:, :4] = lb[:, 1:5] # box
  634. v[:, 4] = 1.0 # conf
  635. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  636. x = torch.cat((x, v), 0)
  637. # If none remain process next image
  638. if not x.shape[0]:
  639. continue
  640. # Compute conf
  641. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  642. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  643. box = xywh2xyxy(x[:, :4])
  644. # Detections matrix nx6 (xyxy, conf, cls)
  645. if multi_label:
  646. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  647. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  648. else: # best class only
  649. conf, j = x[:, 5:].max(1, keepdim=True)
  650. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  651. # Filter by class
  652. if classes is not None:
  653. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  654. # Apply finite constraint
  655. # if not torch.isfinite(x).all():
  656. # x = x[torch.isfinite(x).all(1)]
  657. # Check shape
  658. n = x.shape[0] # number of boxes
  659. if not n: # no boxes
  660. continue
  661. elif n > max_nms: # excess boxes
  662. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  663. # Batched NMS
  664. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  665. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  666. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  667. if i.shape[0] > max_det: # limit detections
  668. i = i[:max_det]
  669. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  670. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  671. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  672. weights = iou * scores[None] # box weights
  673. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  674. if redundant:
  675. i = i[iou.sum(1) > 1] # require redundancy
  676. output[xi] = x[i]
  677. if (time.time() - t) > time_limit:
  678. LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
  679. break # time limit exceeded
  680. return output
  681. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  682. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  683. x = torch.load(f, map_location=torch.device('cpu'))
  684. if x.get('ema'):
  685. x['model'] = x['ema'] # replace model with ema
  686. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  687. x[k] = None
  688. x['epoch'] = -1
  689. x['model'].half() # to FP16
  690. for p in x['model'].parameters():
  691. p.requires_grad = False
  692. torch.save(x, s or f)
  693. mb = os.path.getsize(s or f) / 1E6 # filesize
  694. LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  695. def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  696. evolve_csv = save_dir / 'evolve.csv'
  697. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  698. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
  699. 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  700. keys = tuple(x.strip() for x in keys)
  701. vals = results + tuple(hyp.values())
  702. n = len(keys)
  703. # Download (optional)
  704. if bucket:
  705. url = f'gs://{bucket}/evolve.csv'
  706. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  707. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  708. # Log to evolve.csv
  709. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  710. with open(evolve_csv, 'a') as f:
  711. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  712. # Save yaml
  713. with open(evolve_yaml, 'w') as f:
  714. data = pd.read_csv(evolve_csv)
  715. data = data.rename(columns=lambda x: x.strip()) # strip keys
  716. i = np.argmax(fitness(data.values[:, :4])) #
  717. generations = len(data)
  718. f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
  719. f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
  720. '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  721. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  722. # Print to screen
  723. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
  724. ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
  725. for x in vals) + '\n\n')
  726. if bucket:
  727. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  728. def apply_classifier(x, model, img, im0):
  729. # Apply a second stage classifier to YOLO outputs
  730. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  731. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  732. for i, d in enumerate(x): # per image
  733. if d is not None and len(d):
  734. d = d.clone()
  735. # Reshape and pad cutouts
  736. b = xyxy2xywh(d[:, :4]) # boxes
  737. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  738. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  739. d[:, :4] = xywh2xyxy(b).long()
  740. # Rescale boxes from img_size to im0 size
  741. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  742. # Classes
  743. pred_cls1 = d[:, 5].long()
  744. ims = []
  745. for j, a in enumerate(d): # per item
  746. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  747. im = cv2.resize(cutout, (224, 224)) # BGR
  748. # cv2.imwrite('example%i.jpg' % j, cutout)
  749. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  750. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  751. im /= 255 # 0 - 255 to 0.0 - 1.0
  752. ims.append(im)
  753. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  754. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  755. return x
  756. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  757. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  758. path = Path(path) # os-agnostic
  759. if path.exists() and not exist_ok:
  760. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  761. # Method 1
  762. for n in range(2, 9999):
  763. p = f'{path}{sep}{n}{suffix}' # increment path
  764. if not os.path.exists(p): #
  765. break
  766. path = Path(p)
  767. # Method 2 (deprecated)
  768. # dirs = glob.glob(f"{path}{sep}*") # similar paths
  769. # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
  770. # i = [int(m.groups()[0]) for m in matches if m] # indices
  771. # n = max(i) + 1 if i else 2 # increment number
  772. # path = Path(f"{path}{sep}{n}{suffix}") # increment path
  773. if mkdir:
  774. path.mkdir(parents=True, exist_ok=True) # make directory
  775. return path
  776. # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
  777. imshow_ = cv2.imshow # copy to avoid recursion errors
  778. def imread(path, flags=cv2.IMREAD_COLOR):
  779. return cv2.imdecode(np.fromfile(path, np.uint8), flags)
  780. def imwrite(path, im):
  781. try:
  782. cv2.imencode(Path(path).suffix, im)[1].tofile(path)
  783. return True
  784. except Exception:
  785. return False
  786. def imshow(path, im):
  787. imshow_(path.encode('unicode_escape').decode(), im)
  788. cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
  789. # Variables ------------------------------------------------------------------------------------------------------------
  790. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm