You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1015 lines
41KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import inspect
  8. import logging
  9. import math
  10. import os
  11. import platform
  12. import random
  13. import re
  14. import shutil
  15. import signal
  16. import threading
  17. import time
  18. import urllib
  19. from datetime import datetime
  20. from itertools import repeat
  21. from multiprocessing.pool import ThreadPool
  22. from pathlib import Path
  23. from subprocess import check_output
  24. from typing import Optional
  25. from zipfile import ZipFile
  26. import cv2
  27. import numpy as np
  28. import pandas as pd
  29. import pkg_resources as pkg
  30. import torch
  31. import torchvision
  32. import yaml
  33. from utils.downloads import gsutil_getsize
  34. from utils.metrics import box_iou, fitness
  35. FILE = Path(__file__).resolve()
  36. ROOT = FILE.parents[1] # YOLOv5 root directory
  37. RANK = int(os.getenv('RANK', -1))
  38. # Settings
  39. DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
  40. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  41. AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
  42. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  43. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  44. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  45. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  46. pd.options.display.max_columns = 10
  47. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  48. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  49. os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy)
  50. def is_kaggle():
  51. # Is environment a Kaggle Notebook?
  52. try:
  53. assert os.environ.get('PWD') == '/kaggle/working'
  54. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  55. return True
  56. except AssertionError:
  57. return False
  58. def is_writeable(dir, test=False):
  59. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  60. if not test:
  61. return os.access(dir, os.R_OK) # possible issues on Windows
  62. file = Path(dir) / 'tmp.txt'
  63. try:
  64. with open(file, 'w'): # open file with write permissions
  65. pass
  66. file.unlink() # remove file
  67. return True
  68. except OSError:
  69. return False
  70. def set_logging(name=None, verbose=VERBOSE):
  71. # Sets level and returns logger
  72. if is_kaggle():
  73. for h in logging.root.handlers:
  74. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  75. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  76. level = logging.INFO if verbose and rank in {-1, 0} else logging.WARNING
  77. log = logging.getLogger(name)
  78. log.setLevel(level)
  79. handler = logging.StreamHandler()
  80. handler.setFormatter(logging.Formatter("%(message)s"))
  81. handler.setLevel(level)
  82. log.addHandler(handler)
  83. set_logging() # run before defining LOGGER
  84. LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
  85. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  86. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  87. env = os.getenv(env_var)
  88. if env:
  89. path = Path(env) # use environment variable
  90. else:
  91. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  92. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  93. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  94. path.mkdir(exist_ok=True) # make if required
  95. return path
  96. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  97. class Profile(contextlib.ContextDecorator):
  98. # Usage: @Profile() decorator or 'with Profile():' context manager
  99. def __enter__(self):
  100. self.start = time.time()
  101. def __exit__(self, type, value, traceback):
  102. print(f'Profile results: {time.time() - self.start:.5f}s')
  103. class Timeout(contextlib.ContextDecorator):
  104. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  105. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  106. self.seconds = int(seconds)
  107. self.timeout_message = timeout_msg
  108. self.suppress = bool(suppress_timeout_errors)
  109. def _timeout_handler(self, signum, frame):
  110. raise TimeoutError(self.timeout_message)
  111. def __enter__(self):
  112. if platform.system() != 'Windows': # not supported on Windows
  113. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  114. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  115. def __exit__(self, exc_type, exc_val, exc_tb):
  116. if platform.system() != 'Windows':
  117. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  118. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  119. return True
  120. class WorkingDirectory(contextlib.ContextDecorator):
  121. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  122. def __init__(self, new_dir):
  123. self.dir = new_dir # new dir
  124. self.cwd = Path.cwd().resolve() # current dir
  125. def __enter__(self):
  126. os.chdir(self.dir)
  127. def __exit__(self, exc_type, exc_val, exc_tb):
  128. os.chdir(self.cwd)
  129. def try_except(func):
  130. # try-except function. Usage: @try_except decorator
  131. def handler(*args, **kwargs):
  132. try:
  133. func(*args, **kwargs)
  134. except Exception as e:
  135. print(e)
  136. return handler
  137. def threaded(func):
  138. # Multi-threads a target function and returns thread. Usage: @threaded decorator
  139. def wrapper(*args, **kwargs):
  140. thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
  141. thread.start()
  142. return thread
  143. return wrapper
  144. def methods(instance):
  145. # Get class/instance methods
  146. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  147. def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
  148. # Print function arguments (optional args dict)
  149. x = inspect.currentframe().f_back # previous frame
  150. file, _, fcn, _, _ = inspect.getframeinfo(x)
  151. if args is None: # get args automatically
  152. args, _, _, frm = inspect.getargvalues(x)
  153. args = {k: v for k, v in frm.items() if k in args}
  154. s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
  155. LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
  156. def init_seeds(seed=0):
  157. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  158. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  159. import torch.backends.cudnn as cudnn
  160. random.seed(seed)
  161. np.random.seed(seed)
  162. torch.manual_seed(seed)
  163. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  164. def intersect_dicts(da, db, exclude=()):
  165. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  166. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  167. def get_latest_run(search_dir='.'):
  168. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  169. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  170. return max(last_list, key=os.path.getctime) if last_list else ''
  171. def is_docker():
  172. # Is environment a Docker container?
  173. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  174. def is_colab():
  175. # Is environment a Google Colab instance?
  176. try:
  177. import google.colab
  178. return True
  179. except ImportError:
  180. return False
  181. def is_pip():
  182. # Is file in a pip package?
  183. return 'site-packages' in Path(__file__).resolve().parts
  184. def is_ascii(s=''):
  185. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  186. s = str(s) # convert list, tuple, None, etc. to str
  187. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  188. def is_chinese(s='人工智能'):
  189. # Is string composed of any Chinese characters?
  190. return bool(re.search('[\u4e00-\u9fff]', str(s)))
  191. def emojis(str=''):
  192. # Return platform-dependent emoji-safe version of string
  193. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  194. def file_age(path=__file__):
  195. # Return days since last file update
  196. dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
  197. return dt.days # + dt.seconds / 86400 # fractional days
  198. def file_date(path=__file__):
  199. # Return human-readable file modification date, i.e. '2021-3-26'
  200. t = datetime.fromtimestamp(Path(path).stat().st_mtime)
  201. return f'{t.year}-{t.month}-{t.day}'
  202. def file_size(path):
  203. # Return file/dir size (MB)
  204. mb = 1 << 20 # bytes to MiB (1024 ** 2)
  205. path = Path(path)
  206. if path.is_file():
  207. return path.stat().st_size / mb
  208. elif path.is_dir():
  209. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
  210. else:
  211. return 0.0
  212. def check_online():
  213. # Check internet connectivity
  214. import socket
  215. try:
  216. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  217. return True
  218. except OSError:
  219. return False
  220. def git_describe(path=ROOT): # path must be a directory
  221. # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  222. try:
  223. assert (Path(path) / '.git').is_dir()
  224. return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
  225. except Exception:
  226. return ''
  227. @try_except
  228. @WorkingDirectory(ROOT)
  229. def check_git_status():
  230. # Recommend 'git pull' if code is out of date
  231. msg = ', for updates see https://github.com/ultralytics/yolov5'
  232. s = colorstr('github: ') # string
  233. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  234. assert not is_docker(), s + 'skipping check (Docker image)' + msg
  235. assert check_online(), s + 'skipping check (offline)' + msg
  236. cmd = 'git fetch && git config --get remote.origin.url'
  237. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  238. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  239. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  240. if n > 0:
  241. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  242. else:
  243. s += f'up to date with {url} ✅'
  244. LOGGER.info(emojis(s)) # emoji-safe
  245. def check_python(minimum='3.7.0'):
  246. # Check current python version vs. required python version
  247. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  248. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  249. # Check version vs. required version
  250. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  251. result = (current == minimum) if pinned else (current >= minimum) # bool
  252. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  253. if hard:
  254. assert result, s # assert min requirements met
  255. if verbose and not result:
  256. LOGGER.warning(s)
  257. return result
  258. @try_except
  259. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
  260. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  261. prefix = colorstr('red', 'bold', 'requirements:')
  262. check_python() # check python version
  263. if isinstance(requirements, (str, Path)): # requirements.txt file
  264. file = Path(requirements)
  265. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  266. with file.open() as f:
  267. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  268. else: # list or tuple of packages
  269. requirements = [x for x in requirements if x not in exclude]
  270. n = 0 # number of packages updates
  271. for i, r in enumerate(requirements):
  272. try:
  273. pkg.require(r)
  274. except Exception: # DistributionNotFound or VersionConflict if requirements not met
  275. s = f"{prefix} {r} not found and is required by YOLOv5"
  276. if install and AUTOINSTALL: # check environment variable
  277. LOGGER.info(f"{s}, attempting auto-update...")
  278. try:
  279. assert check_online(), f"'pip install {r}' skipped (offline)"
  280. LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode())
  281. n += 1
  282. except Exception as e:
  283. LOGGER.warning(f'{prefix} {e}')
  284. else:
  285. LOGGER.info(f'{s}. Please install and rerun your command.')
  286. if n: # if packages updated
  287. source = file.resolve() if 'file' in locals() else requirements
  288. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  289. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  290. LOGGER.info(emojis(s))
  291. def check_img_size(imgsz, s=32, floor=0):
  292. # Verify image size is a multiple of stride s in each dimension
  293. if isinstance(imgsz, int): # integer i.e. img_size=640
  294. new_size = max(make_divisible(imgsz, int(s)), floor)
  295. else: # list i.e. img_size=[640, 480]
  296. imgsz = list(imgsz) # convert to list if tuple
  297. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  298. if new_size != imgsz:
  299. LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  300. return new_size
  301. def check_imshow():
  302. # Check if environment supports image displays
  303. try:
  304. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  305. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  306. cv2.imshow('test', np.zeros((1, 1, 3)))
  307. cv2.waitKey(1)
  308. cv2.destroyAllWindows()
  309. cv2.waitKey(1)
  310. return True
  311. except Exception as e:
  312. LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  313. return False
  314. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  315. # Check file(s) for acceptable suffix
  316. if file and suffix:
  317. if isinstance(suffix, str):
  318. suffix = [suffix]
  319. for f in file if isinstance(file, (list, tuple)) else [file]:
  320. s = Path(f).suffix.lower() # file suffix
  321. if len(s):
  322. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  323. def check_yaml(file, suffix=('.yaml', '.yml')):
  324. # Search/download YAML file (if necessary) and return path, checking suffix
  325. return check_file(file, suffix)
  326. def check_file(file, suffix=''):
  327. # Search/download file (if necessary) and return path
  328. check_suffix(file, suffix) # optional
  329. file = str(file) # convert to str()
  330. if Path(file).is_file() or not file: # exists
  331. return file
  332. elif file.startswith(('http:/', 'https:/')): # download
  333. url = file # warning: Pathlib turns :// -> :/
  334. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  335. if Path(file).is_file():
  336. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  337. else:
  338. LOGGER.info(f'Downloading {url} to {file}...')
  339. torch.hub.download_url_to_file(url, file)
  340. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  341. return file
  342. else: # search
  343. files = []
  344. for d in 'data', 'models', 'utils': # search directories
  345. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  346. assert len(files), f'File not found: {file}' # assert file was found
  347. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  348. return files[0] # return file
  349. def check_font(font=FONT, progress=False):
  350. # Download font to CONFIG_DIR if necessary
  351. font = Path(font)
  352. file = CONFIG_DIR / font.name
  353. if not font.exists() and not file.exists():
  354. url = "https://ultralytics.com/assets/" + font.name
  355. LOGGER.info(f'Downloading {url} to {file}...')
  356. torch.hub.download_url_to_file(url, str(file), progress=progress)
  357. def check_dataset(data, autodownload=True):
  358. # Download and/or unzip dataset if not found locally
  359. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  360. # Download (optional)
  361. extract_dir = ''
  362. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  363. download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
  364. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  365. extract_dir, autodownload = data.parent, False
  366. # Read yaml (optional)
  367. if isinstance(data, (str, Path)):
  368. with open(data, errors='ignore') as f:
  369. data = yaml.safe_load(f) # dictionary
  370. # Resolve paths
  371. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  372. if not path.is_absolute():
  373. path = (ROOT / path).resolve()
  374. for k in 'train', 'val', 'test':
  375. if data.get(k): # prepend path
  376. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  377. # Parse yaml
  378. assert 'nc' in data, "Dataset 'nc' key missing."
  379. if 'names' not in data:
  380. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  381. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  382. if val:
  383. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  384. if not all(x.exists() for x in val):
  385. LOGGER.info(emojis('\nDataset not found ⚠, missing paths %s' % [str(x) for x in val if not x.exists()]))
  386. if not s or not autodownload:
  387. raise Exception(emojis('Dataset not found ❌'))
  388. t = time.time()
  389. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  390. if s.startswith('http') and s.endswith('.zip'): # URL
  391. f = Path(s).name # filename
  392. LOGGER.info(f'Downloading {s} to {f}...')
  393. torch.hub.download_url_to_file(s, f)
  394. Path(root).mkdir(parents=True, exist_ok=True) # create root
  395. ZipFile(f).extractall(path=root) # unzip
  396. Path(f).unlink() # remove zip
  397. r = None # success
  398. elif s.startswith('bash '): # bash script
  399. LOGGER.info(f'Running {s} ...')
  400. r = os.system(s)
  401. else: # python script
  402. r = exec(s, {'yaml': data}) # return None
  403. dt = f'({round(time.time() - t, 1)}s)'
  404. s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
  405. LOGGER.info(emojis(f"Dataset download {s}"))
  406. check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
  407. return data # dictionary
  408. def check_amp(model):
  409. # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
  410. from models.common import AutoShape, DetectMultiBackend
  411. def amp_allclose(model, im):
  412. # All close FP32 vs AMP results
  413. m = AutoShape(model, verbose=False) # model
  414. a = m(im).xywhn[0] # FP32 inference
  415. m.amp = True
  416. b = m(im).xywhn[0] # AMP inference
  417. return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
  418. prefix = colorstr('AMP: ')
  419. device = next(model.parameters()).device # get model device
  420. if device.type == 'cpu':
  421. return False # AMP disabled on CPU
  422. f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
  423. im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
  424. try:
  425. assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im)
  426. LOGGER.info(emojis(f'{prefix}checks passed ✅'))
  427. return True
  428. except Exception:
  429. help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
  430. LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
  431. return False
  432. def url2file(url):
  433. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  434. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  435. return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  436. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
  437. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  438. def download_one(url, dir):
  439. # Download 1 file
  440. success = True
  441. f = dir / Path(url).name # filename
  442. if Path(url).is_file(): # exists in current path
  443. Path(url).rename(f) # move to dir
  444. elif not f.exists():
  445. LOGGER.info(f'Downloading {url} to {f}...')
  446. for i in range(retry + 1):
  447. if curl:
  448. s = 'sS' if threads > 1 else '' # silent
  449. r = os.system(f"curl -{s}L '{url}' -o '{f}' --retry 9 -C -") # curl download
  450. success = r == 0
  451. else:
  452. torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
  453. success = f.is_file()
  454. if success:
  455. break
  456. elif i < retry:
  457. LOGGER.warning(f'Download failure, retrying {i + 1}/{retry} {url}...')
  458. else:
  459. LOGGER.warning(f'Failed to download {url}...')
  460. if unzip and success and f.suffix in ('.zip', '.gz'):
  461. LOGGER.info(f'Unzipping {f}...')
  462. if f.suffix == '.zip':
  463. ZipFile(f).extractall(path=dir) # unzip
  464. elif f.suffix == '.gz':
  465. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  466. if delete:
  467. f.unlink() # remove zip
  468. dir = Path(dir)
  469. dir.mkdir(parents=True, exist_ok=True) # make directory
  470. if threads > 1:
  471. pool = ThreadPool(threads)
  472. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  473. pool.close()
  474. pool.join()
  475. else:
  476. for u in [url] if isinstance(url, (str, Path)) else url:
  477. download_one(u, dir)
  478. def make_divisible(x, divisor):
  479. # Returns nearest x divisible by divisor
  480. if isinstance(divisor, torch.Tensor):
  481. divisor = int(divisor.max()) # to int
  482. return math.ceil(x / divisor) * divisor
  483. def clean_str(s):
  484. # Cleans a string by replacing special characters with underscore _
  485. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  486. def one_cycle(y1=0.0, y2=1.0, steps=100):
  487. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  488. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  489. def colorstr(*input):
  490. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  491. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  492. colors = {
  493. 'black': '\033[30m', # basic colors
  494. 'red': '\033[31m',
  495. 'green': '\033[32m',
  496. 'yellow': '\033[33m',
  497. 'blue': '\033[34m',
  498. 'magenta': '\033[35m',
  499. 'cyan': '\033[36m',
  500. 'white': '\033[37m',
  501. 'bright_black': '\033[90m', # bright colors
  502. 'bright_red': '\033[91m',
  503. 'bright_green': '\033[92m',
  504. 'bright_yellow': '\033[93m',
  505. 'bright_blue': '\033[94m',
  506. 'bright_magenta': '\033[95m',
  507. 'bright_cyan': '\033[96m',
  508. 'bright_white': '\033[97m',
  509. 'end': '\033[0m', # misc
  510. 'bold': '\033[1m',
  511. 'underline': '\033[4m'}
  512. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  513. def labels_to_class_weights(labels, nc=80):
  514. # Get class weights (inverse frequency) from training labels
  515. if labels[0] is None: # no labels loaded
  516. return torch.Tensor()
  517. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  518. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  519. weights = np.bincount(classes, minlength=nc) # occurrences per class
  520. # Prepend gridpoint count (for uCE training)
  521. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  522. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  523. weights[weights == 0] = 1 # replace empty bins with 1
  524. weights = 1 / weights # number of targets per class
  525. weights /= weights.sum() # normalize
  526. return torch.from_numpy(weights)
  527. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  528. # Produces image weights based on class_weights and image contents
  529. # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
  530. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  531. return (class_weights.reshape(1, nc) * class_counts).sum(1)
  532. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  533. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  534. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  535. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  536. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  537. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  538. return [
  539. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  540. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  541. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  542. def xyxy2xywh(x):
  543. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  544. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  545. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  546. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  547. y[:, 2] = x[:, 2] - x[:, 0] # width
  548. y[:, 3] = x[:, 3] - x[:, 1] # height
  549. return y
  550. def xywh2xyxy(x):
  551. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  552. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  553. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  554. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  555. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  556. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  557. return y
  558. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  559. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  560. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  561. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  562. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  563. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  564. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  565. return y
  566. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  567. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  568. if clip:
  569. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  570. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  571. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  572. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  573. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  574. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  575. return y
  576. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  577. # Convert normalized segments into pixel segments, shape (n,2)
  578. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  579. y[:, 0] = w * x[:, 0] + padw # top left x
  580. y[:, 1] = h * x[:, 1] + padh # top left y
  581. return y
  582. def segment2box(segment, width=640, height=640):
  583. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  584. x, y = segment.T # segment xy
  585. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  586. x, y, = x[inside], y[inside]
  587. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  588. def segments2boxes(segments):
  589. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  590. boxes = []
  591. for s in segments:
  592. x, y = s.T # segment xy
  593. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  594. return xyxy2xywh(np.array(boxes)) # cls, xywh
  595. def resample_segments(segments, n=1000):
  596. # Up-sample an (n,2) segment
  597. for i, s in enumerate(segments):
  598. x = np.linspace(0, len(s) - 1, n)
  599. xp = np.arange(len(s))
  600. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  601. return segments
  602. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  603. # Rescale coords (xyxy) from img1_shape to img0_shape
  604. if ratio_pad is None: # calculate from img0_shape
  605. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  606. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  607. else:
  608. gain = ratio_pad[0][0]
  609. pad = ratio_pad[1]
  610. coords[:, [0, 2]] -= pad[0] # x padding
  611. coords[:, [1, 3]] -= pad[1] # y padding
  612. coords[:, :4] /= gain
  613. clip_coords(coords, img0_shape)
  614. return coords
  615. def clip_coords(boxes, shape):
  616. # Clip bounding xyxy bounding boxes to image shape (height, width)
  617. if isinstance(boxes, torch.Tensor): # faster individually
  618. boxes[:, 0].clamp_(0, shape[1]) # x1
  619. boxes[:, 1].clamp_(0, shape[0]) # y1
  620. boxes[:, 2].clamp_(0, shape[1]) # x2
  621. boxes[:, 3].clamp_(0, shape[0]) # y2
  622. else: # np.array (faster grouped)
  623. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  624. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  625. def non_max_suppression(prediction,
  626. conf_thres=0.25,
  627. iou_thres=0.45,
  628. classes=None,
  629. agnostic=False,
  630. multi_label=False,
  631. labels=(),
  632. max_det=300):
  633. """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes
  634. Returns:
  635. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  636. """
  637. bs = prediction.shape[0] # batch size
  638. nc = prediction.shape[2] - 5 # number of classes
  639. xc = prediction[..., 4] > conf_thres # candidates
  640. # Checks
  641. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  642. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  643. # Settings
  644. # min_wh = 2 # (pixels) minimum box width and height
  645. max_wh = 7680 # (pixels) maximum box width and height
  646. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  647. time_limit = 0.3 + 0.03 * bs # seconds to quit after
  648. redundant = True # require redundant detections
  649. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  650. merge = False # use merge-NMS
  651. t = time.time()
  652. output = [torch.zeros((0, 6), device=prediction.device)] * bs
  653. for xi, x in enumerate(prediction): # image index, image inference
  654. # Apply constraints
  655. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  656. x = x[xc[xi]] # confidence
  657. # Cat apriori labels if autolabelling
  658. if labels and len(labels[xi]):
  659. lb = labels[xi]
  660. v = torch.zeros((len(lb), nc + 5), device=x.device)
  661. v[:, :4] = lb[:, 1:5] # box
  662. v[:, 4] = 1.0 # conf
  663. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  664. x = torch.cat((x, v), 0)
  665. # If none remain process next image
  666. if not x.shape[0]:
  667. continue
  668. # Compute conf
  669. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  670. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  671. box = xywh2xyxy(x[:, :4])
  672. # Detections matrix nx6 (xyxy, conf, cls)
  673. if multi_label:
  674. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  675. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  676. else: # best class only
  677. conf, j = x[:, 5:].max(1, keepdim=True)
  678. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  679. # Filter by class
  680. if classes is not None:
  681. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  682. # Apply finite constraint
  683. # if not torch.isfinite(x).all():
  684. # x = x[torch.isfinite(x).all(1)]
  685. # Check shape
  686. n = x.shape[0] # number of boxes
  687. if not n: # no boxes
  688. continue
  689. elif n > max_nms: # excess boxes
  690. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  691. # Batched NMS
  692. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  693. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  694. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  695. if i.shape[0] > max_det: # limit detections
  696. i = i[:max_det]
  697. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  698. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  699. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  700. weights = iou * scores[None] # box weights
  701. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  702. if redundant:
  703. i = i[iou.sum(1) > 1] # require redundancy
  704. output[xi] = x[i]
  705. if (time.time() - t) > time_limit:
  706. LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
  707. break # time limit exceeded
  708. return output
  709. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  710. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  711. x = torch.load(f, map_location=torch.device('cpu'))
  712. if x.get('ema'):
  713. x['model'] = x['ema'] # replace model with ema
  714. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  715. x[k] = None
  716. x['epoch'] = -1
  717. x['model'].half() # to FP16
  718. for p in x['model'].parameters():
  719. p.requires_grad = False
  720. torch.save(x, s or f)
  721. mb = os.path.getsize(s or f) / 1E6 # filesize
  722. LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
  723. def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  724. evolve_csv = save_dir / 'evolve.csv'
  725. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  726. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
  727. 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  728. keys = tuple(x.strip() for x in keys)
  729. vals = results + tuple(hyp.values())
  730. n = len(keys)
  731. # Download (optional)
  732. if bucket:
  733. url = f'gs://{bucket}/evolve.csv'
  734. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  735. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  736. # Log to evolve.csv
  737. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  738. with open(evolve_csv, 'a') as f:
  739. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  740. # Save yaml
  741. with open(evolve_yaml, 'w') as f:
  742. data = pd.read_csv(evolve_csv)
  743. data = data.rename(columns=lambda x: x.strip()) # strip keys
  744. i = np.argmax(fitness(data.values[:, :4])) #
  745. generations = len(data)
  746. f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
  747. f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
  748. '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  749. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  750. # Print to screen
  751. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
  752. ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
  753. for x in vals) + '\n\n')
  754. if bucket:
  755. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  756. def apply_classifier(x, model, img, im0):
  757. # Apply a second stage classifier to YOLO outputs
  758. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  759. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  760. for i, d in enumerate(x): # per image
  761. if d is not None and len(d):
  762. d = d.clone()
  763. # Reshape and pad cutouts
  764. b = xyxy2xywh(d[:, :4]) # boxes
  765. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  766. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  767. d[:, :4] = xywh2xyxy(b).long()
  768. # Rescale boxes from img_size to im0 size
  769. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  770. # Classes
  771. pred_cls1 = d[:, 5].long()
  772. ims = []
  773. for a in d:
  774. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  775. im = cv2.resize(cutout, (224, 224)) # BGR
  776. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  777. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  778. im /= 255 # 0 - 255 to 0.0 - 1.0
  779. ims.append(im)
  780. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  781. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  782. return x
  783. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  784. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  785. path = Path(path) # os-agnostic
  786. if path.exists() and not exist_ok:
  787. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  788. # Method 1
  789. for n in range(2, 9999):
  790. p = f'{path}{sep}{n}{suffix}' # increment path
  791. if not os.path.exists(p): #
  792. break
  793. path = Path(p)
  794. # Method 2 (deprecated)
  795. # dirs = glob.glob(f"{path}{sep}*") # similar paths
  796. # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
  797. # i = [int(m.groups()[0]) for m in matches if m] # indices
  798. # n = max(i) + 1 if i else 2 # increment number
  799. # path = Path(f"{path}{sep}{n}{suffix}") # increment path
  800. if mkdir:
  801. path.mkdir(parents=True, exist_ok=True) # make directory
  802. return path
  803. # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
  804. imshow_ = cv2.imshow # copy to avoid recursion errors
  805. def imread(path, flags=cv2.IMREAD_COLOR):
  806. return cv2.imdecode(np.fromfile(path, np.uint8), flags)
  807. def imwrite(path, im):
  808. try:
  809. cv2.imencode(Path(path).suffix, im)[1].tofile(path)
  810. return True
  811. except Exception:
  812. return False
  813. def imshow(path, im):
  814. imshow_(path.encode('unicode_escape').decode(), im)
  815. cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
  816. # Variables ------------------------------------------------------------------------------------------------------------
  817. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm