You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

849 lines
34KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import shutil
  14. import signal
  15. import time
  16. import urllib
  17. from itertools import repeat
  18. from multiprocessing.pool import ThreadPool
  19. from pathlib import Path
  20. from subprocess import check_output
  21. from zipfile import ZipFile
  22. import cv2
  23. import numpy as np
  24. import pandas as pd
  25. import pkg_resources as pkg
  26. import torch
  27. import torchvision
  28. import yaml
  29. from utils.downloads import gsutil_getsize
  30. from utils.metrics import box_iou, fitness
  31. # Settings
  32. FILE = Path(__file__).resolve()
  33. ROOT = FILE.parents[1] # YOLOv5 root directory
  34. NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
  35. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  36. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  37. pd.options.display.max_columns = 10
  38. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  39. os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
  40. def set_logging(name=None, verbose=True):
  41. # Sets level and returns logger
  42. for h in logging.root.handlers:
  43. logging.root.removeHandler(h) # remove all handlers associated with the root logger object
  44. rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
  45. logging.basicConfig(format="%(message)s", level=logging.INFO if (verbose and rank in (-1, 0)) else logging.WARNING)
  46. return logging.getLogger(name)
  47. LOGGER = set_logging(__name__) # define globally (used in train.py, val.py, detect.py, etc.)
  48. class Profile(contextlib.ContextDecorator):
  49. # Usage: @Profile() decorator or 'with Profile():' context manager
  50. def __enter__(self):
  51. self.start = time.time()
  52. def __exit__(self, type, value, traceback):
  53. print(f'Profile results: {time.time() - self.start:.5f}s')
  54. class Timeout(contextlib.ContextDecorator):
  55. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  56. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  57. self.seconds = int(seconds)
  58. self.timeout_message = timeout_msg
  59. self.suppress = bool(suppress_timeout_errors)
  60. def _timeout_handler(self, signum, frame):
  61. raise TimeoutError(self.timeout_message)
  62. def __enter__(self):
  63. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  64. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  65. def __exit__(self, exc_type, exc_val, exc_tb):
  66. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  67. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  68. return True
  69. class WorkingDirectory(contextlib.ContextDecorator):
  70. # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
  71. def __init__(self, new_dir):
  72. self.dir = new_dir # new dir
  73. self.cwd = Path.cwd().resolve() # current dir
  74. def __enter__(self):
  75. os.chdir(self.dir)
  76. def __exit__(self, exc_type, exc_val, exc_tb):
  77. os.chdir(self.cwd)
  78. def try_except(func):
  79. # try-except function. Usage: @try_except decorator
  80. def handler(*args, **kwargs):
  81. try:
  82. func(*args, **kwargs)
  83. except Exception as e:
  84. print(e)
  85. return handler
  86. def methods(instance):
  87. # Get class/instance methods
  88. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  89. def print_args(name, opt):
  90. # Print argparser arguments
  91. LOGGER.info(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  92. def init_seeds(seed=0):
  93. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  94. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  95. import torch.backends.cudnn as cudnn
  96. random.seed(seed)
  97. np.random.seed(seed)
  98. torch.manual_seed(seed)
  99. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  100. def intersect_dicts(da, db, exclude=()):
  101. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  102. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  103. def get_latest_run(search_dir='.'):
  104. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  105. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  106. return max(last_list, key=os.path.getctime) if last_list else ''
  107. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  108. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  109. env = os.getenv(env_var)
  110. if env:
  111. path = Path(env) # use environment variable
  112. else:
  113. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  114. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  115. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  116. path.mkdir(exist_ok=True) # make if required
  117. return path
  118. def is_writeable(dir, test=False):
  119. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  120. if test: # method 1
  121. file = Path(dir) / 'tmp.txt'
  122. try:
  123. with open(file, 'w'): # open file with write permissions
  124. pass
  125. file.unlink() # remove file
  126. return True
  127. except OSError:
  128. return False
  129. else: # method 2
  130. return os.access(dir, os.R_OK) # possible issues on Windows
  131. def is_docker():
  132. # Is environment a Docker container?
  133. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  134. def is_colab():
  135. # Is environment a Google Colab instance?
  136. try:
  137. import google.colab
  138. return True
  139. except ImportError:
  140. return False
  141. def is_pip():
  142. # Is file in a pip package?
  143. return 'site-packages' in Path(__file__).resolve().parts
  144. def is_ascii(s=''):
  145. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  146. s = str(s) # convert list, tuple, None, etc. to str
  147. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  148. def is_chinese(s='人工智能'):
  149. # Is string composed of any Chinese characters?
  150. return re.search('[\u4e00-\u9fff]', s)
  151. def emojis(str=''):
  152. # Return platform-dependent emoji-safe version of string
  153. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  154. def file_size(path):
  155. # Return file/dir size (MB)
  156. path = Path(path)
  157. if path.is_file():
  158. return path.stat().st_size / 1E6
  159. elif path.is_dir():
  160. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  161. else:
  162. return 0.0
  163. def check_online():
  164. # Check internet connectivity
  165. import socket
  166. try:
  167. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  168. return True
  169. except OSError:
  170. return False
  171. @try_except
  172. @WorkingDirectory(ROOT)
  173. def check_git_status():
  174. # Recommend 'git pull' if code is out of date
  175. msg = ', for updates see https://github.com/ultralytics/yolov5'
  176. print(colorstr('github: '), end='')
  177. assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
  178. assert not is_docker(), 'skipping check (Docker image)' + msg
  179. assert check_online(), 'skipping check (offline)' + msg
  180. cmd = 'git fetch && git config --get remote.origin.url'
  181. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  182. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  183. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  184. if n > 0:
  185. s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  186. else:
  187. s = f'up to date with {url} ✅'
  188. print(emojis(s)) # emoji-safe
  189. def check_python(minimum='3.6.2'):
  190. # Check current python version vs. required python version
  191. check_version(platform.python_version(), minimum, name='Python ', hard=True)
  192. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
  193. # Check version vs. required version
  194. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  195. result = (current == minimum) if pinned else (current >= minimum) # bool
  196. s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
  197. if hard:
  198. assert result, s # assert min requirements met
  199. if verbose and not result:
  200. LOGGER.warning(s)
  201. return result
  202. @try_except
  203. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  204. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  205. prefix = colorstr('red', 'bold', 'requirements:')
  206. check_python() # check python version
  207. if isinstance(requirements, (str, Path)): # requirements.txt file
  208. file = Path(requirements)
  209. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  210. with file.open() as f:
  211. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
  212. else: # list or tuple of packages
  213. requirements = [x for x in requirements if x not in exclude]
  214. n = 0 # number of packages updates
  215. for r in requirements:
  216. try:
  217. pkg.require(r)
  218. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  219. s = f"{prefix} {r} not found and is required by YOLOv5"
  220. if install:
  221. print(f"{s}, attempting auto-update...")
  222. try:
  223. assert check_online(), f"'pip install {r}' skipped (offline)"
  224. print(check_output(f"pip install '{r}'", shell=True).decode())
  225. n += 1
  226. except Exception as e:
  227. print(f'{prefix} {e}')
  228. else:
  229. print(f'{s}. Please install and rerun your command.')
  230. if n: # if packages updated
  231. source = file.resolve() if 'file' in locals() else requirements
  232. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  233. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  234. print(emojis(s))
  235. def check_img_size(imgsz, s=32, floor=0):
  236. # Verify image size is a multiple of stride s in each dimension
  237. if isinstance(imgsz, int): # integer i.e. img_size=640
  238. new_size = max(make_divisible(imgsz, int(s)), floor)
  239. else: # list i.e. img_size=[640, 480]
  240. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  241. if new_size != imgsz:
  242. print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  243. return new_size
  244. def check_imshow():
  245. # Check if environment supports image displays
  246. try:
  247. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  248. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  249. cv2.imshow('test', np.zeros((1, 1, 3)))
  250. cv2.waitKey(1)
  251. cv2.destroyAllWindows()
  252. cv2.waitKey(1)
  253. return True
  254. except Exception as e:
  255. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  256. return False
  257. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  258. # Check file(s) for acceptable suffix
  259. if file and suffix:
  260. if isinstance(suffix, str):
  261. suffix = [suffix]
  262. for f in file if isinstance(file, (list, tuple)) else [file]:
  263. s = Path(f).suffix.lower() # file suffix
  264. if len(s):
  265. assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  266. def check_yaml(file, suffix=('.yaml', '.yml')):
  267. # Search/download YAML file (if necessary) and return path, checking suffix
  268. return check_file(file, suffix)
  269. def check_file(file, suffix=''):
  270. # Search/download file (if necessary) and return path
  271. check_suffix(file, suffix) # optional
  272. file = str(file) # convert to str()
  273. if Path(file).is_file() or file == '': # exists
  274. return file
  275. elif file.startswith(('http:/', 'https:/')): # download
  276. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  277. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  278. if Path(file).is_file():
  279. print(f'Found {url} locally at {file}') # file already exists
  280. else:
  281. print(f'Downloading {url} to {file}...')
  282. torch.hub.download_url_to_file(url, file)
  283. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  284. return file
  285. else: # search
  286. files = []
  287. for d in 'data', 'models', 'utils': # search directories
  288. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  289. assert len(files), f'File not found: {file}' # assert file was found
  290. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  291. return files[0] # return file
  292. def check_dataset(data, autodownload=True):
  293. # Download and/or unzip dataset if not found locally
  294. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  295. # Download (optional)
  296. extract_dir = ''
  297. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  298. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  299. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  300. extract_dir, autodownload = data.parent, False
  301. # Read yaml (optional)
  302. if isinstance(data, (str, Path)):
  303. with open(data, errors='ignore') as f:
  304. data = yaml.safe_load(f) # dictionary
  305. # Parse yaml
  306. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  307. for k in 'train', 'val', 'test':
  308. if data.get(k): # prepend path
  309. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  310. assert 'nc' in data, "Dataset 'nc' key missing."
  311. if 'names' not in data:
  312. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  313. train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
  314. if val:
  315. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  316. if not all(x.exists() for x in val):
  317. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  318. if s and autodownload: # download script
  319. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  320. if s.startswith('http') and s.endswith('.zip'): # URL
  321. f = Path(s).name # filename
  322. print(f'Downloading {s} to {f}...')
  323. torch.hub.download_url_to_file(s, f)
  324. Path(root).mkdir(parents=True, exist_ok=True) # create root
  325. ZipFile(f).extractall(path=root) # unzip
  326. Path(f).unlink() # remove zip
  327. r = None # success
  328. elif s.startswith('bash '): # bash script
  329. print(f'Running {s} ...')
  330. r = os.system(s)
  331. else: # python script
  332. r = exec(s, {'yaml': data}) # return None
  333. print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n")
  334. else:
  335. raise Exception('Dataset not found.')
  336. return data # dictionary
  337. def url2file(url):
  338. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  339. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  340. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  341. return file
  342. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  343. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  344. def download_one(url, dir):
  345. # Download 1 file
  346. f = dir / Path(url).name # filename
  347. if Path(url).is_file(): # exists in current path
  348. Path(url).rename(f) # move to dir
  349. elif not f.exists():
  350. print(f'Downloading {url} to {f}...')
  351. if curl:
  352. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  353. else:
  354. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  355. if unzip and f.suffix in ('.zip', '.gz'):
  356. print(f'Unzipping {f}...')
  357. if f.suffix == '.zip':
  358. ZipFile(f).extractall(path=dir) # unzip
  359. elif f.suffix == '.gz':
  360. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  361. if delete:
  362. f.unlink() # remove zip
  363. dir = Path(dir)
  364. dir.mkdir(parents=True, exist_ok=True) # make directory
  365. if threads > 1:
  366. pool = ThreadPool(threads)
  367. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  368. pool.close()
  369. pool.join()
  370. else:
  371. for u in [url] if isinstance(url, (str, Path)) else url:
  372. download_one(u, dir)
  373. def make_divisible(x, divisor):
  374. # Returns nearest x divisible by divisor
  375. if isinstance(divisor, torch.Tensor):
  376. divisor = int(divisor.max()) # to int
  377. return math.ceil(x / divisor) * divisor
  378. def clean_str(s):
  379. # Cleans a string by replacing special characters with underscore _
  380. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  381. def one_cycle(y1=0.0, y2=1.0, steps=100):
  382. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  383. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  384. def colorstr(*input):
  385. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  386. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  387. colors = {'black': '\033[30m', # basic colors
  388. 'red': '\033[31m',
  389. 'green': '\033[32m',
  390. 'yellow': '\033[33m',
  391. 'blue': '\033[34m',
  392. 'magenta': '\033[35m',
  393. 'cyan': '\033[36m',
  394. 'white': '\033[37m',
  395. 'bright_black': '\033[90m', # bright colors
  396. 'bright_red': '\033[91m',
  397. 'bright_green': '\033[92m',
  398. 'bright_yellow': '\033[93m',
  399. 'bright_blue': '\033[94m',
  400. 'bright_magenta': '\033[95m',
  401. 'bright_cyan': '\033[96m',
  402. 'bright_white': '\033[97m',
  403. 'end': '\033[0m', # misc
  404. 'bold': '\033[1m',
  405. 'underline': '\033[4m'}
  406. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  407. def labels_to_class_weights(labels, nc=80):
  408. # Get class weights (inverse frequency) from training labels
  409. if labels[0] is None: # no labels loaded
  410. return torch.Tensor()
  411. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  412. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  413. weights = np.bincount(classes, minlength=nc) # occurrences per class
  414. # Prepend gridpoint count (for uCE training)
  415. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  416. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  417. weights[weights == 0] = 1 # replace empty bins with 1
  418. weights = 1 / weights # number of targets per class
  419. weights /= weights.sum() # normalize
  420. return torch.from_numpy(weights)
  421. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  422. # Produces image weights based on class_weights and image contents
  423. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  424. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  425. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  426. return image_weights
  427. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  428. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  429. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  430. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  431. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  432. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  433. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  434. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  435. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  436. return x
  437. def xyxy2xywh(x):
  438. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  439. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  440. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  441. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  442. y[:, 2] = x[:, 2] - x[:, 0] # width
  443. y[:, 3] = x[:, 3] - x[:, 1] # height
  444. return y
  445. def xywh2xyxy(x):
  446. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  447. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  448. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  449. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  450. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  451. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  452. return y
  453. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  454. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  455. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  456. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  457. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  458. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  459. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  460. return y
  461. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  462. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  463. if clip:
  464. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  465. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  466. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  467. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  468. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  469. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  470. return y
  471. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  472. # Convert normalized segments into pixel segments, shape (n,2)
  473. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  474. y[:, 0] = w * x[:, 0] + padw # top left x
  475. y[:, 1] = h * x[:, 1] + padh # top left y
  476. return y
  477. def segment2box(segment, width=640, height=640):
  478. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  479. x, y = segment.T # segment xy
  480. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  481. x, y, = x[inside], y[inside]
  482. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  483. def segments2boxes(segments):
  484. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  485. boxes = []
  486. for s in segments:
  487. x, y = s.T # segment xy
  488. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  489. return xyxy2xywh(np.array(boxes)) # cls, xywh
  490. def resample_segments(segments, n=1000):
  491. # Up-sample an (n,2) segment
  492. for i, s in enumerate(segments):
  493. x = np.linspace(0, len(s) - 1, n)
  494. xp = np.arange(len(s))
  495. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  496. return segments
  497. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  498. # Rescale coords (xyxy) from img1_shape to img0_shape
  499. if ratio_pad is None: # calculate from img0_shape
  500. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  501. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  502. else:
  503. gain = ratio_pad[0][0]
  504. pad = ratio_pad[1]
  505. coords[:, [0, 2]] -= pad[0] # x padding
  506. coords[:, [1, 3]] -= pad[1] # y padding
  507. coords[:, :4] /= gain
  508. clip_coords(coords, img0_shape)
  509. return coords
  510. def clip_coords(boxes, shape):
  511. # Clip bounding xyxy bounding boxes to image shape (height, width)
  512. if isinstance(boxes, torch.Tensor): # faster individually
  513. boxes[:, 0].clamp_(0, shape[1]) # x1
  514. boxes[:, 1].clamp_(0, shape[0]) # y1
  515. boxes[:, 2].clamp_(0, shape[1]) # x2
  516. boxes[:, 3].clamp_(0, shape[0]) # y2
  517. else: # np.array (faster grouped)
  518. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  519. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  520. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  521. labels=(), max_det=300):
  522. """Runs Non-Maximum Suppression (NMS) on inference results
  523. Returns:
  524. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  525. """
  526. nc = prediction.shape[2] - 5 # number of classes
  527. xc = prediction[..., 4] > conf_thres # candidates
  528. # Checks
  529. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  530. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  531. # Settings
  532. min_wh, max_wh = 2, 7680 # (pixels) minimum and maximum box width and height
  533. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  534. time_limit = 10.0 # seconds to quit after
  535. redundant = True # require redundant detections
  536. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  537. merge = False # use merge-NMS
  538. t = time.time()
  539. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  540. for xi, x in enumerate(prediction): # image index, image inference
  541. # Apply constraints
  542. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  543. x = x[xc[xi]] # confidence
  544. # Cat apriori labels if autolabelling
  545. if labels and len(labels[xi]):
  546. l = labels[xi]
  547. v = torch.zeros((len(l), nc + 5), device=x.device)
  548. v[:, :4] = l[:, 1:5] # box
  549. v[:, 4] = 1.0 # conf
  550. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  551. x = torch.cat((x, v), 0)
  552. # If none remain process next image
  553. if not x.shape[0]:
  554. continue
  555. # Compute conf
  556. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  557. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  558. box = xywh2xyxy(x[:, :4])
  559. # Detections matrix nx6 (xyxy, conf, cls)
  560. if multi_label:
  561. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  562. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  563. else: # best class only
  564. conf, j = x[:, 5:].max(1, keepdim=True)
  565. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  566. # Filter by class
  567. if classes is not None:
  568. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  569. # Apply finite constraint
  570. # if not torch.isfinite(x).all():
  571. # x = x[torch.isfinite(x).all(1)]
  572. # Check shape
  573. n = x.shape[0] # number of boxes
  574. if not n: # no boxes
  575. continue
  576. elif n > max_nms: # excess boxes
  577. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  578. # Batched NMS
  579. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  580. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  581. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  582. if i.shape[0] > max_det: # limit detections
  583. i = i[:max_det]
  584. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  585. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  586. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  587. weights = iou * scores[None] # box weights
  588. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  589. if redundant:
  590. i = i[iou.sum(1) > 1] # require redundancy
  591. output[xi] = x[i]
  592. if (time.time() - t) > time_limit:
  593. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  594. break # time limit exceeded
  595. return output
  596. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  597. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  598. x = torch.load(f, map_location=torch.device('cpu'))
  599. if x.get('ema'):
  600. x['model'] = x['ema'] # replace model with ema
  601. for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys
  602. x[k] = None
  603. x['epoch'] = -1
  604. x['model'].half() # to FP16
  605. for p in x['model'].parameters():
  606. p.requires_grad = False
  607. torch.save(x, s or f)
  608. mb = os.path.getsize(s or f) / 1E6 # filesize
  609. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  610. def print_mutation(results, hyp, save_dir, bucket):
  611. evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
  612. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  613. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  614. keys = tuple(x.strip() for x in keys)
  615. vals = results + tuple(hyp.values())
  616. n = len(keys)
  617. # Download (optional)
  618. if bucket:
  619. url = f'gs://{bucket}/evolve.csv'
  620. if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
  621. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  622. # Log to evolve.csv
  623. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  624. with open(evolve_csv, 'a') as f:
  625. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  626. # Print to screen
  627. print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
  628. print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
  629. # Save yaml
  630. with open(evolve_yaml, 'w') as f:
  631. data = pd.read_csv(evolve_csv)
  632. data = data.rename(columns=lambda x: x.strip()) # strip keys
  633. i = np.argmax(fitness(data.values[:, :7])) #
  634. f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
  635. f'# Best generation: {i}\n' +
  636. f'# Last generation: {len(data) - 1}\n' +
  637. '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  638. '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  639. yaml.safe_dump(hyp, f, sort_keys=False)
  640. if bucket:
  641. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  642. def apply_classifier(x, model, img, im0):
  643. # Apply a second stage classifier to YOLO outputs
  644. # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
  645. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  646. for i, d in enumerate(x): # per image
  647. if d is not None and len(d):
  648. d = d.clone()
  649. # Reshape and pad cutouts
  650. b = xyxy2xywh(d[:, :4]) # boxes
  651. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  652. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  653. d[:, :4] = xywh2xyxy(b).long()
  654. # Rescale boxes from img_size to im0 size
  655. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  656. # Classes
  657. pred_cls1 = d[:, 5].long()
  658. ims = []
  659. for j, a in enumerate(d): # per item
  660. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  661. im = cv2.resize(cutout, (224, 224)) # BGR
  662. # cv2.imwrite('example%i.jpg' % j, cutout)
  663. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  664. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  665. im /= 255 # 0 - 255 to 0.0 - 1.0
  666. ims.append(im)
  667. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  668. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  669. return x
  670. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  671. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  672. path = Path(path) # os-agnostic
  673. if path.exists() and not exist_ok:
  674. path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
  675. dirs = glob.glob(f"{path}{sep}*") # similar paths
  676. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  677. i = [int(m.groups()[0]) for m in matches if m] # indices
  678. n = max(i) + 1 if i else 2 # increment number
  679. path = Path(f"{path}{sep}{n}{suffix}") # increment path
  680. if mkdir:
  681. path.mkdir(parents=True, exist_ok=True) # make directory
  682. return path
  683. # Variables
  684. NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm