You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1018 lines
41KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import inspect
  8. import logging
  9. import math
  10. import os
  11. import platform
  12. import random
  13. import re
  14. import shutil
  15. import signal
  16. import threading
  17. import time
  18. import urllib
  19. from datetime import datetime
  20. from itertools import repeat
  21. from multiprocessing.pool import ThreadPool
  22. from pathlib import Path
  23. from subprocess import check_output
  24. from typing import Optional
  25. from zipfile import ZipFile
  26. import cv2
  27. import numpy as np
  28. import pandas as pd
  29. import pkg_resources as pkg
  30. import torch
  31. import torchvision
  32. import yaml
  33. from utils.downloads import gsutil_getsize
  34. from utils.metrics import box_iou, fitness
  35. FILE = Path(__file__).resolve()
  36. ROOT = FILE.parents[1] # YOLOv5 root directory
  37. RANK = int(os.getenv('RANK', -1))
  38. # Settings
  39. DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
  40. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  41. AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
  42. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  43. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  44. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  45. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  46. pd.options.display.max_columns = 10
  47. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  48. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  49. os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy)
  50. def is_kaggle():
  51. # Is environment a Kaggle Notebook?
  52. try:
  53. assert os.environ.get('PWD') == '/kaggle/working'
  54. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  55. return True
  56. except AssertionError:
  57. return False
  58. def is_writeable(dir, test=False):
  59. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  60. if not test:
  61. return os.access(dir, os.R_OK) # possible issues on Windows
  62. file = Path(dir) / 'tmp.txt'
  63. try:
  64. with open(file, 'w'): # open file with write permissions
  65. pass
  66. file.unlink() # remove file
  67. return True
  68. except OSError:
  69. return False
  70. def set_logging(name=None, verbose=VERBOSE):
  71. # Sets level and returns logger
  72. if is_kaggle():
  73. for h in logging.root.handlers:
  74. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  75. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  76. level = logging.INFO if verbose and rank in {-1, 0} else logging.WARNING
  77. log = logging.getLogger(name)
  78. log.setLevel(level)
  79. handler = logging.StreamHandler()
  80. handler.setFormatter(logging.Formatter("%(message)s"))
  81. handler.setLevel(level)
  82. log.addHandler(handler)
  83. set_logging() # run before defining LOGGER
  84. LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
  85. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  86. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  87. env = os.getenv(env_var)
  88. if env:
  89. path = Path(env) # use environment variable
  90. else:
  91. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  92. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  93. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  94. path.mkdir(exist_ok=True) # make if required
  95. return path
  96. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  97. class Profile(contextlib.ContextDecorator):
  98. # Usage: @Profile() decorator or 'with Profile():' context manager
  99. def __enter__(self):
  100. self.start = time.time()
  101. def __exit__(self, type, value, traceback):
  102. print(f'Profile results: {time.time() - self.start:.5f}s')
  103. class Timeout(contextlib.ContextDecorator):
  104. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  105. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  106. self.seconds = int(seconds)
  107. self.timeout_message = timeout_msg
  108. self.suppress = bool(suppress_timeout_errors)
  109. def _timeout_handler(self, signum, frame):
  110. raise TimeoutError(self.timeout_message)
  111. def __enter__(self):
  112. if platform.system() != 'Windows': # not supported on Windows
  113. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  114. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  115. def __exit__(self, exc_type, exc_val, exc_tb):
  116. if platform.system() != 'Windows':
  117. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  118. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  119. return True
  120. class WorkingDirectory(contextlib.ContextDecorator):
  121. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  122. def __init__(self, new_dir):
  123. self.dir = new_dir # new dir
  124. self.cwd = Path.cwd().resolve() # current dir
  125. def __enter__(self):
  126. os.chdir(self.dir)
  127. def __exit__(self, exc_type, exc_val, exc_tb):
  128. os.chdir(self.cwd)
  129. def try_except(func):
  130. # try-except function. Usage: @try_except decorator
  131. def handler(*args, **kwargs):
  132. try:
  133. func(*args, **kwargs)
  134. except Exception as e:
  135. print(e)
  136. return handler
  137. def threaded(func):
  138. # Multi-threads a target function and returns thread. Usage: @threaded decorator
  139. def wrapper(*args, **kwargs):
  140. thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
  141. thread.start()
  142. return thread
  143. return wrapper
  144. def methods(instance):
  145. # Get class/instance methods
  146. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  147. def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
  148. # Print function arguments (optional args dict)
  149. x = inspect.currentframe().f_back # previous frame
  150. file, _, fcn, _, _ = inspect.getframeinfo(x)
  151. if args is None: # get args automatically
  152. args, _, _, frm = inspect.getargvalues(x)
  153. args = {k: v for k, v in frm.items() if k in args}
  154. s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
  155. LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
  156. def init_seeds(seed=0):
  157. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  158. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  159. import torch.backends.cudnn as cudnn
  160. random.seed(seed)
  161. np.random.seed(seed)
  162. torch.manual_seed(seed)
  163. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  164. def intersect_dicts(da, db, exclude=()):
  165. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  166. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  167. def get_latest_run(search_dir='.'):
  168. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  169. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  170. return max(last_list, key=os.path.getctime) if last_list else ''
  171. def is_docker():
  172. # Is environment a Docker container?
  173. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  174. def is_colab():
  175. # Is environment a Google Colab instance?
  176. try:
  177. import google.colab
  178. return True
  179. except ImportError:
  180. return False
  181. def is_pip():
  182. # Is file in a pip package?
  183. return 'site-packages' in Path(__file__).resolve().parts
  184. def is_ascii(s=''):
  185. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  186. s = str(s) # convert list, tuple, None, etc. to str
  187. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  188. def is_chinese(s='人工智能'):
  189. # Is string composed of any Chinese characters?
  190. return bool(re.search('[\u4e00-\u9fff]', str(s)))
  191. def emojis(str=''):
  192. # Return platform-dependent emoji-safe version of string
  193. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  194. def file_age(path=__file__):
  195. # Return days since last file update
  196. dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
  197. return dt.days # + dt.seconds / 86400 # fractional days
  198. def file_date(path=__file__):
  199. # Return human-readable file modification date, i.e. '2021-3-26'
  200. t = datetime.fromtimestamp(Path(path).stat().st_mtime)
  201. return f'{t.year}-{t.month}-{t.day}'
  202. def file_size(path):
  203. # Return file/dir size (MB)
  204. mb = 1 << 20 # bytes to MiB (1024 ** 2)
  205. path = Path(path)
  206. if path.is_file():
  207. return path.stat().st_size / mb
  208. elif path.is_dir():
  209. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
  210. else:
  211. return 0.0
  212. def check_online():
  213. # Check internet connectivity
  214. import socket
  215. try:
  216. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  217. return True
  218. except OSError:
  219. return False
  220. def git_describe(path=ROOT): # path must be a directory
  221. # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  222. try:
  223. assert (Path(path) / '.git').is_dir()
  224. return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
  225. except Exception:
  226. return ''
  227. @try_except
  228. @WorkingDirectory(ROOT)
  229. def check_git_status():
  230. # Recommend 'git pull' if code is out of date
  231. msg = ', for updates see https://github.com/ultralytics/yolov5'
  232. s = colorstr('github: ') # string
  233. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  234. assert not is_docker(), s + 'skipping check (Docker image)' + msg
  235. assert check_online(), s + 'skipping check (offline)' + msg
  236. cmd = 'git fetch && git config --get remote.origin.url'
  237. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  238. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  239. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  240. if n > 0:
  241. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  242. else:
  243. s += f'up to date with {url} ✅'
  244. LOGGER.info(emojis(s)) # emoji-safe
  245. def check_python(minimum='3.7.0'):
  246. # Check current python version vs. required python version
  247. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  248. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  249. # Check version vs. required version
  250. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  251. result = (current == minimum) if pinned else (current >= minimum) # bool
  252. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  253. if hard:
  254. assert result, s # assert min requirements met
  255. if verbose and not result:
  256. LOGGER.warning(s)
  257. return result
  258. @try_except
  259. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
  260. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  261. prefix = colorstr('red', 'bold', 'requirements:')
  262. check_python() # check python version
  263. if isinstance(requirements, (str, Path)): # requirements.txt file
  264. file = Path(requirements)
  265. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  266. with file.open() as f:
  267. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  268. else: # list or tuple of packages
  269. requirements = [x for x in requirements if x not in exclude]
  270. n = 0 # number of packages updates
  271. for i, r in enumerate(requirements):
  272. try:
  273. pkg.require(r)
  274. except Exception: # DistributionNotFound or VersionConflict if requirements not met
  275. s = f"{prefix} {r} not found and is required by YOLOv5"
  276. if install and AUTOINSTALL: # check environment variable
  277. LOGGER.info(f"{s}, attempting auto-update...")
  278. try:
  279. assert check_online(), f"'pip install {r}' skipped (offline)"
  280. LOGGER.info(check_output(f'pip install "{r}" {cmds[i] if cmds else ""}', shell=True).decode())
  281. n += 1
  282. except Exception as e:
  283. LOGGER.warning(f'{prefix} {e}')
  284. else:
  285. LOGGER.info(f'{s}. Please install and rerun your command.')
  286. if n: # if packages updated
  287. source = file.resolve() if 'file' in locals() else requirements
  288. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  289. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  290. LOGGER.info(emojis(s))
  291. def check_img_size(imgsz, s=32, floor=0):
  292. # Verify image size is a multiple of stride s in each dimension
  293. if isinstance(imgsz, int): # integer i.e. img_size=640
  294. new_size = max(make_divisible(imgsz, int(s)), floor)
  295. else: # list i.e. img_size=[640, 480]
  296. imgsz = list(imgsz) # convert to list if tuple
  297. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  298. if new_size != imgsz:
  299. LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  300. return new_size
  301. def check_imshow():
  302. # Check if environment supports image displays
  303. try:
  304. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  305. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  306. cv2.imshow('test', np.zeros((1, 1, 3)))
  307. cv2.waitKey(1)
  308. cv2.destroyAllWindows()
  309. cv2.waitKey(1)
  310. return True
  311. except Exception as e:
  312. LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  313. return False
  314. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  315. # Check file(s) for acceptable suffix
  316. if file and suffix:
  317. if isinstance(suffix, str):
  318. suffix = [suffix]
  319. for f in file if isinstance(file, (list, tuple)) else [file]:
  320. s = Path(f).suffix.lower() # file suffix
  321. if len(s):
  322. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  323. def check_yaml(file, suffix=('.yaml', '.yml')):
  324. # Search/download YAML file (if necessary) and return path, checking suffix
  325. return check_file(file, suffix)
  326. def check_file(file, suffix=''):
  327. # Search/download file (if necessary) and return path
  328. check_suffix(file, suffix) # optional
  329. file = str(file) # convert to str()
  330. if Path(file).is_file() or not file: # exists
  331. return file
  332. elif file.startswith(('http:/', 'https:/')): # download
  333. url = file # warning: Pathlib turns :// -> :/
  334. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  335. if Path(file).is_file():
  336. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  337. else:
  338. LOGGER.info(f'Downloading {url} to {file}...')
  339. torch.hub.download_url_to_file(url, file)
  340. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  341. return file
  342. else: # search
  343. files = []
  344. for d in 'data', 'models', 'utils': # search directories
  345. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  346. assert len(files), f'File not found: {file}' # assert file was found
  347. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  348. return files[0] # return file
  349. def check_font(font=FONT, progress=False):
  350. # Download font to CONFIG_DIR if necessary
  351. font = Path(font)
  352. file = CONFIG_DIR / font.name
  353. if not font.exists() and not file.exists():
  354. url = "https://ultralytics.com/assets/" + font.name
  355. LOGGER.info(f'Downloading {url} to {file}...')
  356. torch.hub.download_url_to_file(url, str(file), progress=progress)
  357. def check_dataset(data, autodownload=True):
  358. # Download, check and/or unzip dataset if not found locally
  359. # Download (optional)
  360. extract_dir = ''
  361. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  362. download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
  363. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  364. extract_dir, autodownload = data.parent, False
  365. # Read yaml (optional)
  366. if isinstance(data, (str, Path)):
  367. with open(data, errors='ignore') as f:
  368. data = yaml.safe_load(f) # dictionary
  369. # Checks
  370. for k in 'train', 'val', 'nc':
  371. assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
  372. if 'names' not in data:
  373. LOGGER.warning(emojis("data.yaml 'names:' field missing ⚠, assigning default names 'class0', 'class1', etc."))
  374. data['names'] = [f'class{i}' for i in range(data['nc'])] # default names
  375. # Resolve paths
  376. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  377. if not path.is_absolute():
  378. path = (ROOT / path).resolve()
  379. for k in 'train', 'val', 'test':
  380. if data.get(k): # prepend path
  381. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  382. # Parse yaml
  383. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  384. if val:
  385. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  386. if not all(x.exists() for x in val):
  387. LOGGER.info(emojis('\nDataset not found ⚠, missing paths %s' % [str(x) for x in val if not x.exists()]))
  388. if not s or not autodownload:
  389. raise Exception(emojis('Dataset not found ❌'))
  390. t = time.time()
  391. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  392. if s.startswith('http') and s.endswith('.zip'): # URL
  393. f = Path(s).name # filename
  394. LOGGER.info(f'Downloading {s} to {f}...')
  395. torch.hub.download_url_to_file(s, f)
  396. Path(root).mkdir(parents=True, exist_ok=True) # create root
  397. ZipFile(f).extractall(path=root) # unzip
  398. Path(f).unlink() # remove zip
  399. r = None # success
  400. elif s.startswith('bash '): # bash script
  401. LOGGER.info(f'Running {s} ...')
  402. r = os.system(s)
  403. else: # python script
  404. r = exec(s, {'yaml': data}) # return None
  405. dt = f'({round(time.time() - t, 1)}s)'
  406. s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
  407. LOGGER.info(emojis(f"Dataset download {s}"))
  408. check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
  409. return data # dictionary
  410. def check_amp(model):
  411. # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
  412. from models.common import AutoShape, DetectMultiBackend
  413. def amp_allclose(model, im):
  414. # All close FP32 vs AMP results
  415. m = AutoShape(model, verbose=False) # model
  416. a = m(im).xywhn[0] # FP32 inference
  417. m.amp = True
  418. b = m(im).xywhn[0] # AMP inference
  419. return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
  420. prefix = colorstr('AMP: ')
  421. device = next(model.parameters()).device # get model device
  422. if device.type == 'cpu':
  423. return False # AMP disabled on CPU
  424. f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
  425. im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
  426. try:
  427. assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
  428. LOGGER.info(emojis(f'{prefix}checks passed ✅'))
  429. return True
  430. except Exception:
  431. help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
  432. LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
  433. return False
  434. def url2file(url):
  435. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  436. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  437. return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  438. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
  439. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  440. def download_one(url, dir):
  441. # Download 1 file
  442. success = True
  443. f = dir / Path(url).name # filename
  444. if Path(url).is_file(): # exists in current path
  445. Path(url).rename(f) # move to dir
  446. elif not f.exists():
  447. LOGGER.info(f'Downloading {url} to {f}...')
  448. for i in range(retry + 1):
  449. if curl:
  450. s = 'sS' if threads > 1 else '' # silent
  451. r = os.system(f'curl -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
  452. success = r == 0
  453. else:
  454. torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
  455. success = f.is_file()
  456. if success:
  457. break
  458. elif i < retry:
  459. LOGGER.warning(f'Download failure, retrying {i + 1}/{retry} {url}...')
  460. else:
  461. LOGGER.warning(f'Failed to download {url}...')
  462. if unzip and success and f.suffix in ('.zip', '.gz'):
  463. LOGGER.info(f'Unzipping {f}...')
  464. if f.suffix == '.zip':
  465. ZipFile(f).extractall(path=dir) # unzip
  466. elif f.suffix == '.gz':
  467. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  468. if delete:
  469. f.unlink() # remove zip
  470. dir = Path(dir)
  471. dir.mkdir(parents=True, exist_ok=True) # make directory
  472. if threads > 1:
  473. pool = ThreadPool(threads)
  474. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  475. pool.close()
  476. pool.join()
  477. else:
  478. for u in [url] if isinstance(url, (str, Path)) else url:
  479. download_one(u, dir)
  480. def make_divisible(x, divisor):
  481. # Returns nearest x divisible by divisor
  482. if isinstance(divisor, torch.Tensor):
  483. divisor = int(divisor.max()) # to int
  484. return math.ceil(x / divisor) * divisor
  485. def clean_str(s):
  486. # Cleans a string by replacing special characters with underscore _
  487. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  488. def one_cycle(y1=0.0, y2=1.0, steps=100):
  489. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  490. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  491. def colorstr(*input):
  492. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  493. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  494. colors = {
  495. 'black': '\033[30m', # basic colors
  496. 'red': '\033[31m',
  497. 'green': '\033[32m',
  498. 'yellow': '\033[33m',
  499. 'blue': '\033[34m',
  500. 'magenta': '\033[35m',
  501. 'cyan': '\033[36m',
  502. 'white': '\033[37m',
  503. 'bright_black': '\033[90m', # bright colors
  504. 'bright_red': '\033[91m',
  505. 'bright_green': '\033[92m',
  506. 'bright_yellow': '\033[93m',
  507. 'bright_blue': '\033[94m',
  508. 'bright_magenta': '\033[95m',
  509. 'bright_cyan': '\033[96m',
  510. 'bright_white': '\033[97m',
  511. 'end': '\033[0m', # misc
  512. 'bold': '\033[1m',
  513. 'underline': '\033[4m'}
  514. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  515. def labels_to_class_weights(labels, nc=80):
  516. # Get class weights (inverse frequency) from training labels
  517. if labels[0] is None: # no labels loaded
  518. return torch.Tensor()
  519. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  520. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  521. weights = np.bincount(classes, minlength=nc) # occurrences per class
  522. # Prepend gridpoint count (for uCE training)
  523. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  524. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  525. weights[weights == 0] = 1 # replace empty bins with 1
  526. weights = 1 / weights # number of targets per class
  527. weights /= weights.sum() # normalize
  528. return torch.from_numpy(weights)
  529. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  530. # Produces image weights based on class_weights and image contents
  531. # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
  532. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  533. return (class_weights.reshape(1, nc) * class_counts).sum(1)
  534. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  535. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  536. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  537. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  538. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  539. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  540. return [
  541. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  542. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  543. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  544. def xyxy2xywh(x):
  545. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  546. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  547. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  548. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  549. y[:, 2] = x[:, 2] - x[:, 0] # width
  550. y[:, 3] = x[:, 3] - x[:, 1] # height
  551. return y
  552. def xywh2xyxy(x):
  553. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  554. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  555. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  556. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  557. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  558. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  559. return y
  560. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  561. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  562. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  563. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  564. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  565. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  566. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  567. return y
  568. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  569. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  570. if clip:
  571. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  572. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  573. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  574. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  575. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  576. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  577. return y
  578. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  579. # Convert normalized segments into pixel segments, shape (n,2)
  580. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  581. y[:, 0] = w * x[:, 0] + padw # top left x
  582. y[:, 1] = h * x[:, 1] + padh # top left y
  583. return y
  584. def segment2box(segment, width=640, height=640):
  585. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  586. x, y = segment.T # segment xy
  587. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  588. x, y, = x[inside], y[inside]
  589. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  590. def segments2boxes(segments):
  591. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  592. boxes = []
  593. for s in segments:
  594. x, y = s.T # segment xy
  595. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  596. return xyxy2xywh(np.array(boxes)) # cls, xywh
  597. def resample_segments(segments, n=1000):
  598. # Up-sample an (n,2) segment
  599. for i, s in enumerate(segments):
  600. x = np.linspace(0, len(s) - 1, n)
  601. xp = np.arange(len(s))
  602. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  603. return segments
  604. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  605. # Rescale coords (xyxy) from img1_shape to img0_shape
  606. if ratio_pad is None: # calculate from img0_shape
  607. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  608. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  609. else:
  610. gain = ratio_pad[0][0]
  611. pad = ratio_pad[1]
  612. coords[:, [0, 2]] -= pad[0] # x padding
  613. coords[:, [1, 3]] -= pad[1] # y padding
  614. coords[:, :4] /= gain
  615. clip_coords(coords, img0_shape)
  616. return coords
  617. def clip_coords(boxes, shape):
  618. # Clip bounding xyxy bounding boxes to image shape (height, width)
  619. if isinstance(boxes, torch.Tensor): # faster individually
  620. boxes[:, 0].clamp_(0, shape[1]) # x1
  621. boxes[:, 1].clamp_(0, shape[0]) # y1
  622. boxes[:, 2].clamp_(0, shape[1]) # x2
  623. boxes[:, 3].clamp_(0, shape[0]) # y2
  624. else: # np.array (faster grouped)
  625. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  626. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  627. def non_max_suppression(prediction,
  628. conf_thres=0.25,
  629. iou_thres=0.45,
  630. classes=None,
  631. agnostic=False,
  632. multi_label=False,
  633. labels=(),
  634. max_det=300):
  635. """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes
  636. Returns:
  637. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  638. """
  639. bs = prediction.shape[0] # batch size
  640. nc = prediction.shape[2] - 5 # number of classes
  641. xc = prediction[..., 4] > conf_thres # candidates
  642. # Checks
  643. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  644. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  645. # Settings
  646. # min_wh = 2 # (pixels) minimum box width and height
  647. max_wh = 7680 # (pixels) maximum box width and height
  648. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  649. time_limit = 0.3 + 0.03 * bs # seconds to quit after
  650. redundant = True # require redundant detections
  651. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  652. merge = False # use merge-NMS
  653. t = time.time()
  654. output = [torch.zeros((0, 6), device=prediction.device)] * bs
  655. for xi, x in enumerate(prediction): # image index, image inference
  656. # Apply constraints
  657. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  658. x = x[xc[xi]] # confidence
  659. # Cat apriori labels if autolabelling
  660. if labels and len(labels[xi]):
  661. lb = labels[xi]
  662. v = torch.zeros((len(lb), nc + 5), device=x.device)
  663. v[:, :4] = lb[:, 1:5] # box
  664. v[:, 4] = 1.0 # conf
  665. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  666. x = torch.cat((x, v), 0)
  667. # If none remain process next image
  668. if not x.shape[0]:
  669. continue
  670. # Compute conf
  671. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  672. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  673. box = xywh2xyxy(x[:, :4])
  674. # Detections matrix nx6 (xyxy, conf, cls)
  675. if multi_label:
  676. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  677. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  678. else: # best class only
  679. conf, j = x[:, 5:].max(1, keepdim=True)
  680. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  681. # Filter by class
  682. if classes is not None:
  683. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  684. # Apply finite constraint
  685. # if not torch.isfinite(x).all():
  686. # x = x[torch.isfinite(x).all(1)]
  687. # Check shape
  688. n = x.shape[0] # number of boxes
  689. if not n: # no boxes
  690. continue
  691. elif n > max_nms: # excess boxes
  692. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  693. # Batched NMS
  694. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  695. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  696. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  697. if i.shape[0] > max_det: # limit detections
  698. i = i[:max_det]
  699. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  700. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  701. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  702. weights = iou * scores[None] # box weights
  703. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  704. if redundant:
  705. i = i[iou.sum(1) > 1] # require redundancy
  706. output[xi] = x[i]
  707. if (time.time() - t) > time_limit:
  708. LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
  709. break # time limit exceeded
  710. return output
  711. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  712. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  713. x = torch.load(f, map_location=torch.device('cpu'))
  714. if x.get('ema'):
  715. x['model'] = x['ema'] # replace model with ema
  716. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  717. x[k] = None
  718. x['epoch'] = -1
  719. x['model'].half() # to FP16
  720. for p in x['model'].parameters():
  721. p.requires_grad = False
  722. torch.save(x, s or f)
  723. mb = os.path.getsize(s or f) / 1E6 # filesize
  724. LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
  725. def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  726. evolve_csv = save_dir / 'evolve.csv'
  727. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  728. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
  729. 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  730. keys = tuple(x.strip() for x in keys)
  731. vals = results + tuple(hyp.values())
  732. n = len(keys)
  733. # Download (optional)
  734. if bucket:
  735. url = f'gs://{bucket}/evolve.csv'
  736. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  737. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  738. # Log to evolve.csv
  739. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  740. with open(evolve_csv, 'a') as f:
  741. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  742. # Save yaml
  743. with open(evolve_yaml, 'w') as f:
  744. data = pd.read_csv(evolve_csv)
  745. data = data.rename(columns=lambda x: x.strip()) # strip keys
  746. i = np.argmax(fitness(data.values[:, :4])) #
  747. generations = len(data)
  748. f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
  749. f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
  750. '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  751. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  752. # Print to screen
  753. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
  754. ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
  755. for x in vals) + '\n\n')
  756. if bucket:
  757. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  758. def apply_classifier(x, model, img, im0):
  759. # Apply a second stage classifier to YOLO outputs
  760. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  761. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  762. for i, d in enumerate(x): # per image
  763. if d is not None and len(d):
  764. d = d.clone()
  765. # Reshape and pad cutouts
  766. b = xyxy2xywh(d[:, :4]) # boxes
  767. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  768. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  769. d[:, :4] = xywh2xyxy(b).long()
  770. # Rescale boxes from img_size to im0 size
  771. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  772. # Classes
  773. pred_cls1 = d[:, 5].long()
  774. ims = []
  775. for a in d:
  776. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  777. im = cv2.resize(cutout, (224, 224)) # BGR
  778. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  779. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  780. im /= 255 # 0 - 255 to 0.0 - 1.0
  781. ims.append(im)
  782. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  783. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  784. return x
  785. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  786. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  787. path = Path(path) # os-agnostic
  788. if path.exists() and not exist_ok:
  789. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  790. # Method 1
  791. for n in range(2, 9999):
  792. p = f'{path}{sep}{n}{suffix}' # increment path
  793. if not os.path.exists(p): #
  794. break
  795. path = Path(p)
  796. # Method 2 (deprecated)
  797. # dirs = glob.glob(f"{path}{sep}*") # similar paths
  798. # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
  799. # i = [int(m.groups()[0]) for m in matches if m] # indices
  800. # n = max(i) + 1 if i else 2 # increment number
  801. # path = Path(f"{path}{sep}{n}{suffix}") # increment path
  802. if mkdir:
  803. path.mkdir(parents=True, exist_ok=True) # make directory
  804. return path
  805. # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
  806. imshow_ = cv2.imshow # copy to avoid recursion errors
  807. def imread(path, flags=cv2.IMREAD_COLOR):
  808. return cv2.imdecode(np.fromfile(path, np.uint8), flags)
  809. def imwrite(path, im):
  810. try:
  811. cv2.imencode(Path(path).suffix, im)[1].tofile(path)
  812. return True
  813. except Exception:
  814. return False
  815. def imshow(path, im):
  816. imshow_(path.encode('unicode_escape').decode(), im)
  817. cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
  818. # Variables ------------------------------------------------------------------------------------------------------------
  819. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm