You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

814 lines
33KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import signal
  14. import time
  15. import urllib
  16. from itertools import repeat
  17. from multiprocessing.pool import ThreadPool
  18. from pathlib import Path
  19. from subprocess import check_output
  20. from zipfile import ZipFile
  21. import cv2
  22. import numpy as np
  23. import pandas as pd
  24. import pkg_resources as pkg
  25. import torch
  26. import torchvision
  27. import yaml
  28. from utils.downloads import gsutil_getsize
  29. from utils.metrics import box_iou, fitness
  30. # Settings
  31. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  32. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  33. pd.options.display.max_columns = 10
  34. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  35. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  36. FILE = Path(__file__).resolve()
  37. ROOT = FILE.parents[1] # YOLOv5 root directory
  38. class Profile(contextlib.ContextDecorator):
  39. # Usage: @Profile() decorator or 'with Profile():' context manager
  40. def __enter__(self):
  41. self.start = time.time()
  42. def __exit__(self, type, value, traceback):
  43. print(f'Profile results: {time.time() - self.start:.5f}s')
  44. class Timeout(contextlib.ContextDecorator):
  45. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  46. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  47. self.seconds = int(seconds)
  48. self.timeout_message = timeout_msg
  49. self.suppress = bool(suppress_timeout_errors)
  50. def _timeout_handler(self, signum, frame):
  51. raise TimeoutError(self.timeout_message)
  52. def __enter__(self):
  53. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  54. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  55. def __exit__(self, exc_type, exc_val, exc_tb):
  56. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  57. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  58. return True
  59. def try_except(func):
  60. # try-except function. Usage: @try_except decorator
  61. def handler(*args, **kwargs):
  62. try:
  63. func(*args, **kwargs)
  64. except Exception as e:
  65. print(e)
  66. return handler
  67. def methods(instance):
  68. # Get class/instance methods
  69. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  70. def set_logging(rank=-1, verbose=True):
  71. logging.basicConfig(
  72. format="%(message)s",
  73. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  74. def print_args(name, opt):
  75. # Print argparser arguments
  76. print(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  77. def init_seeds(seed=0):
  78. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  79. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  80. import torch.backends.cudnn as cudnn
  81. random.seed(seed)
  82. np.random.seed(seed)
  83. torch.manual_seed(seed)
  84. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  85. def get_latest_run(search_dir='.'):
  86. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  87. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  88. return max(last_list, key=os.path.getctime) if last_list else ''
  89. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  90. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  91. env = os.getenv(env_var)
  92. if env:
  93. path = Path(env) # use environment variable
  94. else:
  95. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  96. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  97. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  98. path.mkdir(exist_ok=True) # make if required
  99. return path
  100. def is_writeable(dir, test=False):
  101. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  102. if test: # method 1
  103. file = Path(dir) / 'tmp.txt'
  104. try:
  105. with open(file, 'w'): # open file with write permissions
  106. pass
  107. file.unlink() # remove file
  108. return True
  109. except IOError:
  110. return False
  111. else: # method 2
  112. return os.access(dir, os.R_OK) # possible issues on Windows
  113. def is_docker():
  114. # Is environment a Docker container?
  115. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  116. def is_colab():
  117. # Is environment a Google Colab instance?
  118. try:
  119. import google.colab
  120. return True
  121. except Exception as e:
  122. return False
  123. def is_pip():
  124. # Is file in a pip package?
  125. return 'site-packages' in Path(__file__).resolve().parts
  126. def is_ascii(s=''):
  127. # Is string composed of all ASCII (no UTF) characters?
  128. s = str(s) # convert list, tuple, None, etc. to str
  129. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  130. def emojis(str=''):
  131. # Return platform-dependent emoji-safe version of string
  132. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  133. def file_size(path):
  134. # Return file/dir size (MB)
  135. path = Path(path)
  136. if path.is_file():
  137. return path.stat().st_size / 1E6
  138. elif path.is_dir():
  139. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  140. else:
  141. return 0.0
  142. def check_online():
  143. # Check internet connectivity
  144. import socket
  145. try:
  146. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  147. return True
  148. except OSError:
  149. return False
  150. @try_except
  151. def check_git_status():
  152. # Recommend 'git pull' if code is out of date
  153. msg = ', for updates see https://github.com/ultralytics/yolov5'
  154. print(colorstr('github: '), end='')
  155. assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
  156. assert not is_docker(), 'skipping check (Docker image)' + msg
  157. assert check_online(), 'skipping check (offline)' + msg
  158. cmd = 'git fetch && git config --get remote.origin.url'
  159. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  160. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  161. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  162. if n > 0:
  163. s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  164. else:
  165. s = f'up to date with {url} ✅'
  166. print(emojis(s)) # emoji-safe
  167. def check_python(minimum='3.6.2'):
  168. # Check current python version vs. required python version
  169. check_version(platform.python_version(), minimum, name='Python ')
  170. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
  171. # Check version vs. required version
  172. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  173. result = (current == minimum) if pinned else (current >= minimum)
  174. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  175. @try_except
  176. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  177. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  178. prefix = colorstr('red', 'bold', 'requirements:')
  179. check_python() # check python version
  180. if isinstance(requirements, (str, Path)): # requirements.txt file
  181. file = Path(requirements)
  182. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  183. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  184. else: # list or tuple of packages
  185. requirements = [x for x in requirements if x not in exclude]
  186. n = 0 # number of packages updates
  187. for r in requirements:
  188. try:
  189. pkg.require(r)
  190. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  191. s = f"{prefix} {r} not found and is required by YOLOv5"
  192. if install:
  193. print(f"{s}, attempting auto-update...")
  194. try:
  195. assert check_online(), f"'pip install {r}' skipped (offline)"
  196. print(check_output(f"pip install '{r}'", shell=True).decode())
  197. n += 1
  198. except Exception as e:
  199. print(f'{prefix} {e}')
  200. else:
  201. print(f'{s}. Please install and rerun your command.')
  202. if n: # if packages updated
  203. source = file.resolve() if 'file' in locals() else requirements
  204. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  205. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  206. print(emojis(s))
  207. def check_img_size(imgsz, s=32, floor=0):
  208. # Verify image size is a multiple of stride s in each dimension
  209. if isinstance(imgsz, int): # integer i.e. img_size=640
  210. new_size = max(make_divisible(imgsz, int(s)), floor)
  211. else: # list i.e. img_size=[640, 480]
  212. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  213. if new_size != imgsz:
  214. print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  215. return new_size
  216. def check_imshow():
  217. # Check if environment supports image displays
  218. try:
  219. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  220. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  221. cv2.imshow('test', np.zeros((1, 1, 3)))
  222. cv2.waitKey(1)
  223. cv2.destroyAllWindows()
  224. cv2.waitKey(1)
  225. return True
  226. except Exception as e:
  227. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  228. return False
  229. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  230. # Check file(s) for acceptable suffixes
  231. if file and suffix:
  232. if isinstance(suffix, str):
  233. suffix = [suffix]
  234. for f in file if isinstance(file, (list, tuple)) else [file]:
  235. assert Path(f).suffix.lower() in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  236. def check_yaml(file, suffix=('.yaml', '.yml')):
  237. # Search/download YAML file (if necessary) and return path, checking suffix
  238. return check_file(file, suffix)
  239. def check_file(file, suffix=''):
  240. # Search/download file (if necessary) and return path
  241. check_suffix(file, suffix) # optional
  242. file = str(file) # convert to str()
  243. if Path(file).is_file() or file == '': # exists
  244. return file
  245. elif file.startswith(('http:/', 'https:/')): # download
  246. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  247. file = Path(urllib.parse.unquote(file)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  248. print(f'Downloading {url} to {file}...')
  249. torch.hub.download_url_to_file(url, file)
  250. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  251. return file
  252. else: # search
  253. files = glob.glob('./**/' + file, recursive=True) # find file
  254. assert len(files), f'File not found: {file}' # assert file was found
  255. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  256. return files[0] # return file
  257. def check_dataset(data, autodownload=True):
  258. # Download and/or unzip dataset if not found locally
  259. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  260. # Download (optional)
  261. extract_dir = ''
  262. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  263. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  264. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  265. extract_dir, autodownload = data.parent, False
  266. # Read yaml (optional)
  267. if isinstance(data, (str, Path)):
  268. with open(data, errors='ignore') as f:
  269. data = yaml.safe_load(f) # dictionary
  270. # Parse yaml
  271. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  272. for k in 'train', 'val', 'test':
  273. if data.get(k): # prepend path
  274. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  275. assert 'nc' in data, "Dataset 'nc' key missing."
  276. if 'names' not in data:
  277. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  278. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  279. if val:
  280. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  281. if not all(x.exists() for x in val):
  282. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  283. if s and autodownload: # download script
  284. if s.startswith('http') and s.endswith('.zip'): # URL
  285. f = Path(s).name # filename
  286. print(f'Downloading {s} to {f}...')
  287. torch.hub.download_url_to_file(s, f)
  288. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  289. Path(root).mkdir(parents=True, exist_ok=True) # create root
  290. ZipFile(f).extractall(path=root) # unzip
  291. Path(f).unlink() # remove zip
  292. r = None # success
  293. elif s.startswith('bash '): # bash script
  294. print(f'Running {s} ...')
  295. r = os.system(s)
  296. else: # python script
  297. r = exec(s, {'yaml': data}) # return None
  298. print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}")
  299. else:
  300. raise Exception('Dataset not found.')
  301. return data # dictionary
  302. def url2file(url):
  303. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  304. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  305. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  306. return file
  307. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  308. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  309. def download_one(url, dir):
  310. # Download 1 file
  311. f = dir / Path(url).name # filename
  312. if Path(url).is_file(): # exists in current path
  313. Path(url).rename(f) # move to dir
  314. elif not f.exists():
  315. print(f'Downloading {url} to {f}...')
  316. if curl:
  317. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  318. else:
  319. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  320. if unzip and f.suffix in ('.zip', '.gz'):
  321. print(f'Unzipping {f}...')
  322. if f.suffix == '.zip':
  323. ZipFile(f).extractall(path=dir) # unzip
  324. elif f.suffix == '.gz':
  325. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  326. if delete:
  327. f.unlink() # remove zip
  328. dir = Path(dir)
  329. dir.mkdir(parents=True, exist_ok=True) # make directory
  330. if threads > 1:
  331. pool = ThreadPool(threads)
  332. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  333. pool.close()
  334. pool.join()
  335. else:
  336. for u in [url] if isinstance(url, (str, Path)) else url:
  337. download_one(u, dir)
  338. def make_divisible(x, divisor):
  339. # Returns x evenly divisible by divisor
  340. return math.ceil(x / divisor) * divisor
  341. def clean_str(s):
  342. # Cleans a string by replacing special characters with underscore _
  343. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  344. def one_cycle(y1=0.0, y2=1.0, steps=100):
  345. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  346. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  347. def colorstr(*input):
  348. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  349. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  350. colors = {'black': '\033[30m', # basic colors
  351. 'red': '\033[31m',
  352. 'green': '\033[32m',
  353. 'yellow': '\033[33m',
  354. 'blue': '\033[34m',
  355. 'magenta': '\033[35m',
  356. 'cyan': '\033[36m',
  357. 'white': '\033[37m',
  358. 'bright_black': '\033[90m', # bright colors
  359. 'bright_red': '\033[91m',
  360. 'bright_green': '\033[92m',
  361. 'bright_yellow': '\033[93m',
  362. 'bright_blue': '\033[94m',
  363. 'bright_magenta': '\033[95m',
  364. 'bright_cyan': '\033[96m',
  365. 'bright_white': '\033[97m',
  366. 'end': '\033[0m', # misc
  367. 'bold': '\033[1m',
  368. 'underline': '\033[4m'}
  369. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  370. def labels_to_class_weights(labels, nc=80):
  371. # Get class weights (inverse frequency) from training labels
  372. if labels[0] is None: # no labels loaded
  373. return torch.Tensor()
  374. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  375. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  376. weights = np.bincount(classes, minlength=nc) # occurrences per class
  377. # Prepend gridpoint count (for uCE training)
  378. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  379. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  380. weights[weights == 0] = 1 # replace empty bins with 1
  381. weights = 1 / weights # number of targets per class
  382. weights /= weights.sum() # normalize
  383. return torch.from_numpy(weights)
  384. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  385. # Produces image weights based on class_weights and image contents
  386. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  387. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  388. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  389. return image_weights
  390. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  391. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  392. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  393. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  394. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  395. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  396. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  397. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  398. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  399. return x
  400. def xyxy2xywh(x):
  401. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  402. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  403. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  404. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  405. y[:, 2] = x[:, 2] - x[:, 0] # width
  406. y[:, 3] = x[:, 3] - x[:, 1] # height
  407. return y
  408. def xywh2xyxy(x):
  409. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  410. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  411. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  412. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  413. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  414. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  415. return y
  416. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  417. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  418. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  419. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  420. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  421. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  422. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  423. return y
  424. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  425. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  426. if clip:
  427. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  428. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  429. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  430. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  431. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  432. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  433. return y
  434. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  435. # Convert normalized segments into pixel segments, shape (n,2)
  436. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  437. y[:, 0] = w * x[:, 0] + padw # top left x
  438. y[:, 1] = h * x[:, 1] + padh # top left y
  439. return y
  440. def segment2box(segment, width=640, height=640):
  441. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  442. x, y = segment.T # segment xy
  443. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  444. x, y, = x[inside], y[inside]
  445. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  446. def segments2boxes(segments):
  447. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  448. boxes = []
  449. for s in segments:
  450. x, y = s.T # segment xy
  451. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  452. return xyxy2xywh(np.array(boxes)) # cls, xywh
  453. def resample_segments(segments, n=1000):
  454. # Up-sample an (n,2) segment
  455. for i, s in enumerate(segments):
  456. x = np.linspace(0, len(s) - 1, n)
  457. xp = np.arange(len(s))
  458. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  459. return segments
  460. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  461. # Rescale coords (xyxy) from img1_shape to img0_shape
  462. if ratio_pad is None: # calculate from img0_shape
  463. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  464. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  465. else:
  466. gain = ratio_pad[0][0]
  467. pad = ratio_pad[1]
  468. coords[:, [0, 2]] -= pad[0] # x padding
  469. coords[:, [1, 3]] -= pad[1] # y padding
  470. coords[:, :4] /= gain
  471. clip_coords(coords, img0_shape)
  472. return coords
  473. def clip_coords(boxes, shape):
  474. # Clip bounding xyxy bounding boxes to image shape (height, width)
  475. if isinstance(boxes, torch.Tensor): # faster individually
  476. boxes[:, 0].clamp_(0, shape[1]) # x1
  477. boxes[:, 1].clamp_(0, shape[0]) # y1
  478. boxes[:, 2].clamp_(0, shape[1]) # x2
  479. boxes[:, 3].clamp_(0, shape[0]) # y2
  480. else: # np.array (faster grouped)
  481. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  482. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  483. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  484. labels=(), max_det=300):
  485. """Runs Non-Maximum Suppression (NMS) on inference results
  486. Returns:
  487. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  488. """
  489. nc = prediction.shape[2] - 5 # number of classes
  490. xc = prediction[..., 4] > conf_thres # candidates
  491. # Checks
  492. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  493. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  494. # Settings
  495. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  496. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  497. time_limit = 10.0 # seconds to quit after
  498. redundant = True # require redundant detections
  499. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  500. merge = False # use merge-NMS
  501. t = time.time()
  502. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  503. for xi, x in enumerate(prediction): # image index, image inference
  504. # Apply constraints
  505. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  506. x = x[xc[xi]] # confidence
  507. # Cat apriori labels if autolabelling
  508. if labels and len(labels[xi]):
  509. l = labels[xi]
  510. v = torch.zeros((len(l), nc + 5), device=x.device)
  511. v[:, :4] = l[:, 1:5] # box
  512. v[:, 4] = 1.0 # conf
  513. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  514. x = torch.cat((x, v), 0)
  515. # If none remain process next image
  516. if not x.shape[0]:
  517. continue
  518. # Compute conf
  519. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  520. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  521. box = xywh2xyxy(x[:, :4])
  522. # Detections matrix nx6 (xyxy, conf, cls)
  523. if multi_label:
  524. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  525. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  526. else: # best class only
  527. conf, j = x[:, 5:].max(1, keepdim=True)
  528. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  529. # Filter by class
  530. if classes is not None:
  531. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  532. # Apply finite constraint
  533. # if not torch.isfinite(x).all():
  534. # x = x[torch.isfinite(x).all(1)]
  535. # Check shape
  536. n = x.shape[0] # number of boxes
  537. if not n: # no boxes
  538. continue
  539. elif n > max_nms: # excess boxes
  540. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  541. # Batched NMS
  542. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  543. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  544. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  545. if i.shape[0] > max_det: # limit detections
  546. i = i[:max_det]
  547. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  548. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  549. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  550. weights = iou * scores[None] # box weights
  551. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  552. if redundant:
  553. i = i[iou.sum(1) > 1] # require redundancy
  554. output[xi] = x[i]
  555. if (time.time() - t) > time_limit:
  556. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  557. break # time limit exceeded
  558. return output
  559. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  560. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  561. x = torch.load(f, map_location=torch.device('cpu'))
  562. if x.get('ema'):
  563. x['model'] = x['ema'] # replace model with ema
  564. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  565. x[k] = None
  566. x['epoch'] = -1
  567. x['model'].half() # to FP16
  568. for p in x['model'].parameters():
  569. p.requires_grad = False
  570. torch.save(x, s or f)
  571. mb = os.path.getsize(s or f) / 1E6 # filesize
  572. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  573. def print_mutation(results, hyp, save_dir, bucket):
  574. evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
  575. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  576. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  577. keys = tuple(x.strip() for x in keys)
  578. vals = results + tuple(hyp.values())
  579. n = len(keys)
  580. # Download (optional)
  581. if bucket:
  582. url = f'gs://{bucket}/evolve.csv'
  583. if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
  584. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  585. # Log to evolve.csv
  586. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  587. with open(evolve_csv, 'a') as f:
  588. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  589. # Print to screen
  590. print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
  591. print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
  592. # Save yaml
  593. with open(evolve_yaml, 'w') as f:
  594. data = pd.read_csv(evolve_csv)
  595. data = data.rename(columns=lambda x: x.strip()) # strip keys
  596. i = np.argmax(fitness(data.values[:, :7])) #
  597. f.write(f'# YOLOv5 Hyperparameter Evolution Results\n' +
  598. f'# Best generation: {i}\n' +
  599. f'# Last generation: {len(data)}\n' +
  600. f'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  601. f'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  602. yaml.safe_dump(hyp, f, sort_keys=False)
  603. if bucket:
  604. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  605. def apply_classifier(x, model, img, im0):
  606. # Apply a second stage classifier to yolo outputs
  607. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  608. for i, d in enumerate(x): # per image
  609. if d is not None and len(d):
  610. d = d.clone()
  611. # Reshape and pad cutouts
  612. b = xyxy2xywh(d[:, :4]) # boxes
  613. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  614. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  615. d[:, :4] = xywh2xyxy(b).long()
  616. # Rescale boxes from img_size to im0 size
  617. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  618. # Classes
  619. pred_cls1 = d[:, 5].long()
  620. ims = []
  621. for j, a in enumerate(d): # per item
  622. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  623. im = cv2.resize(cutout, (224, 224)) # BGR
  624. # cv2.imwrite('example%i.jpg' % j, cutout)
  625. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  626. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  627. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  628. ims.append(im)
  629. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  630. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  631. return x
  632. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  633. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  634. xyxy = torch.tensor(xyxy).view(-1, 4)
  635. b = xyxy2xywh(xyxy) # boxes
  636. if square:
  637. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  638. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  639. xyxy = xywh2xyxy(b).long()
  640. clip_coords(xyxy, im.shape)
  641. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  642. if save:
  643. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  644. return crop
  645. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  646. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  647. path = Path(path) # os-agnostic
  648. if path.exists() and not exist_ok:
  649. suffix = path.suffix
  650. path = path.with_suffix('')
  651. dirs = glob.glob(f"{path}{sep}*") # similar paths
  652. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  653. i = [int(m.groups()[0]) for m in matches if m] # indices
  654. n = max(i) + 1 if i else 2 # increment number
  655. path = Path(f"{path}{sep}{n}{suffix}") # update path
  656. dir = path if path.suffix == '' else path.parent # directory
  657. if not dir.exists() and mkdir:
  658. dir.mkdir(parents=True, exist_ok=True) # make directory
  659. return path