Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

1015 lignes
41KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import inspect
  8. import logging
  9. import math
  10. import os
  11. import platform
  12. import random
  13. import re
  14. import shutil
  15. import signal
  16. import threading
  17. import time
  18. import urllib
  19. from datetime import datetime
  20. from itertools import repeat
  21. from multiprocessing.pool import ThreadPool
  22. from pathlib import Path
  23. from subprocess import check_output
  24. from typing import Optional
  25. from zipfile import ZipFile
  26. import cv2
  27. import numpy as np
  28. import pandas as pd
  29. import pkg_resources as pkg
  30. import torch
  31. import torchvision
  32. import yaml
  33. from utils.downloads import gsutil_getsize
  34. from utils.metrics import box_iou, fitness
  35. FILE = Path(__file__).resolve()
  36. ROOT = FILE.parents[1] # YOLOv5 root directory
  37. RANK = int(os.getenv('RANK', -1))
  38. # Settings
  39. DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
  40. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  41. AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
  42. VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
  43. FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
  44. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  45. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  46. pd.options.display.max_columns = 10
  47. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  48. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  49. os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy)
  50. def is_kaggle():
  51. # Is environment a Kaggle Notebook?
  52. try:
  53. assert os.environ.get('PWD') == '/kaggle/working'
  54. assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
  55. return True
  56. except AssertionError:
  57. return False
  58. def is_writeable(dir, test=False):
  59. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  60. if not test:
  61. return os.access(dir, os.R_OK) # possible issues on Windows
  62. file = Path(dir) / 'tmp.txt'
  63. try:
  64. with open(file, 'w'): # open file with write permissions
  65. pass
  66. file.unlink() # remove file
  67. return True
  68. except OSError:
  69. return False
  70. def set_logging(name=None, verbose=VERBOSE):
  71. # Sets level and returns logger
  72. if is_kaggle():
  73. for h in logging.root.handlers:
  74. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  75. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  76. level = logging.INFO if verbose and rank in {-1, 0} else logging.WARNING
  77. log = logging.getLogger(name)
  78. log.setLevel(level)
  79. handler = logging.StreamHandler()
  80. handler.setFormatter(logging.Formatter("%(message)s"))
  81. handler.setLevel(level)
  82. log.addHandler(handler)
  83. set_logging() # run before defining LOGGER
  84. LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.)
  85. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  86. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  87. env = os.getenv(env_var)
  88. if env:
  89. path = Path(env) # use environment variable
  90. else:
  91. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  92. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  93. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  94. path.mkdir(exist_ok=True) # make if required
  95. return path
  96. CONFIG_DIR = user_config_dir() # Ultralytics settings dir
  97. class Profile(contextlib.ContextDecorator):
  98. # Usage: @Profile() decorator or 'with Profile():' context manager
  99. def __enter__(self):
  100. self.start = time.time()
  101. def __exit__(self, type, value, traceback):
  102. print(f'Profile results: {time.time() - self.start:.5f}s')
  103. class Timeout(contextlib.ContextDecorator):
  104. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  105. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  106. self.seconds = int(seconds)
  107. self.timeout_message = timeout_msg
  108. self.suppress = bool(suppress_timeout_errors)
  109. def _timeout_handler(self, signum, frame):
  110. raise TimeoutError(self.timeout_message)
  111. def __enter__(self):
  112. if platform.system() != 'Windows': # not supported on Windows
  113. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  114. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  115. def __exit__(self, exc_type, exc_val, exc_tb):
  116. if platform.system() != 'Windows':
  117. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  118. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  119. return True
  120. class WorkingDirectory(contextlib.ContextDecorator):
  121. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  122. def __init__(self, new_dir):
  123. self.dir = new_dir # new dir
  124. self.cwd = Path.cwd().resolve() # current dir
  125. def __enter__(self):
  126. os.chdir(self.dir)
  127. def __exit__(self, exc_type, exc_val, exc_tb):
  128. os.chdir(self.cwd)
  129. def try_except(func):
  130. # try-except function. Usage: @try_except decorator
  131. def handler(*args, **kwargs):
  132. try:
  133. func(*args, **kwargs)
  134. except Exception as e:
  135. print(e)
  136. return handler
  137. def threaded(func):
  138. # Multi-threads a target function and returns thread. Usage: @threaded decorator
  139. def wrapper(*args, **kwargs):
  140. thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
  141. thread.start()
  142. return thread
  143. return wrapper
  144. def methods(instance):
  145. # Get class/instance methods
  146. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  147. def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
  148. # Print function arguments (optional args dict)
  149. x = inspect.currentframe().f_back # previous frame
  150. file, _, fcn, _, _ = inspect.getframeinfo(x)
  151. if args is None: # get args automatically
  152. args, _, _, frm = inspect.getargvalues(x)
  153. args = {k: v for k, v in frm.items() if k in args}
  154. s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
  155. LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
  156. def init_seeds(seed=0):
  157. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  158. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  159. import torch.backends.cudnn as cudnn
  160. random.seed(seed)
  161. np.random.seed(seed)
  162. torch.manual_seed(seed)
  163. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  164. def intersect_dicts(da, db, exclude=()):
  165. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  166. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  167. def get_latest_run(search_dir='.'):
  168. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  169. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  170. return max(last_list, key=os.path.getctime) if last_list else ''
  171. def is_docker():
  172. # Is environment a Docker container?
  173. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  174. def is_colab():
  175. # Is environment a Google Colab instance?
  176. try:
  177. import google.colab
  178. return True
  179. except ImportError:
  180. return False
  181. def is_pip():
  182. # Is file in a pip package?
  183. return 'site-packages' in Path(__file__).resolve().parts
  184. def is_ascii(s=''):
  185. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  186. s = str(s) # convert list, tuple, None, etc. to str
  187. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  188. def is_chinese(s='人工智能'):
  189. # Is string composed of any Chinese characters?
  190. return bool(re.search('[\u4e00-\u9fff]', str(s)))
  191. def emojis(str=''):
  192. # Return platform-dependent emoji-safe version of string
  193. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  194. def file_age(path=__file__):
  195. # Return days since last file update
  196. dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
  197. return dt.days # + dt.seconds / 86400 # fractional days
  198. def file_date(path=__file__):
  199. # Return human-readable file modification date, i.e. '2021-3-26'
  200. t = datetime.fromtimestamp(Path(path).stat().st_mtime)
  201. return f'{t.year}-{t.month}-{t.day}'
  202. def file_size(path):
  203. # Return file/dir size (MB)
  204. mb = 1 << 20 # bytes to MiB (1024 ** 2)
  205. path = Path(path)
  206. if path.is_file():
  207. return path.stat().st_size / mb
  208. elif path.is_dir():
  209. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
  210. else:
  211. return 0.0
  212. def check_online():
  213. # Check internet connectivity
  214. import socket
  215. try:
  216. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  217. return True
  218. except OSError:
  219. return False
  220. def git_describe(path=ROOT): # path must be a directory
  221. # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
  222. try:
  223. assert (Path(path) / '.git').is_dir()
  224. return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
  225. except Exception:
  226. return ''
  227. @try_except
  228. @WorkingDirectory(ROOT)
  229. def check_git_status():
  230. # Recommend 'git pull' if code is out of date
  231. msg = ', for updates see https://github.com/ultralytics/yolov5'
  232. s = colorstr('github: ') # string
  233. assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
  234. assert not is_docker(), s + 'skipping check (Docker image)' + msg
  235. assert check_online(), s + 'skipping check (offline)' + msg
  236. cmd = 'git fetch && git config --get remote.origin.url'
  237. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  238. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  239. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  240. if n > 0:
  241. s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  242. else:
  243. s += f'up to date with {url} ✅'
  244. LOGGER.info(emojis(s)) # emoji-safe
  245. def check_python(minimum='3.7.0'):
  246. # Check current python version vs. required python version
  247. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  248. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  249. # Check version vs. required version
  250. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  251. result = (current == minimum) if pinned else (current >= minimum) # bool
  252. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  253. if hard:
  254. assert result, s # assert min requirements met
  255. if verbose and not result:
  256. LOGGER.warning(s)
  257. return result
  258. @try_except
  259. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
  260. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  261. prefix = colorstr('red', 'bold', 'requirements:')
  262. check_python() # check python version
  263. if isinstance(requirements, (str, Path)): # requirements.txt file
  264. file = Path(requirements)
  265. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  266. with file.open() as f:
  267. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  268. else: # list or tuple of packages
  269. requirements = [x for x in requirements if x not in exclude]
  270. n = 0 # number of packages updates
  271. for i, r in enumerate(requirements):
  272. try:
  273. pkg.require(r)
  274. except Exception: # DistributionNotFound or VersionConflict if requirements not met
  275. s = f"{prefix} {r} not found and is required by YOLOv5"
  276. if install and AUTOINSTALL: # check environment variable
  277. LOGGER.info(f"{s}, attempting auto-update...")
  278. try:
  279. assert check_online(), f"'pip install {r}' skipped (offline)"
  280. LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode())
  281. n += 1
  282. except Exception as e:
  283. LOGGER.warning(f'{prefix} {e}')
  284. else:
  285. LOGGER.info(f'{s}. Please install and rerun your command.')
  286. if n: # if packages updated
  287. source = file.resolve() if 'file' in locals() else requirements
  288. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  289. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  290. LOGGER.info(emojis(s))
  291. def check_img_size(imgsz, s=32, floor=0):
  292. # Verify image size is a multiple of stride s in each dimension
  293. if isinstance(imgsz, int): # integer i.e. img_size=640
  294. new_size = max(make_divisible(imgsz, int(s)), floor)
  295. else: # list i.e. img_size=[640, 480]
  296. imgsz = list(imgsz) # convert to list if tuple
  297. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  298. if new_size != imgsz:
  299. LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  300. return new_size
  301. def check_imshow():
  302. # Check if environment supports image displays
  303. try:
  304. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  305. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  306. cv2.imshow('test', np.zeros((1, 1, 3)))
  307. cv2.waitKey(1)
  308. cv2.destroyAllWindows()
  309. cv2.waitKey(1)
  310. return True
  311. except Exception as e:
  312. LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  313. return False
  314. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  315. # Check file(s) for acceptable suffix
  316. if file and suffix:
  317. if isinstance(suffix, str):
  318. suffix = [suffix]
  319. for f in file if isinstance(file, (list, tuple)) else [file]:
  320. s = Path(f).suffix.lower() # file suffix
  321. if len(s):
  322. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  323. def check_yaml(file, suffix=('.yaml', '.yml')):
  324. # Search/download YAML file (if necessary) and return path, checking suffix
  325. return check_file(file, suffix)
  326. def check_file(file, suffix=''):
  327. # Search/download file (if necessary) and return path
  328. check_suffix(file, suffix) # optional
  329. file = str(file) # convert to str()
  330. if Path(file).is_file() or not file: # exists
  331. return file
  332. elif file.startswith(('http:/', 'https:/')): # download
  333. url = file # warning: Pathlib turns :// -> :/
  334. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  335. if Path(file).is_file():
  336. LOGGER.info(f'Found {url} locally at {file}') # file already exists
  337. else:
  338. LOGGER.info(f'Downloading {url} to {file}...')
  339. torch.hub.download_url_to_file(url, file)
  340. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  341. return file
  342. else: # search
  343. files = []
  344. for d in 'data', 'models', 'utils': # search directories
  345. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  346. assert len(files), f'File not found: {file}' # assert file was found
  347. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  348. return files[0] # return file
  349. def check_font(font=FONT, progress=False):
  350. # Download font to CONFIG_DIR if necessary
  351. font = Path(font)
  352. file = CONFIG_DIR / font.name
  353. if not font.exists() and not file.exists():
  354. url = "https://ultralytics.com/assets/" + font.name
  355. LOGGER.info(f'Downloading {url} to {file}...')
  356. torch.hub.download_url_to_file(url, str(file), progress=progress)
  357. def check_dataset(data, autodownload=True):
  358. # Download and/or unzip dataset if not found locally
  359. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  360. # Download (optional)
  361. extract_dir = ''
  362. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  363. download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
  364. data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
  365. extract_dir, autodownload = data.parent, False
  366. # Read yaml (optional)
  367. if isinstance(data, (str, Path)):
  368. with open(data, errors='ignore') as f:
  369. data = yaml.safe_load(f) # dictionary
  370. # Resolve paths
  371. path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
  372. if not path.is_absolute():
  373. path = (ROOT / path).resolve()
  374. for k in 'train', 'val', 'test':
  375. if data.get(k): # prepend path
  376. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  377. # Parse yaml
  378. assert 'nc' in data, "Dataset 'nc' key missing."
  379. if 'names' not in data:
  380. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  381. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  382. if val:
  383. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  384. if not all(x.exists() for x in val):
  385. LOGGER.info(emojis('\nDataset not found ⚠, missing paths %s' % [str(x) for x in val if not x.exists()]))
  386. if not s or not autodownload:
  387. raise Exception(emojis('Dataset not found ❌'))
  388. t = time.time()
  389. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  390. if s.startswith('http') and s.endswith('.zip'): # URL
  391. f = Path(s).name # filename
  392. LOGGER.info(f'Downloading {s} to {f}...')
  393. torch.hub.download_url_to_file(s, f)
  394. Path(root).mkdir(parents=True, exist_ok=True) # create root
  395. ZipFile(f).extractall(path=root) # unzip
  396. Path(f).unlink() # remove zip
  397. r = None # success
  398. elif s.startswith('bash '): # bash script
  399. LOGGER.info(f'Running {s} ...')
  400. r = os.system(s)
  401. else: # python script
  402. r = exec(s, {'yaml': data}) # return None
  403. dt = f'({round(time.time() - t, 1)}s)'
  404. s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
  405. LOGGER.info(emojis(f"Dataset download {s}"))
  406. check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
  407. return data # dictionary
  408. def check_amp(model):
  409. # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
  410. from models.common import AutoShape
  411. if next(model.parameters()).device.type == 'cpu': # get model device
  412. return False
  413. prefix = colorstr('AMP: ')
  414. file = ROOT / 'data' / 'images' / 'bus.jpg' # image to test
  415. if file.exists():
  416. im = cv2.imread(file)[..., ::-1] # OpenCV image (BGR to RGB)
  417. elif check_online():
  418. im = 'https://ultralytics.com/images/bus.jpg'
  419. else:
  420. LOGGER.warning(emojis(f'{prefix}checks skipped ⚠️, not online.'))
  421. return True
  422. m = AutoShape(model, verbose=False) # model
  423. a = m(im).xywhn[0] # FP32 inference
  424. m.amp = True
  425. b = m(im).xywhn[0] # AMP inference
  426. if (a.shape == b.shape) and torch.allclose(a, b, atol=0.05): # close to 5% absolute tolerance
  427. LOGGER.info(emojis(f'{prefix}checks passed ✅'))
  428. return True
  429. else:
  430. help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
  431. LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
  432. return False
  433. def url2file(url):
  434. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  435. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  436. return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  437. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
  438. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  439. def download_one(url, dir):
  440. # Download 1 file
  441. success = True
  442. f = dir / Path(url).name # filename
  443. if Path(url).is_file(): # exists in current path
  444. Path(url).rename(f) # move to dir
  445. elif not f.exists():
  446. LOGGER.info(f'Downloading {url} to {f}...')
  447. for i in range(retry + 1):
  448. if curl:
  449. s = 'sS' if threads > 1 else '' # silent
  450. r = os.system(f"curl -{s}L '{url}' -o '{f}' --retry 9 -C -") # curl download
  451. success = r == 0
  452. else:
  453. torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
  454. success = f.is_file()
  455. if success:
  456. break
  457. elif i < retry:
  458. LOGGER.warning(f'Download failure, retrying {i + 1}/{retry} {url}...')
  459. else:
  460. LOGGER.warning(f'Failed to download {url}...')
  461. if unzip and success and f.suffix in ('.zip', '.gz'):
  462. LOGGER.info(f'Unzipping {f}...')
  463. if f.suffix == '.zip':
  464. ZipFile(f).extractall(path=dir) # unzip
  465. elif f.suffix == '.gz':
  466. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  467. if delete:
  468. f.unlink() # remove zip
  469. dir = Path(dir)
  470. dir.mkdir(parents=True, exist_ok=True) # make directory
  471. if threads > 1:
  472. pool = ThreadPool(threads)
  473. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  474. pool.close()
  475. pool.join()
  476. else:
  477. for u in [url] if isinstance(url, (str, Path)) else url:
  478. download_one(u, dir)
  479. def make_divisible(x, divisor):
  480. # Returns nearest x divisible by divisor
  481. if isinstance(divisor, torch.Tensor):
  482. divisor = int(divisor.max()) # to int
  483. return math.ceil(x / divisor) * divisor
  484. def clean_str(s):
  485. # Cleans a string by replacing special characters with underscore _
  486. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  487. def one_cycle(y1=0.0, y2=1.0, steps=100):
  488. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  489. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  490. def colorstr(*input):
  491. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  492. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  493. colors = {
  494. 'black': '\033[30m', # basic colors
  495. 'red': '\033[31m',
  496. 'green': '\033[32m',
  497. 'yellow': '\033[33m',
  498. 'blue': '\033[34m',
  499. 'magenta': '\033[35m',
  500. 'cyan': '\033[36m',
  501. 'white': '\033[37m',
  502. 'bright_black': '\033[90m', # bright colors
  503. 'bright_red': '\033[91m',
  504. 'bright_green': '\033[92m',
  505. 'bright_yellow': '\033[93m',
  506. 'bright_blue': '\033[94m',
  507. 'bright_magenta': '\033[95m',
  508. 'bright_cyan': '\033[96m',
  509. 'bright_white': '\033[97m',
  510. 'end': '\033[0m', # misc
  511. 'bold': '\033[1m',
  512. 'underline': '\033[4m'}
  513. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  514. def labels_to_class_weights(labels, nc=80):
  515. # Get class weights (inverse frequency) from training labels
  516. if labels[0] is None: # no labels loaded
  517. return torch.Tensor()
  518. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  519. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  520. weights = np.bincount(classes, minlength=nc) # occurrences per class
  521. # Prepend gridpoint count (for uCE training)
  522. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  523. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  524. weights[weights == 0] = 1 # replace empty bins with 1
  525. weights = 1 / weights # number of targets per class
  526. weights /= weights.sum() # normalize
  527. return torch.from_numpy(weights)
  528. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  529. # Produces image weights based on class_weights and image contents
  530. # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
  531. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  532. return (class_weights.reshape(1, nc) * class_counts).sum(1)
  533. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  534. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  535. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  536. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  537. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  538. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  539. return [
  540. 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  541. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  542. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  543. def xyxy2xywh(x):
  544. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  545. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  546. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  547. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  548. y[:, 2] = x[:, 2] - x[:, 0] # width
  549. y[:, 3] = x[:, 3] - x[:, 1] # height
  550. return y
  551. def xywh2xyxy(x):
  552. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  553. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  554. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  555. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  556. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  557. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  558. return y
  559. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  560. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  561. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  562. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  563. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  564. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  565. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  566. return y
  567. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  568. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  569. if clip:
  570. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  571. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  572. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  573. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  574. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  575. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  576. return y
  577. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  578. # Convert normalized segments into pixel segments, shape (n,2)
  579. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  580. y[:, 0] = w * x[:, 0] + padw # top left x
  581. y[:, 1] = h * x[:, 1] + padh # top left y
  582. return y
  583. def segment2box(segment, width=640, height=640):
  584. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  585. x, y = segment.T # segment xy
  586. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  587. x, y, = x[inside], y[inside]
  588. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  589. def segments2boxes(segments):
  590. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  591. boxes = []
  592. for s in segments:
  593. x, y = s.T # segment xy
  594. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  595. return xyxy2xywh(np.array(boxes)) # cls, xywh
  596. def resample_segments(segments, n=1000):
  597. # Up-sample an (n,2) segment
  598. for i, s in enumerate(segments):
  599. x = np.linspace(0, len(s) - 1, n)
  600. xp = np.arange(len(s))
  601. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  602. return segments
  603. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  604. # Rescale coords (xyxy) from img1_shape to img0_shape
  605. if ratio_pad is None: # calculate from img0_shape
  606. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  607. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  608. else:
  609. gain = ratio_pad[0][0]
  610. pad = ratio_pad[1]
  611. coords[:, [0, 2]] -= pad[0] # x padding
  612. coords[:, [1, 3]] -= pad[1] # y padding
  613. coords[:, :4] /= gain
  614. clip_coords(coords, img0_shape)
  615. return coords
  616. def clip_coords(boxes, shape):
  617. # Clip bounding xyxy bounding boxes to image shape (height, width)
  618. if isinstance(boxes, torch.Tensor): # faster individually
  619. boxes[:, 0].clamp_(0, shape[1]) # x1
  620. boxes[:, 1].clamp_(0, shape[0]) # y1
  621. boxes[:, 2].clamp_(0, shape[1]) # x2
  622. boxes[:, 3].clamp_(0, shape[0]) # y2
  623. else: # np.array (faster grouped)
  624. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  625. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  626. def non_max_suppression(prediction,
  627. conf_thres=0.25,
  628. iou_thres=0.45,
  629. classes=None,
  630. agnostic=False,
  631. multi_label=False,
  632. labels=(),
  633. max_det=300):
  634. """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes
  635. Returns:
  636. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  637. """
  638. bs = prediction.shape[0] # batch size
  639. nc = prediction.shape[2] - 5 # number of classes
  640. xc = prediction[..., 4] > conf_thres # candidates
  641. # Checks
  642. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  643. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  644. # Settings
  645. # min_wh = 2 # (pixels) minimum box width and height
  646. max_wh = 7680 # (pixels) maximum box width and height
  647. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  648. time_limit = 0.1 + 0.05 * bs # seconds to quit after
  649. redundant = True # require redundant detections
  650. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  651. merge = False # use merge-NMS
  652. t = time.time()
  653. output = [torch.zeros((0, 6), device=prediction.device)] * bs
  654. for xi, x in enumerate(prediction): # image index, image inference
  655. # Apply constraints
  656. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  657. x = x[xc[xi]] # confidence
  658. # Cat apriori labels if autolabelling
  659. if labels and len(labels[xi]):
  660. lb = labels[xi]
  661. v = torch.zeros((len(lb), nc + 5), device=x.device)
  662. v[:, :4] = lb[:, 1:5] # box
  663. v[:, 4] = 1.0 # conf
  664. v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
  665. x = torch.cat((x, v), 0)
  666. # If none remain process next image
  667. if not x.shape[0]:
  668. continue
  669. # Compute conf
  670. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  671. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  672. box = xywh2xyxy(x[:, :4])
  673. # Detections matrix nx6 (xyxy, conf, cls)
  674. if multi_label:
  675. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  676. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  677. else: # best class only
  678. conf, j = x[:, 5:].max(1, keepdim=True)
  679. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  680. # Filter by class
  681. if classes is not None:
  682. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  683. # Apply finite constraint
  684. # if not torch.isfinite(x).all():
  685. # x = x[torch.isfinite(x).all(1)]
  686. # Check shape
  687. n = x.shape[0] # number of boxes
  688. if not n: # no boxes
  689. continue
  690. elif n > max_nms: # excess boxes
  691. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  692. # Batched NMS
  693. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  694. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  695. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  696. if i.shape[0] > max_det: # limit detections
  697. i = i[:max_det]
  698. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  699. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  700. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  701. weights = iou * scores[None] # box weights
  702. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  703. if redundant:
  704. i = i[iou.sum(1) > 1] # require redundancy
  705. output[xi] = x[i]
  706. if (time.time() - t) > time_limit:
  707. LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
  708. break # time limit exceeded
  709. return output
  710. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  711. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  712. x = torch.load(f, map_location=torch.device('cpu'))
  713. if x.get('ema'):
  714. x['model'] = x['ema'] # replace model with ema
  715. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  716. x[k] = None
  717. x['epoch'] = -1
  718. x['model'].half() # to FP16
  719. for p in x['model'].parameters():
  720. p.requires_grad = False
  721. torch.save(x, s or f)
  722. mb = os.path.getsize(s or f) / 1E6 # filesize
  723. LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
  724. def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
  725. evolve_csv = save_dir / 'evolve.csv'
  726. evolve_yaml = save_dir / 'hyp_evolve.yaml'
  727. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
  728. 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  729. keys = tuple(x.strip() for x in keys)
  730. vals = results + tuple(hyp.values())
  731. n = len(keys)
  732. # Download (optional)
  733. if bucket:
  734. url = f'gs://{bucket}/evolve.csv'
  735. if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
  736. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  737. # Log to evolve.csv
  738. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  739. with open(evolve_csv, 'a') as f:
  740. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  741. # Save yaml
  742. with open(evolve_yaml, 'w') as f:
  743. data = pd.read_csv(evolve_csv)
  744. data = data.rename(columns=lambda x: x.strip()) # strip keys
  745. i = np.argmax(fitness(data.values[:, :4])) #
  746. generations = len(data)
  747. f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
  748. f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
  749. '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  750. yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
  751. # Print to screen
  752. LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
  753. ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
  754. for x in vals) + '\n\n')
  755. if bucket:
  756. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  757. def apply_classifier(x, model, img, im0):
  758. # Apply a second stage classifier to YOLO outputs
  759. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  760. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  761. for i, d in enumerate(x): # per image
  762. if d is not None and len(d):
  763. d = d.clone()
  764. # Reshape and pad cutouts
  765. b = xyxy2xywh(d[:, :4]) # boxes
  766. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  767. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  768. d[:, :4] = xywh2xyxy(b).long()
  769. # Rescale boxes from img_size to im0 size
  770. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  771. # Classes
  772. pred_cls1 = d[:, 5].long()
  773. ims = []
  774. for a in d:
  775. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  776. im = cv2.resize(cutout, (224, 224)) # BGR
  777. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  778. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  779. im /= 255 # 0 - 255 to 0.0 - 1.0
  780. ims.append(im)
  781. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  782. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  783. return x
  784. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  785. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  786. path = Path(path) # os-agnostic
  787. if path.exists() and not exist_ok:
  788. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  789. # Method 1
  790. for n in range(2, 9999):
  791. p = f'{path}{sep}{n}{suffix}' # increment path
  792. if not os.path.exists(p): #
  793. break
  794. path = Path(p)
  795. # Method 2 (deprecated)
  796. # dirs = glob.glob(f"{path}{sep}*") # similar paths
  797. # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
  798. # i = [int(m.groups()[0]) for m in matches if m] # indices
  799. # n = max(i) + 1 if i else 2 # increment number
  800. # path = Path(f"{path}{sep}{n}{suffix}") # increment path
  801. if mkdir:
  802. path.mkdir(parents=True, exist_ok=True) # make directory
  803. return path
  804. # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
  805. imshow_ = cv2.imshow # copy to avoid recursion errors
  806. def imread(path, flags=cv2.IMREAD_COLOR):
  807. return cv2.imdecode(np.fromfile(path, np.uint8), flags)
  808. def imwrite(path, im):
  809. try:
  810. cv2.imencode(Path(path).suffix, im)[1].tofile(path)
  811. return True
  812. except Exception:
  813. return False
  814. def imshow(path, im):
  815. imshow_(path.encode('unicode_escape').decode(), im)
  816. cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
  817. # Variables ------------------------------------------------------------------------------------------------------------
  818. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm