You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

948 lines
38KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import inspect
  8. import logging
  9. import math
  10. import os
  11. import platform
  12. import random
  13. import re
  14. import shutil
  15. import signal
  16. import time
  17. import urllib
  18. from datetime import datetime
  19. from itertools import repeat
  20. from multiprocessing.pool import ThreadPool
  21. from pathlib import Path
  22. from subprocess import check_output
  23. from typing import Optional
  24. from zipfile import ZipFile
  25. import cv2
  26. import numpy as np
  27. import pandas as pd
  28. import pkg_resources as pkg
  29. import torch
  30. import torchvision
  31. import yaml
  32. from utils.downloads import gsutil_getsize
  33. from utils.metrics import box_iou, fitness
  34. # Settings
  35. FILE = Path(__file__).resolve()
  36. ROOT = FILE.parents[1] # YOLOv5 root directory
  37. DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
  38. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  39. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  40. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  41. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  42. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  43. pd.options.display.max_columns = 10
  44. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  45. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  46. os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy)
  47. def is_kaggle():
  48. # Is environment a Kaggle Notebook?
  49. try:
  50. assert os.environ.get('PWD') == '/kaggle/working'
  51. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  52. return True
  53. except AssertionError:
  54. return False
  55. def is_writeable(dir, test=False):
  56. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  57. if test: # method 1
  58. file = Path(dir) / 'tmp.txt'
  59. try:
  60. with open(file, 'w'): # open file with write permissions
  61. pass
  62. file.unlink() # remove file
  63. return True
  64. except OSError:
  65. return False
  66. else: # method 2
  67. return os.access(dir, os.R_OK) # possible issues on Windows
  68. def set_logging(name=None, verbose=VERBOSE):
  69. # Sets level and returns logger
  70. if is_kaggle():
  71. for h in logging.root.handlers:
  72. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  73. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  74. logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING)
  75. return logging.getLogger(name)
  76. LOGGER = set_logging('yolov5') # define globally (used in train.py, val.py, detect.py, etc.)
  77. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  78. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  79. env = os.getenv(env_var)
  80. if env:
  81. path = Path(env) # use environment variable
  82. else:
  83. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  84. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  85. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  86. path.mkdir(exist_ok=True) # make if required
  87. return path
  88. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  89. class Profile(contextlib.ContextDecorator):
  90. # Usage: @Profile() decorator or 'with Profile():' context manager
  91. def __enter__(self):
  92. self.start = time.time()
  93. def __exit__(self, type, value, traceback):
  94. print(f'Profile results: {time.time() - self.start:.5f}s')
  95. class Timeout(contextlib.ContextDecorator):
  96. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  97. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  98. self.seconds = int(seconds)
  99. self.timeout_message = timeout_msg
  100. self.suppress = bool(suppress_timeout_errors)
  101. def _timeout_handler(self, signum, frame):
  102. raise TimeoutError(self.timeout_message)
  103. def __enter__(self):
  104. if platform.system() != 'Windows': # not supported on Windows
  105. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  106. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  107. def __exit__(self, exc_type, exc_val, exc_tb):
  108. if platform.system() != 'Windows':
  109. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  110. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  111. return True
  112. class WorkingDirectory(contextlib.ContextDecorator):
  113. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  114. def __init__(self, new_dir):
  115. self.dir = new_dir # new dir
  116. self.cwd = Path.cwd().resolve() # current dir
  117. def __enter__(self):
  118. os.chdir(self.dir)
  119. def __exit__(self, exc_type, exc_val, exc_tb):
  120. os.chdir(self.cwd)
  121. def try_except(func):
  122. # try-except function. Usage: @try_except decorator
  123. def handler(*args, **kwargs):
  124. try:
  125. func(*args, **kwargs)
  126. except Exception as e:
  127. print(e)
  128. return handler
  129. def methods(instance):
  130. # Get class/instance methods
  131. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  132. def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
  133. # Print function arguments (optional args dict)
  134. x = inspect.currentframe().f_back # previous frame
  135. file, _, fcn, _, _ = inspect.getframeinfo(x)
  136. if args is None: # get args automatically
  137. args, _, _, frm = inspect.getargvalues(x)
  138. args = {k: v for k, v in frm.items() if k in args}
  139. s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
  140. LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
  141. def init_seeds(seed=0):
  142. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  143. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  144. import torch.backends.cudnn as cudnn
  145. random.seed(seed)
  146. np.random.seed(seed)
  147. torch.manual_seed(seed)
  148. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  149. def intersect_dicts(da, db, exclude=()):
  150. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  151. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  152. def get_latest_run(search_dir='.'):
  153. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  154. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  155. return max(last_list, key=os.path.getctime) if last_list else ''
  156. def is_docker():
  157. # Is environment a Docker container?
  158. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  159. def is_colab():
  160. # Is environment a Google Colab instance?
  161. try:
  162. import google.colab
  163. return True
  164. except ImportError:
  165. return False
  166. def is_pip():
  167. # Is file in a pip package?
  168. return 'site-packages' in Path(__file__).resolve().parts
  169. def is_ascii(s=''):
  170. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  171. s = str(s) # convert list, tuple, None, etc. to str
  172. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  173. def is_chinese(s='人工智能'):
  174. # Is string composed of any Chinese characters?
  175. return True if re.search('[\u4e00-\u9fff]', str(s)) else False
  176. def emojis(str=''):
  177. # Return platform-dependent emoji-safe version of string
  178. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  179. def file_age(path=__file__):
  180. # Return days since last file update
  181. dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
  182. return dt.days # + dt.seconds / 86400 # fractional days
  183. def file_update_date(path=__file__):
  184. # Return human-readable file modification date, i.e. '2021-3-26'
  185. t = datetime.fromtimestamp(Path(path).stat().st_mtime)
  186. return f'{t.year}-{t.month}-{t.day}'
  187. def file_size(path):
  188. # Return file/dir size (MB)
  189. mb = 1 << 20 # bytes to MiB (1024 ** 2)
  190. path = Path(path)
  191. if path.is_file():
  192. return path.stat().st_size / mb
  193. elif path.is_dir():
  194. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
  195. else:
  196. return 0.0
  197. def check_online():
  198. # Check internet connectivity
  199. import socket
  200. try:
  201. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  202. return True
  203. except OSError:
  204. return False
  205. def git_describe(path=ROOT): # path must be a directory
  206. # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  207. try:
  208. return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
  209. except Exception:
  210. return ''
  211. @try_except
  212. @WorkingDirectory(ROOT)
  213. def check_git_status():
  214. # Recommend 'git pull' if code is out of date
  215. msg = ', for updates see https://github.com/ultralytics/yolov5'
  216. s = colorstr('github: ') # string
  217. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  218. assert not is_docker(), s + 'skipping check (Docker image)' + msg
  219. assert check_online(), s + 'skipping check (offline)' + msg
  220. cmd = 'git fetch && git config --get remote.origin.url'
  221. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  222. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  223. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  224. if n > 0:
  225. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  226. else:
  227. s += f'up to date with {url} ✅'
  228. LOGGER.info(emojis(s)) # emoji-safe
  229. def check_python(minimum='3.7.0'):
  230. # Check current python version vs. required python version
  231. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  232. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  233. # Check version vs. required version
  234. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  235. result = (current == minimum) if pinned else (current >= minimum) # bool
  236. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  237. if hard:
  238. assert result, s # assert min requirements met
  239. if verbose and not result:
  240. LOGGER.warning(s)
  241. return result
  242. @try_except
  243. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  244. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  245. prefix = colorstr('red', 'bold', 'requirements:')
  246. check_python() # check python version
  247. if isinstance(requirements, (str, Path)): # requirements.txt file
  248. file = Path(requirements)
  249. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  250. with file.open() as f:
  251. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  252. else: # list or tuple of packages
  253. requirements = [x for x in requirements if x not in exclude]
  254. n = 0 # number of packages updates
  255. for r in requirements:
  256. try:
  257. pkg.require(r)
  258. except Exception: # DistributionNotFound or VersionConflict if requirements not met
  259. s = f"{prefix} {r} not found and is required by YOLOv5"
  260. if install:
  261. LOGGER.info(f"{s}, attempting auto-update...")
  262. try:
  263. assert check_online(), f"'pip install {r}' skipped (offline)"
  264. LOGGER.info(check_output(f"pip install '{r}'", shell=True).decode())
  265. n += 1
  266. except Exception as e:
  267. LOGGER.warning(f'{prefix} {e}')
  268. else:
  269. LOGGER.info(f'{s}. Please install and rerun your command.')
  270. if n: # if packages updated
  271. source = file.resolve() if 'file' in locals() else requirements
  272. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  273. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  274. LOGGER.info(emojis(s))
  275. def check_img_size(imgsz, s=32, floor=0):
  276. # Verify image size is a multiple of stride s in each dimension
  277. if isinstance(imgsz, int): # integer i.e. img_size=640
  278. new_size = max(make_divisible(imgsz, int(s)), floor)
  279. else: # list i.e. img_size=[640, 480]
  280. imgsz = list(imgsz) # convert to list if tuple
  281. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  282. if new_size != imgsz:
  283. LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  284. return new_size
  285. def check_imshow():
  286. # Check if environment supports image displays
  287. try:
  288. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  289. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  290. cv2.imshow('test', np.zeros((1, 1, 3)))
  291. cv2.waitKey(1)
  292. cv2.destroyAllWindows()
  293. cv2.waitKey(1)
  294. return True
  295. except Exception as e:
  296. LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  297. return False
  298. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  299. # Check file(s) for acceptable suffix
  300. if file and suffix:
  301. if isinstance(suffix, str):
  302. suffix = [suffix]
  303. for f in file if isinstance(file, (list, tuple)) else [file]:
  304. s = Path(f).suffix.lower() # file suffix
  305. if len(s):
  306. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  307. def check_yaml(file, suffix=('.yaml', '.yml')):
  308. # Search/download YAML file (if necessary) and return path, checking suffix
  309. return check_file(file, suffix)
  310. def check_file(file, suffix=''):
  311. # Search/download file (if necessary) and return path
  312. check_suffix(file, suffix) # optional
  313. file = str(file) # convert to str()
  314. if Path(file).is_file() or file == '': # exists
  315. return file
  316. elif file.startswith(('http:/', 'https:/')): # download
  317. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  318. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  319. if Path(file).is_file():
  320. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  321. else:
  322. LOGGER.info(f'Downloading {url} to {file}...')
  323. torch.hub.download_url_to_file(url, file)
  324. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  325. return file
  326. else: # search
  327. files = []
  328. for d in 'data', 'models', 'utils': # search directories
  329. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  330. assert len(files), f'File not found: {file}' # assert file was found
  331. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  332. return files[0] # return file
  333. def check_font(font=FONT):
  334. # Download font to CONFIG_DIR if necessary
  335. font = Path(font)
  336. if not font.exists() and not (CONFIG_DIR / font.name).exists():
  337. url = "https://ultralytics.com/assets/" + font.name
  338. LOGGER.info(f'Downloading {url} to {CONFIG_DIR / font.name}...')
  339. torch.hub.download_url_to_file(url, str(font), progress=False)
  340. def check_dataset(data, autodownload=True):
  341. # Download and/or unzip dataset if not found locally
  342. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  343. # Download (optional)
  344. extract_dir = ''
  345. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  346. download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
  347. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  348. extract_dir, autodownload = data.parent, False
  349. # Read yaml (optional)
  350. if isinstance(data, (str, Path)):
  351. with open(data, errors='ignore') as f:
  352. data = yaml.safe_load(f) # dictionary
  353. # Resolve paths
  354. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  355. if not path.is_absolute():
  356. path = (ROOT / path).resolve()
  357. for k in 'train', 'val', 'test':
  358. if data.get(k): # prepend path
  359. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  360. # Parse yaml
  361. assert 'nc' in data, "Dataset 'nc' key missing."
  362. if 'names' not in data:
  363. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  364. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  365. if val:
  366. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  367. if not all(x.exists() for x in val):
  368. LOGGER.info(emojis('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()]))
  369. if s and autodownload: # download script
  370. t = time.time()
  371. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  372. if s.startswith('http') and s.endswith('.zip'): # URL
  373. f = Path(s).name # filename
  374. LOGGER.info(f'Downloading {s} to {f}...')
  375. torch.hub.download_url_to_file(s, f)
  376. Path(root).mkdir(parents=True, exist_ok=True) # create root
  377. ZipFile(f).extractall(path=root) # unzip
  378. Path(f).unlink() # remove zip
  379. r = None # success
  380. elif s.startswith('bash '): # bash script
  381. LOGGER.info(f'Running {s} ...')
  382. r = os.system(s)
  383. else: # python script
  384. r = exec(s, {'yaml': data}) # return None
  385. dt = f'({round(time.time() - t, 1)}s)'
  386. s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
  387. LOGGER.info(emojis(f"Dataset download {s}"))
  388. else:
  389. raise Exception(emojis('Dataset not found ❌'))
  390. return data # dictionary
  391. def url2file(url):
  392. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  393. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  394. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  395. return file
  396. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  397. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  398. def download_one(url, dir):
  399. # Download 1 file
  400. f = dir / Path(url).name # filename
  401. if Path(url).is_file(): # exists in current path
  402. Path(url).rename(f) # move to dir
  403. elif not f.exists():
  404. LOGGER.info(f'Downloading {url} to {f}...')
  405. if curl:
  406. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  407. else:
  408. torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
  409. if unzip and f.suffix in ('.zip', '.gz'):
  410. LOGGER.info(f'Unzipping {f}...')
  411. if f.suffix == '.zip':
  412. ZipFile(f).extractall(path=dir) # unzip
  413. elif f.suffix == '.gz':
  414. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  415. if delete:
  416. f.unlink() # remove zip
  417. dir = Path(dir)
  418. dir.mkdir(parents=True, exist_ok=True) # make directory
  419. if threads > 1:
  420. pool = ThreadPool(threads)
  421. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  422. pool.close()
  423. pool.join()
  424. else:
  425. for u in [url] if isinstance(url, (str, Path)) else url:
  426. download_one(u, dir)
  427. def make_divisible(x, divisor):
  428. # Returns nearest x divisible by divisor
  429. if isinstance(divisor, torch.Tensor):
  430. divisor = int(divisor.max()) # to int
  431. return math.ceil(x / divisor) * divisor
  432. def clean_str(s):
  433. # Cleans a string by replacing special characters with underscore _
  434. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  435. def one_cycle(y1=0.0, y2=1.0, steps=100):
  436. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  437. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  438. def colorstr(*input):
  439. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  440. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  441. colors = {
  442. 'black': '\033[30m', # basic colors
  443. 'red': '\033[31m',
  444. 'green': '\033[32m',
  445. 'yellow': '\033[33m',
  446. 'blue': '\033[34m',
  447. 'magenta': '\033[35m',
  448. 'cyan': '\033[36m',
  449. 'white': '\033[37m',
  450. 'bright_black': '\033[90m', # bright colors
  451. 'bright_red': '\033[91m',
  452. 'bright_green': '\033[92m',
  453. 'bright_yellow': '\033[93m',
  454. 'bright_blue': '\033[94m',
  455. 'bright_magenta': '\033[95m',
  456. 'bright_cyan': '\033[96m',
  457. 'bright_white': '\033[97m',
  458. 'end': '\033[0m', # misc
  459. 'bold': '\033[1m',
  460. 'underline': '\033[4m'}
  461. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  462. def labels_to_class_weights(labels, nc=80):
  463. # Get class weights (inverse frequency) from training labels
  464. if labels[0] is None: # no labels loaded
  465. return torch.Tensor()
  466. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  467. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  468. weights = np.bincount(classes, minlength=nc) # occurrences per class
  469. # Prepend gridpoint count (for uCE training)
  470. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  471. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  472. weights[weights == 0] = 1 # replace empty bins with 1
  473. weights = 1 / weights # number of targets per class
  474. weights /= weights.sum() # normalize
  475. return torch.from_numpy(weights)
  476. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  477. # Produces image weights based on class_weights and image contents
  478. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  479. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  480. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  481. return image_weights
  482. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  483. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  484. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  485. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  486. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  487. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  488. x = [
  489. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  490. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  491. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  492. return x
  493. def xyxy2xywh(x):
  494. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  495. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  496. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  497. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  498. y[:, 2] = x[:, 2] - x[:, 0] # width
  499. y[:, 3] = x[:, 3] - x[:, 1] # height
  500. return y
  501. def xywh2xyxy(x):
  502. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  503. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  504. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  505. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  506. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  507. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  508. return y
  509. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  510. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  511. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  512. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  513. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  514. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  515. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  516. return y
  517. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  518. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  519. if clip:
  520. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  521. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  522. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  523. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  524. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  525. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  526. return y
  527. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  528. # Convert normalized segments into pixel segments, shape (n,2)
  529. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  530. y[:, 0] = w * x[:, 0] + padw # top left x
  531. y[:, 1] = h * x[:, 1] + padh # top left y
  532. return y
  533. def segment2box(segment, width=640, height=640):
  534. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  535. x, y = segment.T # segment xy
  536. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  537. x, y, = x[inside], y[inside]
  538. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  539. def segments2boxes(segments):
  540. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  541. boxes = []
  542. for s in segments:
  543. x, y = s.T # segment xy
  544. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  545. return xyxy2xywh(np.array(boxes)) # cls, xywh
  546. def resample_segments(segments, n=1000):
  547. # Up-sample an (n,2) segment
  548. for i, s in enumerate(segments):
  549. x = np.linspace(0, len(s) - 1, n)
  550. xp = np.arange(len(s))
  551. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  552. return segments
  553. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  554. # Rescale coords (xyxy) from img1_shape to img0_shape
  555. if ratio_pad is None: # calculate from img0_shape
  556. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  557. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  558. else:
  559. gain = ratio_pad[0][0]
  560. pad = ratio_pad[1]
  561. coords[:, [0, 2]] -= pad[0] # x padding
  562. coords[:, [1, 3]] -= pad[1] # y padding
  563. coords[:, :4] /= gain
  564. clip_coords(coords, img0_shape)
  565. return coords
  566. def clip_coords(boxes, shape):
  567. # Clip bounding xyxy bounding boxes to image shape (height, width)
  568. if isinstance(boxes, torch.Tensor): # faster individually
  569. boxes[:, 0].clamp_(0, shape[1]) # x1
  570. boxes[:, 1].clamp_(0, shape[0]) # y1
  571. boxes[:, 2].clamp_(0, shape[1]) # x2
  572. boxes[:, 3].clamp_(0, shape[0]) # y2
  573. else: # np.array (faster grouped)
  574. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  575. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  576. def non_max_suppression(prediction,
  577. conf_thres=0.25,
  578. iou_thres=0.45,
  579. classes=None,
  580. agnostic=False,
  581. multi_label=False,
  582. labels=(),
  583. max_det=300):
  584. """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes
  585. Returns:
  586. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  587. """
  588. bs = prediction.shape[0] # batch size
  589. nc = prediction.shape[2] - 5 # number of classes
  590. xc = prediction[..., 4] > conf_thres # candidates
  591. # Checks
  592. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  593. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  594. # Settings
  595. # min_wh = 2 # (pixels) minimum box width and height
  596. max_wh = 7680 # (pixels) maximum box width and height
  597. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  598. time_limit = 0.1 + 0.03 * bs # seconds to quit after
  599. redundant = True # require redundant detections
  600. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  601. merge = False # use merge-NMS
  602. t = time.time()
  603. output = [torch.zeros((0, 6), device=prediction.device)] * bs
  604. for xi, x in enumerate(prediction): # image index, image inference
  605. # Apply constraints
  606. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  607. x = x[xc[xi]] # confidence
  608. # Cat apriori labels if autolabelling
  609. if labels and len(labels[xi]):
  610. lb = labels[xi]
  611. v = torch.zeros((len(lb), nc + 5), device=x.device)
  612. v[:, :4] = lb[:, 1:5] # box
  613. v[:, 4] = 1.0 # conf
  614. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  615. x = torch.cat((x, v), 0)
  616. # If none remain process next image
  617. if not x.shape[0]:
  618. continue
  619. # Compute conf
  620. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  621. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  622. box = xywh2xyxy(x[:, :4])
  623. # Detections matrix nx6 (xyxy, conf, cls)
  624. if multi_label:
  625. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  626. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  627. else: # best class only
  628. conf, j = x[:, 5:].max(1, keepdim=True)
  629. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  630. # Filter by class
  631. if classes is not None:
  632. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  633. # Apply finite constraint
  634. # if not torch.isfinite(x).all():
  635. # x = x[torch.isfinite(x).all(1)]
  636. # Check shape
  637. n = x.shape[0] # number of boxes
  638. if not n: # no boxes
  639. continue
  640. elif n > max_nms: # excess boxes
  641. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  642. # Batched NMS
  643. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  644. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  645. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  646. if i.shape[0] > max_det: # limit detections
  647. i = i[:max_det]
  648. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  649. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  650. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  651. weights = iou * scores[None] # box weights
  652. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  653. if redundant:
  654. i = i[iou.sum(1) > 1] # require redundancy
  655. output[xi] = x[i]
  656. if (time.time() - t) > time_limit:
  657. LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
  658. break # time limit exceeded
  659. return output
  660. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  661. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  662. x = torch.load(f, map_location=torch.device('cpu'))
  663. if x.get('ema'):
  664. x['model'] = x['ema'] # replace model with ema
  665. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  666. x[k] = None
  667. x['epoch'] = -1
  668. x['model'].half() # to FP16
  669. for p in x['model'].parameters():
  670. p.requires_grad = False
  671. torch.save(x, s or f)
  672. mb = os.path.getsize(s or f) / 1E6 # filesize
  673. LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  674. def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  675. evolve_csv = save_dir / 'evolve.csv'
  676. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  677. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
  678. 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  679. keys = tuple(x.strip() for x in keys)
  680. vals = results + tuple(hyp.values())
  681. n = len(keys)
  682. # Download (optional)
  683. if bucket:
  684. url = f'gs://{bucket}/evolve.csv'
  685. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  686. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  687. # Log to evolve.csv
  688. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  689. with open(evolve_csv, 'a') as f:
  690. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  691. # Save yaml
  692. with open(evolve_yaml, 'w') as f:
  693. data = pd.read_csv(evolve_csv)
  694. data = data.rename(columns=lambda x: x.strip()) # strip keys
  695. i = np.argmax(fitness(data.values[:, :4])) #
  696. generations = len(data)
  697. f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
  698. f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
  699. '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  700. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  701. # Print to screen
  702. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
  703. ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
  704. for x in vals) + '\n\n')
  705. if bucket:
  706. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  707. def apply_classifier(x, model, img, im0):
  708. # Apply a second stage classifier to YOLO outputs
  709. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  710. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  711. for i, d in enumerate(x): # per image
  712. if d is not None and len(d):
  713. d = d.clone()
  714. # Reshape and pad cutouts
  715. b = xyxy2xywh(d[:, :4]) # boxes
  716. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  717. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  718. d[:, :4] = xywh2xyxy(b).long()
  719. # Rescale boxes from img_size to im0 size
  720. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  721. # Classes
  722. pred_cls1 = d[:, 5].long()
  723. ims = []
  724. for j, a in enumerate(d): # per item
  725. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  726. im = cv2.resize(cutout, (224, 224)) # BGR
  727. # cv2.imwrite('example%i.jpg' % j, cutout)
  728. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  729. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  730. im /= 255 # 0 - 255 to 0.0 - 1.0
  731. ims.append(im)
  732. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  733. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  734. return x
  735. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  736. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  737. path = Path(path) # os-agnostic
  738. if path.exists() and not exist_ok:
  739. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  740. dirs = glob.glob(f"{path}{sep}*") # similar paths
  741. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  742. i = [int(m.groups()[0]) for m in matches if m] # indices
  743. n = max(i) + 1 if i else 2 # increment number
  744. path = Path(f"{path}{sep}{n}{suffix}") # increment path
  745. if mkdir:
  746. path.mkdir(parents=True, exist_ok=True) # make directory
  747. return path
  748. # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
  749. imshow_ = cv2.imshow # copy to avoid recursion errors
  750. def imread(path):
  751. return cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_COLOR)
  752. def imwrite(path, im):
  753. try:
  754. cv2.imencode(Path(path).suffix, im)[1].tofile(path)
  755. return True
  756. except Exception:
  757. return False
  758. def imshow(path, im):
  759. imshow_(path.encode('unicode_escape').decode(), im)
  760. cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
  761. # Variables ------------------------------------------------------------------------------------------------------------
  762. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm