You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

804 lines
33KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import signal
  14. import time
  15. import urllib
  16. from itertools import repeat
  17. from multiprocessing.pool import ThreadPool
  18. from pathlib import Path
  19. from subprocess import check_output
  20. import cv2
  21. import numpy as np
  22. import pandas as pd
  23. import pkg_resources as pkg
  24. import torch
  25. import torchvision
  26. import yaml
  27. from utils.downloads import gsutil_getsize
  28. from utils.metrics import box_iou, fitness
  29. # Settings
  30. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  31. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  32. pd.options.display.max_columns = 10
  33. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  34. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  35. class Profile(contextlib.ContextDecorator):
  36. # Usage: @Profile() decorator or 'with Profile():' context manager
  37. def __enter__(self):
  38. self.start = time.time()
  39. def __exit__(self, type, value, traceback):
  40. print(f'Profile results: {time.time() - self.start:.5f}s')
  41. class Timeout(contextlib.ContextDecorator):
  42. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  43. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  44. self.seconds = int(seconds)
  45. self.timeout_message = timeout_msg
  46. self.suppress = bool(suppress_timeout_errors)
  47. def _timeout_handler(self, signum, frame):
  48. raise TimeoutError(self.timeout_message)
  49. def __enter__(self):
  50. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  51. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  52. def __exit__(self, exc_type, exc_val, exc_tb):
  53. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  54. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  55. return True
  56. def try_except(func):
  57. # try-except function. Usage: @try_except decorator
  58. def handler(*args, **kwargs):
  59. try:
  60. func(*args, **kwargs)
  61. except Exception as e:
  62. print(e)
  63. return handler
  64. def methods(instance):
  65. # Get class/instance methods
  66. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  67. def set_logging(rank=-1, verbose=True):
  68. logging.basicConfig(
  69. format="%(message)s",
  70. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  71. def init_seeds(seed=0):
  72. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  73. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  74. import torch.backends.cudnn as cudnn
  75. random.seed(seed)
  76. np.random.seed(seed)
  77. torch.manual_seed(seed)
  78. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  79. def get_latest_run(search_dir='.'):
  80. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  81. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  82. return max(last_list, key=os.path.getctime) if last_list else ''
  83. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  84. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  85. env = os.getenv(env_var)
  86. if env:
  87. path = Path(env) # use environment variable
  88. else:
  89. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  90. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  91. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  92. path.mkdir(exist_ok=True) # make if required
  93. return path
  94. def is_writeable(dir, test=False):
  95. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  96. if test: # method 1
  97. file = Path(dir) / 'tmp.txt'
  98. try:
  99. with open(file, 'w'): # open file with write permissions
  100. pass
  101. file.unlink() # remove file
  102. return True
  103. except IOError:
  104. return False
  105. else: # method 2
  106. return os.access(dir, os.R_OK) # possible issues on Windows
  107. def is_docker():
  108. # Is environment a Docker container?
  109. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  110. def is_colab():
  111. # Is environment a Google Colab instance?
  112. try:
  113. import google.colab
  114. return True
  115. except Exception as e:
  116. return False
  117. def is_pip():
  118. # Is file in a pip package?
  119. return 'site-packages' in Path(__file__).resolve().parts
  120. def is_ascii(s=''):
  121. # Is string composed of all ASCII (no UTF) characters?
  122. s = str(s) # convert list, tuple, None, etc. to str
  123. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  124. def emojis(str=''):
  125. # Return platform-dependent emoji-safe version of string
  126. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  127. def file_size(path):
  128. # Return file/dir size (MB)
  129. path = Path(path)
  130. if path.is_file():
  131. return path.stat().st_size / 1E6
  132. elif path.is_dir():
  133. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  134. else:
  135. return 0.0
  136. def check_online():
  137. # Check internet connectivity
  138. import socket
  139. try:
  140. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  141. return True
  142. except OSError:
  143. return False
  144. @try_except
  145. def check_git_status():
  146. # Recommend 'git pull' if code is out of date
  147. msg = ', for updates see https://github.com/ultralytics/yolov5'
  148. print(colorstr('github: '), end='')
  149. assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
  150. assert not is_docker(), 'skipping check (Docker image)' + msg
  151. assert check_online(), 'skipping check (offline)' + msg
  152. cmd = 'git fetch && git config --get remote.origin.url'
  153. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  154. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  155. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  156. if n > 0:
  157. s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  158. else:
  159. s = f'up to date with {url} ✅'
  160. print(emojis(s)) # emoji-safe
  161. def check_python(minimum='3.6.2'):
  162. # Check current python version vs. required python version
  163. check_version(platform.python_version(), minimum, name='Python ')
  164. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
  165. # Check version vs. required version
  166. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  167. result = (current == minimum) if pinned else (current >= minimum)
  168. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  169. @try_except
  170. def check_requirements(requirements='requirements.txt', exclude=(), install=True):
  171. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  172. prefix = colorstr('red', 'bold', 'requirements:')
  173. check_python() # check python version
  174. if isinstance(requirements, (str, Path)): # requirements.txt file
  175. file = Path(requirements)
  176. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  177. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  178. else: # list or tuple of packages
  179. requirements = [x for x in requirements if x not in exclude]
  180. n = 0 # number of packages updates
  181. for r in requirements:
  182. try:
  183. pkg.require(r)
  184. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  185. s = f"{prefix} {r} not found and is required by YOLOv5"
  186. if install:
  187. print(f"{s}, attempting auto-update...")
  188. try:
  189. assert check_online(), f"'pip install {r}' skipped (offline)"
  190. print(check_output(f"pip install '{r}'", shell=True).decode())
  191. n += 1
  192. except Exception as e:
  193. print(f'{prefix} {e}')
  194. else:
  195. print(f'{s}. Please install and rerun your command.')
  196. if n: # if packages updated
  197. source = file.resolve() if 'file' in locals() else requirements
  198. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  199. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  200. print(emojis(s))
  201. def check_img_size(imgsz, s=32, floor=0):
  202. # Verify image size is a multiple of stride s in each dimension
  203. if isinstance(imgsz, int): # integer i.e. img_size=640
  204. new_size = max(make_divisible(imgsz, int(s)), floor)
  205. else: # list i.e. img_size=[640, 480]
  206. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  207. if new_size != imgsz:
  208. print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  209. return new_size
  210. def check_imshow():
  211. # Check if environment supports image displays
  212. try:
  213. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  214. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  215. cv2.imshow('test', np.zeros((1, 1, 3)))
  216. cv2.waitKey(1)
  217. cv2.destroyAllWindows()
  218. cv2.waitKey(1)
  219. return True
  220. except Exception as e:
  221. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  222. return False
  223. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  224. # Check file(s) for acceptable suffixes
  225. if file and suffix:
  226. if isinstance(suffix, str):
  227. suffix = [suffix]
  228. for f in file if isinstance(file, (list, tuple)) else [file]:
  229. assert Path(f).suffix.lower() in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  230. def check_yaml(file, suffix=('.yaml', '.yml')):
  231. # Search/download YAML file (if necessary) and return path, checking suffix
  232. return check_file(file, suffix)
  233. def check_file(file, suffix=''):
  234. # Search/download file (if necessary) and return path
  235. check_suffix(file, suffix) # optional
  236. file = str(file) # convert to str()
  237. if Path(file).is_file() or file == '': # exists
  238. return file
  239. elif file.startswith(('http:/', 'https:/')): # download
  240. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  241. file = Path(urllib.parse.unquote(file)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  242. print(f'Downloading {url} to {file}...')
  243. torch.hub.download_url_to_file(url, file)
  244. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  245. return file
  246. else: # search
  247. files = glob.glob('./**/' + file, recursive=True) # find file
  248. assert len(files), f'File not found: {file}' # assert file was found
  249. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  250. return files[0] # return file
  251. def check_dataset(data, autodownload=True):
  252. # Download and/or unzip dataset if not found locally
  253. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  254. # Download (optional)
  255. extract_dir = ''
  256. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  257. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  258. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  259. extract_dir, autodownload = data.parent, False
  260. # Read yaml (optional)
  261. if isinstance(data, (str, Path)):
  262. with open(data, errors='ignore') as f:
  263. data = yaml.safe_load(f) # dictionary
  264. # Parse yaml
  265. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  266. for k in 'train', 'val', 'test':
  267. if data.get(k): # prepend path
  268. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  269. assert 'nc' in data, "Dataset 'nc' key missing."
  270. if 'names' not in data:
  271. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  272. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  273. if val:
  274. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  275. if not all(x.exists() for x in val):
  276. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  277. if s and autodownload: # download script
  278. if s.startswith('http') and s.endswith('.zip'): # URL
  279. f = Path(s).name # filename
  280. print(f'Downloading {s} ...')
  281. torch.hub.download_url_to_file(s, f)
  282. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  283. Path(root).mkdir(parents=True, exist_ok=True) # create root
  284. r = os.system(f'unzip -q {f} -d {root} && rm {f}') # unzip
  285. elif s.startswith('bash '): # bash script
  286. print(f'Running {s} ...')
  287. r = os.system(s)
  288. else: # python script
  289. r = exec(s, {'yaml': data}) # return None
  290. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  291. else:
  292. raise Exception('Dataset not found.')
  293. return data # dictionary
  294. def url2file(url):
  295. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  296. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  297. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  298. return file
  299. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  300. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  301. def download_one(url, dir):
  302. # Download 1 file
  303. f = dir / Path(url).name # filename
  304. if Path(url).is_file(): # exists in current path
  305. Path(url).rename(f) # move to dir
  306. elif not f.exists():
  307. print(f'Downloading {url} to {f}...')
  308. if curl:
  309. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  310. else:
  311. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  312. if unzip and f.suffix in ('.zip', '.gz'):
  313. print(f'Unzipping {f}...')
  314. if f.suffix == '.zip':
  315. s = f'unzip -qo {f} -d {dir}' # unzip -quiet -overwrite
  316. elif f.suffix == '.gz':
  317. s = f'tar xfz {f} --directory {f.parent}' # unzip
  318. if delete: # delete zip file after unzip
  319. s += f' && rm {f}'
  320. os.system(s)
  321. dir = Path(dir)
  322. dir.mkdir(parents=True, exist_ok=True) # make directory
  323. if threads > 1:
  324. pool = ThreadPool(threads)
  325. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  326. pool.close()
  327. pool.join()
  328. else:
  329. for u in [url] if isinstance(url, (str, Path)) else url:
  330. download_one(u, dir)
  331. def make_divisible(x, divisor):
  332. # Returns x evenly divisible by divisor
  333. return math.ceil(x / divisor) * divisor
  334. def clean_str(s):
  335. # Cleans a string by replacing special characters with underscore _
  336. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  337. def one_cycle(y1=0.0, y2=1.0, steps=100):
  338. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  339. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  340. def colorstr(*input):
  341. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  342. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  343. colors = {'black': '\033[30m', # basic colors
  344. 'red': '\033[31m',
  345. 'green': '\033[32m',
  346. 'yellow': '\033[33m',
  347. 'blue': '\033[34m',
  348. 'magenta': '\033[35m',
  349. 'cyan': '\033[36m',
  350. 'white': '\033[37m',
  351. 'bright_black': '\033[90m', # bright colors
  352. 'bright_red': '\033[91m',
  353. 'bright_green': '\033[92m',
  354. 'bright_yellow': '\033[93m',
  355. 'bright_blue': '\033[94m',
  356. 'bright_magenta': '\033[95m',
  357. 'bright_cyan': '\033[96m',
  358. 'bright_white': '\033[97m',
  359. 'end': '\033[0m', # misc
  360. 'bold': '\033[1m',
  361. 'underline': '\033[4m'}
  362. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  363. def labels_to_class_weights(labels, nc=80):
  364. # Get class weights (inverse frequency) from training labels
  365. if labels[0] is None: # no labels loaded
  366. return torch.Tensor()
  367. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  368. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  369. weights = np.bincount(classes, minlength=nc) # occurrences per class
  370. # Prepend gridpoint count (for uCE training)
  371. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  372. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  373. weights[weights == 0] = 1 # replace empty bins with 1
  374. weights = 1 / weights # number of targets per class
  375. weights /= weights.sum() # normalize
  376. return torch.from_numpy(weights)
  377. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  378. # Produces image weights based on class_weights and image contents
  379. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  380. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  381. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  382. return image_weights
  383. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  384. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  385. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  386. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  387. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  388. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  389. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  390. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  391. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  392. return x
  393. def xyxy2xywh(x):
  394. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  395. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  396. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  397. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  398. y[:, 2] = x[:, 2] - x[:, 0] # width
  399. y[:, 3] = x[:, 3] - x[:, 1] # height
  400. return y
  401. def xywh2xyxy(x):
  402. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  403. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  404. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  405. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  406. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  407. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  408. return y
  409. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  410. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  411. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  412. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  413. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  414. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  415. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  416. return y
  417. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  418. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  419. if clip:
  420. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  421. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  422. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  423. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  424. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  425. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  426. return y
  427. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  428. # Convert normalized segments into pixel segments, shape (n,2)
  429. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  430. y[:, 0] = w * x[:, 0] + padw # top left x
  431. y[:, 1] = h * x[:, 1] + padh # top left y
  432. return y
  433. def segment2box(segment, width=640, height=640):
  434. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  435. x, y = segment.T # segment xy
  436. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  437. x, y, = x[inside], y[inside]
  438. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  439. def segments2boxes(segments):
  440. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  441. boxes = []
  442. for s in segments:
  443. x, y = s.T # segment xy
  444. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  445. return xyxy2xywh(np.array(boxes)) # cls, xywh
  446. def resample_segments(segments, n=1000):
  447. # Up-sample an (n,2) segment
  448. for i, s in enumerate(segments):
  449. x = np.linspace(0, len(s) - 1, n)
  450. xp = np.arange(len(s))
  451. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  452. return segments
  453. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  454. # Rescale coords (xyxy) from img1_shape to img0_shape
  455. if ratio_pad is None: # calculate from img0_shape
  456. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  457. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  458. else:
  459. gain = ratio_pad[0][0]
  460. pad = ratio_pad[1]
  461. coords[:, [0, 2]] -= pad[0] # x padding
  462. coords[:, [1, 3]] -= pad[1] # y padding
  463. coords[:, :4] /= gain
  464. clip_coords(coords, img0_shape)
  465. return coords
  466. def clip_coords(boxes, shape):
  467. # Clip bounding xyxy bounding boxes to image shape (height, width)
  468. if isinstance(boxes, torch.Tensor): # faster individually
  469. boxes[:, 0].clamp_(0, shape[1]) # x1
  470. boxes[:, 1].clamp_(0, shape[0]) # y1
  471. boxes[:, 2].clamp_(0, shape[1]) # x2
  472. boxes[:, 3].clamp_(0, shape[0]) # y2
  473. else: # np.array (faster grouped)
  474. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  475. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  476. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  477. labels=(), max_det=300):
  478. """Runs Non-Maximum Suppression (NMS) on inference results
  479. Returns:
  480. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  481. """
  482. nc = prediction.shape[2] - 5 # number of classes
  483. xc = prediction[..., 4] > conf_thres # candidates
  484. # Checks
  485. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  486. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  487. # Settings
  488. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  489. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  490. time_limit = 10.0 # seconds to quit after
  491. redundant = True # require redundant detections
  492. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  493. merge = False # use merge-NMS
  494. t = time.time()
  495. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  496. for xi, x in enumerate(prediction): # image index, image inference
  497. # Apply constraints
  498. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  499. x = x[xc[xi]] # confidence
  500. # Cat apriori labels if autolabelling
  501. if labels and len(labels[xi]):
  502. l = labels[xi]
  503. v = torch.zeros((len(l), nc + 5), device=x.device)
  504. v[:, :4] = l[:, 1:5] # box
  505. v[:, 4] = 1.0 # conf
  506. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  507. x = torch.cat((x, v), 0)
  508. # If none remain process next image
  509. if not x.shape[0]:
  510. continue
  511. # Compute conf
  512. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  513. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  514. box = xywh2xyxy(x[:, :4])
  515. # Detections matrix nx6 (xyxy, conf, cls)
  516. if multi_label:
  517. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  518. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  519. else: # best class only
  520. conf, j = x[:, 5:].max(1, keepdim=True)
  521. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  522. # Filter by class
  523. if classes is not None:
  524. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  525. # Apply finite constraint
  526. # if not torch.isfinite(x).all():
  527. # x = x[torch.isfinite(x).all(1)]
  528. # Check shape
  529. n = x.shape[0] # number of boxes
  530. if not n: # no boxes
  531. continue
  532. elif n > max_nms: # excess boxes
  533. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  534. # Batched NMS
  535. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  536. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  537. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  538. if i.shape[0] > max_det: # limit detections
  539. i = i[:max_det]
  540. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  541. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  542. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  543. weights = iou * scores[None] # box weights
  544. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  545. if redundant:
  546. i = i[iou.sum(1) > 1] # require redundancy
  547. output[xi] = x[i]
  548. if (time.time() - t) > time_limit:
  549. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  550. break # time limit exceeded
  551. return output
  552. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  553. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  554. x = torch.load(f, map_location=torch.device('cpu'))
  555. if x.get('ema'):
  556. x['model'] = x['ema'] # replace model with ema
  557. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  558. x[k] = None
  559. x['epoch'] = -1
  560. x['model'].half() # to FP16
  561. for p in x['model'].parameters():
  562. p.requires_grad = False
  563. torch.save(x, s or f)
  564. mb = os.path.getsize(s or f) / 1E6 # filesize
  565. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  566. def print_mutation(results, hyp, save_dir, bucket):
  567. evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
  568. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  569. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  570. keys = tuple(x.strip() for x in keys)
  571. vals = results + tuple(hyp.values())
  572. n = len(keys)
  573. # Download (optional)
  574. if bucket:
  575. url = f'gs://{bucket}/evolve.csv'
  576. if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
  577. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  578. # Log to evolve.csv
  579. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  580. with open(evolve_csv, 'a') as f:
  581. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  582. # Print to screen
  583. print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
  584. print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
  585. # Save yaml
  586. with open(evolve_yaml, 'w') as f:
  587. data = pd.read_csv(evolve_csv)
  588. data = data.rename(columns=lambda x: x.strip()) # strip keys
  589. i = np.argmax(fitness(data.values[:, :7])) #
  590. f.write(f'# YOLOv5 Hyperparameter Evolution Results\n' +
  591. f'# Best generation: {i}\n' +
  592. f'# Last generation: {len(data)}\n' +
  593. f'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  594. f'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  595. yaml.safe_dump(hyp, f, sort_keys=False)
  596. if bucket:
  597. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  598. def apply_classifier(x, model, img, im0):
  599. # Apply a second stage classifier to yolo outputs
  600. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  601. for i, d in enumerate(x): # per image
  602. if d is not None and len(d):
  603. d = d.clone()
  604. # Reshape and pad cutouts
  605. b = xyxy2xywh(d[:, :4]) # boxes
  606. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  607. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  608. d[:, :4] = xywh2xyxy(b).long()
  609. # Rescale boxes from img_size to im0 size
  610. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  611. # Classes
  612. pred_cls1 = d[:, 5].long()
  613. ims = []
  614. for j, a in enumerate(d): # per item
  615. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  616. im = cv2.resize(cutout, (224, 224)) # BGR
  617. # cv2.imwrite('example%i.jpg' % j, cutout)
  618. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  619. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  620. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  621. ims.append(im)
  622. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  623. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  624. return x
  625. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  626. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  627. xyxy = torch.tensor(xyxy).view(-1, 4)
  628. b = xyxy2xywh(xyxy) # boxes
  629. if square:
  630. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  631. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  632. xyxy = xywh2xyxy(b).long()
  633. clip_coords(xyxy, im.shape)
  634. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  635. if save:
  636. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  637. return crop
  638. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  639. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  640. path = Path(path) # os-agnostic
  641. if path.exists() and not exist_ok:
  642. suffix = path.suffix
  643. path = path.with_suffix('')
  644. dirs = glob.glob(f"{path}{sep}*") # similar paths
  645. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  646. i = [int(m.groups()[0]) for m in matches if m] # indices
  647. n = max(i) + 1 if i else 2 # increment number
  648. path = Path(f"{path}{sep}{n}{suffix}") # update path
  649. dir = path if path.suffix == '' else path.parent # directory
  650. if not dir.exists() and mkdir:
  651. dir.mkdir(parents=True, exist_ok=True) # make directory
  652. return path