You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

795 lines
32KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import signal
  14. import time
  15. import urllib
  16. from itertools import repeat
  17. from multiprocessing.pool import ThreadPool
  18. from pathlib import Path
  19. from subprocess import check_output
  20. import cv2
  21. import numpy as np
  22. import pandas as pd
  23. import pkg_resources as pkg
  24. import torch
  25. import torchvision
  26. import yaml
  27. from utils.downloads import gsutil_getsize
  28. from utils.metrics import box_iou, fitness
  29. from utils.torch_utils import init_torch_seeds
  30. # Settings
  31. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  32. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  33. pd.options.display.max_columns = 10
  34. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  35. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  36. class Profile(contextlib.ContextDecorator):
  37. # Usage: @Profile() decorator or 'with Profile():' context manager
  38. def __enter__(self):
  39. self.start = time.time()
  40. def __exit__(self, type, value, traceback):
  41. print(f'Profile results: {time.time() - self.start:.5f}s')
  42. class Timeout(contextlib.ContextDecorator):
  43. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  44. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  45. self.seconds = int(seconds)
  46. self.timeout_message = timeout_msg
  47. self.suppress = bool(suppress_timeout_errors)
  48. def _timeout_handler(self, signum, frame):
  49. raise TimeoutError(self.timeout_message)
  50. def __enter__(self):
  51. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  52. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  53. def __exit__(self, exc_type, exc_val, exc_tb):
  54. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  55. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  56. return True
  57. def try_except(func):
  58. # try-except function. Usage: @try_except decorator
  59. def handler(*args, **kwargs):
  60. try:
  61. func(*args, **kwargs)
  62. except Exception as e:
  63. print(e)
  64. return handler
  65. def methods(instance):
  66. # Get class/instance methods
  67. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  68. def set_logging(rank=-1, verbose=True):
  69. logging.basicConfig(
  70. format="%(message)s",
  71. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  72. def init_seeds(seed=0):
  73. # Initialize random number generator (RNG) seeds
  74. random.seed(seed)
  75. np.random.seed(seed)
  76. init_torch_seeds(seed)
  77. def get_latest_run(search_dir='.'):
  78. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  79. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  80. return max(last_list, key=os.path.getctime) if last_list else ''
  81. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  82. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  83. env = os.getenv(env_var)
  84. if env:
  85. path = Path(env) # use environment variable
  86. else:
  87. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  88. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  89. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  90. path.mkdir(exist_ok=True) # make if required
  91. return path
  92. def is_writeable(dir, test=False):
  93. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  94. if test: # method 1
  95. file = Path(dir) / 'tmp.txt'
  96. try:
  97. with open(file, 'w'): # open file with write permissions
  98. pass
  99. file.unlink() # remove file
  100. return True
  101. except IOError:
  102. return False
  103. else: # method 2
  104. return os.access(dir, os.R_OK) # possible issues on Windows
  105. def is_docker():
  106. # Is environment a Docker container?
  107. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  108. def is_colab():
  109. # Is environment a Google Colab instance?
  110. try:
  111. import google.colab
  112. return True
  113. except Exception as e:
  114. return False
  115. def is_pip():
  116. # Is file in a pip package?
  117. return 'site-packages' in Path(__file__).resolve().parts
  118. def is_ascii(s=''):
  119. # Is string composed of all ASCII (no UTF) characters?
  120. s = str(s) # convert list, tuple, None, etc. to str
  121. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  122. def emojis(str=''):
  123. # Return platform-dependent emoji-safe version of string
  124. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  125. def file_size(path):
  126. # Return file/dir size (MB)
  127. path = Path(path)
  128. if path.is_file():
  129. return path.stat().st_size / 1E6
  130. elif path.is_dir():
  131. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  132. else:
  133. return 0.0
  134. def check_online():
  135. # Check internet connectivity
  136. import socket
  137. try:
  138. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  139. return True
  140. except OSError:
  141. return False
  142. @try_except
  143. def check_git_status():
  144. # Recommend 'git pull' if code is out of date
  145. msg = ', for updates see https://github.com/ultralytics/yolov5'
  146. print(colorstr('github: '), end='')
  147. assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
  148. assert not is_docker(), 'skipping check (Docker image)' + msg
  149. assert check_online(), 'skipping check (offline)' + msg
  150. cmd = 'git fetch && git config --get remote.origin.url'
  151. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  152. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  153. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  154. if n > 0:
  155. s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  156. else:
  157. s = f'up to date with {url} ✅'
  158. print(emojis(s)) # emoji-safe
  159. def check_python(minimum='3.6.2'):
  160. # Check current python version vs. required python version
  161. check_version(platform.python_version(), minimum, name='Python ')
  162. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
  163. # Check version vs. required version
  164. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  165. result = (current == minimum) if pinned else (current >= minimum)
  166. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  167. @try_except
  168. def check_requirements(requirements='requirements.txt', exclude=(), install=True):
  169. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  170. prefix = colorstr('red', 'bold', 'requirements:')
  171. check_python() # check python version
  172. if isinstance(requirements, (str, Path)): # requirements.txt file
  173. file = Path(requirements)
  174. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  175. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  176. else: # list or tuple of packages
  177. requirements = [x for x in requirements if x not in exclude]
  178. n = 0 # number of packages updates
  179. for r in requirements:
  180. try:
  181. pkg.require(r)
  182. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  183. s = f"{prefix} {r} not found and is required by YOLOv5"
  184. if install:
  185. print(f"{s}, attempting auto-update...")
  186. try:
  187. assert check_online(), f"'pip install {r}' skipped (offline)"
  188. print(check_output(f"pip install '{r}'", shell=True).decode())
  189. n += 1
  190. except Exception as e:
  191. print(f'{prefix} {e}')
  192. else:
  193. print(f'{s}. Please install and rerun your command.')
  194. if n: # if packages updated
  195. source = file.resolve() if 'file' in locals() else requirements
  196. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  197. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  198. print(emojis(s))
  199. def check_img_size(imgsz, s=32, floor=0):
  200. # Verify image size is a multiple of stride s in each dimension
  201. if isinstance(imgsz, int): # integer i.e. img_size=640
  202. new_size = max(make_divisible(imgsz, int(s)), floor)
  203. else: # list i.e. img_size=[640, 480]
  204. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  205. if new_size != imgsz:
  206. print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  207. return new_size
  208. def check_imshow():
  209. # Check if environment supports image displays
  210. try:
  211. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  212. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  213. cv2.imshow('test', np.zeros((1, 1, 3)))
  214. cv2.waitKey(1)
  215. cv2.destroyAllWindows()
  216. cv2.waitKey(1)
  217. return True
  218. except Exception as e:
  219. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  220. return False
  221. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  222. # Check file(s) for acceptable suffixes
  223. if file and suffix:
  224. if isinstance(suffix, str):
  225. suffix = [suffix]
  226. for f in file if isinstance(file, (list, tuple)) else [file]:
  227. assert Path(f).suffix.lower() in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  228. def check_yaml(file, suffix=('.yaml', '.yml')):
  229. # Search/download YAML file (if necessary) and return path, checking suffix
  230. return check_file(file, suffix)
  231. def check_file(file, suffix=''):
  232. # Search/download file (if necessary) and return path
  233. check_suffix(file, suffix) # optional
  234. file = str(file) # convert to str()
  235. if Path(file).is_file() or file == '': # exists
  236. return file
  237. elif file.startswith(('http:/', 'https:/')): # download
  238. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  239. file = Path(urllib.parse.unquote(file)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  240. print(f'Downloading {url} to {file}...')
  241. torch.hub.download_url_to_file(url, file)
  242. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  243. return file
  244. else: # search
  245. files = glob.glob('./**/' + file, recursive=True) # find file
  246. assert len(files), f'File not found: {file}' # assert file was found
  247. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  248. return files[0] # return file
  249. def check_dataset(data, autodownload=True):
  250. # Download and/or unzip dataset if not found locally
  251. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  252. # Download (optional)
  253. extract_dir = ''
  254. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  255. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  256. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  257. extract_dir, autodownload = data.parent, False
  258. # Read yaml (optional)
  259. if isinstance(data, (str, Path)):
  260. with open(data, errors='ignore') as f:
  261. data = yaml.safe_load(f) # dictionary
  262. # Parse yaml
  263. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  264. for k in 'train', 'val', 'test':
  265. if data.get(k): # prepend path
  266. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  267. assert 'nc' in data, "Dataset 'nc' key missing."
  268. if 'names' not in data:
  269. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  270. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  271. if val:
  272. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  273. if not all(x.exists() for x in val):
  274. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  275. if s and autodownload: # download script
  276. if s.startswith('http') and s.endswith('.zip'): # URL
  277. f = Path(s).name # filename
  278. print(f'Downloading {s} ...')
  279. torch.hub.download_url_to_file(s, f)
  280. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  281. Path(root).mkdir(parents=True, exist_ok=True) # create root
  282. r = os.system(f'unzip -q {f} -d {root} && rm {f}') # unzip
  283. elif s.startswith('bash '): # bash script
  284. print(f'Running {s} ...')
  285. r = os.system(s)
  286. else: # python script
  287. r = exec(s, {'yaml': data}) # return None
  288. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  289. else:
  290. raise Exception('Dataset not found.')
  291. return data # dictionary
  292. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  293. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  294. def download_one(url, dir):
  295. # Download 1 file
  296. f = dir / Path(url).name # filename
  297. if Path(url).is_file(): # exists in current path
  298. Path(url).rename(f) # move to dir
  299. elif not f.exists():
  300. print(f'Downloading {url} to {f}...')
  301. if curl:
  302. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  303. else:
  304. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  305. if unzip and f.suffix in ('.zip', '.gz'):
  306. print(f'Unzipping {f}...')
  307. if f.suffix == '.zip':
  308. s = f'unzip -qo {f} -d {dir}' # unzip -quiet -overwrite
  309. elif f.suffix == '.gz':
  310. s = f'tar xfz {f} --directory {f.parent}' # unzip
  311. if delete: # delete zip file after unzip
  312. s += f' && rm {f}'
  313. os.system(s)
  314. dir = Path(dir)
  315. dir.mkdir(parents=True, exist_ok=True) # make directory
  316. if threads > 1:
  317. pool = ThreadPool(threads)
  318. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  319. pool.close()
  320. pool.join()
  321. else:
  322. for u in [url] if isinstance(url, (str, Path)) else url:
  323. download_one(u, dir)
  324. def make_divisible(x, divisor):
  325. # Returns x evenly divisible by divisor
  326. return math.ceil(x / divisor) * divisor
  327. def clean_str(s):
  328. # Cleans a string by replacing special characters with underscore _
  329. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  330. def one_cycle(y1=0.0, y2=1.0, steps=100):
  331. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  332. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  333. def colorstr(*input):
  334. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  335. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  336. colors = {'black': '\033[30m', # basic colors
  337. 'red': '\033[31m',
  338. 'green': '\033[32m',
  339. 'yellow': '\033[33m',
  340. 'blue': '\033[34m',
  341. 'magenta': '\033[35m',
  342. 'cyan': '\033[36m',
  343. 'white': '\033[37m',
  344. 'bright_black': '\033[90m', # bright colors
  345. 'bright_red': '\033[91m',
  346. 'bright_green': '\033[92m',
  347. 'bright_yellow': '\033[93m',
  348. 'bright_blue': '\033[94m',
  349. 'bright_magenta': '\033[95m',
  350. 'bright_cyan': '\033[96m',
  351. 'bright_white': '\033[97m',
  352. 'end': '\033[0m', # misc
  353. 'bold': '\033[1m',
  354. 'underline': '\033[4m'}
  355. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  356. def labels_to_class_weights(labels, nc=80):
  357. # Get class weights (inverse frequency) from training labels
  358. if labels[0] is None: # no labels loaded
  359. return torch.Tensor()
  360. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  361. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  362. weights = np.bincount(classes, minlength=nc) # occurrences per class
  363. # Prepend gridpoint count (for uCE training)
  364. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  365. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  366. weights[weights == 0] = 1 # replace empty bins with 1
  367. weights = 1 / weights # number of targets per class
  368. weights /= weights.sum() # normalize
  369. return torch.from_numpy(weights)
  370. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  371. # Produces image weights based on class_weights and image contents
  372. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  373. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  374. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  375. return image_weights
  376. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  377. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  378. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  379. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  380. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  381. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  382. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  383. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  384. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  385. return x
  386. def xyxy2xywh(x):
  387. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  388. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  389. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  390. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  391. y[:, 2] = x[:, 2] - x[:, 0] # width
  392. y[:, 3] = x[:, 3] - x[:, 1] # height
  393. return y
  394. def xywh2xyxy(x):
  395. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  396. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  397. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  398. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  399. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  400. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  401. return y
  402. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  403. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  404. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  405. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  406. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  407. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  408. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  409. return y
  410. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  411. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  412. if clip:
  413. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  414. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  415. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  416. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  417. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  418. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  419. return y
  420. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  421. # Convert normalized segments into pixel segments, shape (n,2)
  422. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  423. y[:, 0] = w * x[:, 0] + padw # top left x
  424. y[:, 1] = h * x[:, 1] + padh # top left y
  425. return y
  426. def segment2box(segment, width=640, height=640):
  427. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  428. x, y = segment.T # segment xy
  429. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  430. x, y, = x[inside], y[inside]
  431. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  432. def segments2boxes(segments):
  433. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  434. boxes = []
  435. for s in segments:
  436. x, y = s.T # segment xy
  437. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  438. return xyxy2xywh(np.array(boxes)) # cls, xywh
  439. def resample_segments(segments, n=1000):
  440. # Up-sample an (n,2) segment
  441. for i, s in enumerate(segments):
  442. x = np.linspace(0, len(s) - 1, n)
  443. xp = np.arange(len(s))
  444. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  445. return segments
  446. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  447. # Rescale coords (xyxy) from img1_shape to img0_shape
  448. if ratio_pad is None: # calculate from img0_shape
  449. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  450. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  451. else:
  452. gain = ratio_pad[0][0]
  453. pad = ratio_pad[1]
  454. coords[:, [0, 2]] -= pad[0] # x padding
  455. coords[:, [1, 3]] -= pad[1] # y padding
  456. coords[:, :4] /= gain
  457. clip_coords(coords, img0_shape)
  458. return coords
  459. def clip_coords(boxes, shape):
  460. # Clip bounding xyxy bounding boxes to image shape (height, width)
  461. if isinstance(boxes, torch.Tensor): # faster individually
  462. boxes[:, 0].clamp_(0, shape[1]) # x1
  463. boxes[:, 1].clamp_(0, shape[0]) # y1
  464. boxes[:, 2].clamp_(0, shape[1]) # x2
  465. boxes[:, 3].clamp_(0, shape[0]) # y2
  466. else: # np.array (faster grouped)
  467. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  468. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  469. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  470. labels=(), max_det=300):
  471. """Runs Non-Maximum Suppression (NMS) on inference results
  472. Returns:
  473. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  474. """
  475. nc = prediction.shape[2] - 5 # number of classes
  476. xc = prediction[..., 4] > conf_thres # candidates
  477. # Checks
  478. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  479. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  480. # Settings
  481. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  482. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  483. time_limit = 10.0 # seconds to quit after
  484. redundant = True # require redundant detections
  485. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  486. merge = False # use merge-NMS
  487. t = time.time()
  488. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  489. for xi, x in enumerate(prediction): # image index, image inference
  490. # Apply constraints
  491. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  492. x = x[xc[xi]] # confidence
  493. # Cat apriori labels if autolabelling
  494. if labels and len(labels[xi]):
  495. l = labels[xi]
  496. v = torch.zeros((len(l), nc + 5), device=x.device)
  497. v[:, :4] = l[:, 1:5] # box
  498. v[:, 4] = 1.0 # conf
  499. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  500. x = torch.cat((x, v), 0)
  501. # If none remain process next image
  502. if not x.shape[0]:
  503. continue
  504. # Compute conf
  505. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  506. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  507. box = xywh2xyxy(x[:, :4])
  508. # Detections matrix nx6 (xyxy, conf, cls)
  509. if multi_label:
  510. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  511. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  512. else: # best class only
  513. conf, j = x[:, 5:].max(1, keepdim=True)
  514. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  515. # Filter by class
  516. if classes is not None:
  517. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  518. # Apply finite constraint
  519. # if not torch.isfinite(x).all():
  520. # x = x[torch.isfinite(x).all(1)]
  521. # Check shape
  522. n = x.shape[0] # number of boxes
  523. if not n: # no boxes
  524. continue
  525. elif n > max_nms: # excess boxes
  526. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  527. # Batched NMS
  528. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  529. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  530. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  531. if i.shape[0] > max_det: # limit detections
  532. i = i[:max_det]
  533. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  534. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  535. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  536. weights = iou * scores[None] # box weights
  537. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  538. if redundant:
  539. i = i[iou.sum(1) > 1] # require redundancy
  540. output[xi] = x[i]
  541. if (time.time() - t) > time_limit:
  542. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  543. break # time limit exceeded
  544. return output
  545. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  546. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  547. x = torch.load(f, map_location=torch.device('cpu'))
  548. if x.get('ema'):
  549. x['model'] = x['ema'] # replace model with ema
  550. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  551. x[k] = None
  552. x['epoch'] = -1
  553. x['model'].half() # to FP16
  554. for p in x['model'].parameters():
  555. p.requires_grad = False
  556. torch.save(x, s or f)
  557. mb = os.path.getsize(s or f) / 1E6 # filesize
  558. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  559. def print_mutation(results, hyp, save_dir, bucket):
  560. evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
  561. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  562. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  563. keys = tuple(x.strip() for x in keys)
  564. vals = results + tuple(hyp.values())
  565. n = len(keys)
  566. # Download (optional)
  567. if bucket:
  568. url = f'gs://{bucket}/evolve.csv'
  569. if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
  570. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  571. # Log to evolve.csv
  572. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  573. with open(evolve_csv, 'a') as f:
  574. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  575. # Print to screen
  576. print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
  577. print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
  578. # Save yaml
  579. with open(evolve_yaml, 'w') as f:
  580. data = pd.read_csv(evolve_csv)
  581. data = data.rename(columns=lambda x: x.strip()) # strip keys
  582. i = np.argmax(fitness(data.values[:, :7])) #
  583. f.write(f'# YOLOv5 Hyperparameter Evolution Results\n' +
  584. f'# Best generation: {i}\n' +
  585. f'# Last generation: {len(data)}\n' +
  586. f'# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  587. f'# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  588. yaml.safe_dump(hyp, f, sort_keys=False)
  589. if bucket:
  590. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  591. def apply_classifier(x, model, img, im0):
  592. # Apply a second stage classifier to yolo outputs
  593. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  594. for i, d in enumerate(x): # per image
  595. if d is not None and len(d):
  596. d = d.clone()
  597. # Reshape and pad cutouts
  598. b = xyxy2xywh(d[:, :4]) # boxes
  599. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  600. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  601. d[:, :4] = xywh2xyxy(b).long()
  602. # Rescale boxes from img_size to im0 size
  603. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  604. # Classes
  605. pred_cls1 = d[:, 5].long()
  606. ims = []
  607. for j, a in enumerate(d): # per item
  608. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  609. im = cv2.resize(cutout, (224, 224)) # BGR
  610. # cv2.imwrite('example%i.jpg' % j, cutout)
  611. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  612. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  613. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  614. ims.append(im)
  615. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  616. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  617. return x
  618. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  619. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  620. xyxy = torch.tensor(xyxy).view(-1, 4)
  621. b = xyxy2xywh(xyxy) # boxes
  622. if square:
  623. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  624. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  625. xyxy = xywh2xyxy(b).long()
  626. clip_coords(xyxy, im.shape)
  627. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  628. if save:
  629. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  630. return crop
  631. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  632. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  633. path = Path(path) # os-agnostic
  634. if path.exists() and not exist_ok:
  635. suffix = path.suffix
  636. path = path.with_suffix('')
  637. dirs = glob.glob(f"{path}{sep}*") # similar paths
  638. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  639. i = [int(m.groups()[0]) for m in matches if m] # indices
  640. n = max(i) + 1 if i else 2 # increment number
  641. path = Path(f"{path}{sep}{n}{suffix}") # update path
  642. dir = path if path.suffix == '' else path.parent # directory
  643. if not dir.exists() and mkdir:
  644. dir.mkdir(parents=True, exist_ok=True) # make directory
  645. return path