You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

821 line
33KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. General utils
  4. """
  5. import contextlib
  6. import glob
  7. import logging
  8. import math
  9. import os
  10. import platform
  11. import random
  12. import re
  13. import signal
  14. import time
  15. import urllib
  16. from itertools import repeat
  17. from multiprocessing.pool import ThreadPool
  18. from pathlib import Path
  19. from subprocess import check_output
  20. from zipfile import ZipFile
  21. import cv2
  22. import numpy as np
  23. import pandas as pd
  24. import pkg_resources as pkg
  25. import torch
  26. import torchvision
  27. import yaml
  28. from utils.downloads import gsutil_getsize
  29. from utils.metrics import box_iou, fitness
  30. # Settings
  31. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  32. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  33. pd.options.display.max_columns = 10
  34. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  35. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  36. FILE = Path(__file__).resolve()
  37. ROOT = FILE.parents[1] # YOLOv5 root directory
  38. class Profile(contextlib.ContextDecorator):
  39. # Usage: @Profile() decorator or 'with Profile():' context manager
  40. def __enter__(self):
  41. self.start = time.time()
  42. def __exit__(self, type, value, traceback):
  43. print(f'Profile results: {time.time() - self.start:.5f}s')
  44. class Timeout(contextlib.ContextDecorator):
  45. # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
  46. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  47. self.seconds = int(seconds)
  48. self.timeout_message = timeout_msg
  49. self.suppress = bool(suppress_timeout_errors)
  50. def _timeout_handler(self, signum, frame):
  51. raise TimeoutError(self.timeout_message)
  52. def __enter__(self):
  53. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  54. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  55. def __exit__(self, exc_type, exc_val, exc_tb):
  56. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  57. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  58. return True
  59. def try_except(func):
  60. # try-except function. Usage: @try_except decorator
  61. def handler(*args, **kwargs):
  62. try:
  63. func(*args, **kwargs)
  64. except Exception as e:
  65. print(e)
  66. return handler
  67. def methods(instance):
  68. # Get class/instance methods
  69. return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
  70. def set_logging(rank=-1, verbose=True):
  71. logging.basicConfig(
  72. format="%(message)s",
  73. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  74. def print_args(name, opt):
  75. # Print argparser arguments
  76. print(colorstr(f'{name}: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
  77. def init_seeds(seed=0):
  78. # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
  79. # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
  80. import torch.backends.cudnn as cudnn
  81. random.seed(seed)
  82. np.random.seed(seed)
  83. torch.manual_seed(seed)
  84. cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
  85. def get_latest_run(search_dir='.'):
  86. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  87. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  88. return max(last_list, key=os.path.getctime) if last_list else ''
  89. def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
  90. # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
  91. env = os.getenv(env_var)
  92. if env:
  93. path = Path(env) # use environment variable
  94. else:
  95. cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
  96. path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
  97. path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
  98. path.mkdir(exist_ok=True) # make if required
  99. return path
  100. def is_writeable(dir, test=False):
  101. # Return True if directory has write permissions, test opening a file with write permissions if test=True
  102. if test: # method 1
  103. file = Path(dir) / 'tmp.txt'
  104. try:
  105. with open(file, 'w'): # open file with write permissions
  106. pass
  107. file.unlink() # remove file
  108. return True
  109. except IOError:
  110. return False
  111. else: # method 2
  112. return os.access(dir, os.R_OK) # possible issues on Windows
  113. def is_docker():
  114. # Is environment a Docker container?
  115. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  116. def is_colab():
  117. # Is environment a Google Colab instance?
  118. try:
  119. import google.colab
  120. return True
  121. except ImportError:
  122. return False
  123. def is_pip():
  124. # Is file in a pip package?
  125. return 'site-packages' in Path(__file__).resolve().parts
  126. def is_ascii(s=''):
  127. # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
  128. s = str(s) # convert list, tuple, None, etc. to str
  129. return len(s.encode().decode('ascii', 'ignore')) == len(s)
  130. def is_chinese(s='人工智能'):
  131. # Is string composed of any Chinese characters?
  132. return re.search('[\u4e00-\u9fff]', s)
  133. def emojis(str=''):
  134. # Return platform-dependent emoji-safe version of string
  135. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  136. def file_size(path):
  137. # Return file/dir size (MB)
  138. path = Path(path)
  139. if path.is_file():
  140. return path.stat().st_size / 1E6
  141. elif path.is_dir():
  142. return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / 1E6
  143. else:
  144. return 0.0
  145. def check_online():
  146. # Check internet connectivity
  147. import socket
  148. try:
  149. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  150. return True
  151. except OSError:
  152. return False
  153. @try_except
  154. def check_git_status():
  155. # Recommend 'git pull' if code is out of date
  156. msg = ', for updates see https://github.com/ultralytics/yolov5'
  157. print(colorstr('github: '), end='')
  158. assert Path('.git').exists(), 'skipping check (not a git repository)' + msg
  159. assert not is_docker(), 'skipping check (Docker image)' + msg
  160. assert check_online(), 'skipping check (offline)' + msg
  161. cmd = 'git fetch && git config --get remote.origin.url'
  162. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  163. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  164. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  165. if n > 0:
  166. s = f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update."
  167. else:
  168. s = f'up to date with {url} ✅'
  169. print(emojis(s)) # emoji-safe
  170. def check_python(minimum='3.6.2'):
  171. # Check current python version vs. required python version
  172. check_version(platform.python_version(), minimum, name='Python ')
  173. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
  174. # Check version vs. required version
  175. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  176. result = (current == minimum) if pinned else (current >= minimum)
  177. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  178. @try_except
  179. def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True):
  180. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  181. prefix = colorstr('red', 'bold', 'requirements:')
  182. check_python() # check python version
  183. if isinstance(requirements, (str, Path)): # requirements.txt file
  184. file = Path(requirements)
  185. assert file.exists(), f"{prefix} {file.resolve()} not found, check failed."
  186. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  187. else: # list or tuple of packages
  188. requirements = [x for x in requirements if x not in exclude]
  189. n = 0 # number of packages updates
  190. for r in requirements:
  191. try:
  192. pkg.require(r)
  193. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  194. s = f"{prefix} {r} not found and is required by YOLOv5"
  195. if install:
  196. print(f"{s}, attempting auto-update...")
  197. try:
  198. assert check_online(), f"'pip install {r}' skipped (offline)"
  199. print(check_output(f"pip install '{r}'", shell=True).decode())
  200. n += 1
  201. except Exception as e:
  202. print(f'{prefix} {e}')
  203. else:
  204. print(f'{s}. Please install and rerun your command.')
  205. if n: # if packages updated
  206. source = file.resolve() if 'file' in locals() else requirements
  207. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  208. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  209. print(emojis(s))
  210. def check_img_size(imgsz, s=32, floor=0):
  211. # Verify image size is a multiple of stride s in each dimension
  212. if isinstance(imgsz, int): # integer i.e. img_size=640
  213. new_size = max(make_divisible(imgsz, int(s)), floor)
  214. else: # list i.e. img_size=[640, 480]
  215. new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
  216. if new_size != imgsz:
  217. print(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
  218. return new_size
  219. def check_imshow():
  220. # Check if environment supports image displays
  221. try:
  222. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  223. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  224. cv2.imshow('test', np.zeros((1, 1, 3)))
  225. cv2.waitKey(1)
  226. cv2.destroyAllWindows()
  227. cv2.waitKey(1)
  228. return True
  229. except Exception as e:
  230. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  231. return False
  232. def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''):
  233. # Check file(s) for acceptable suffixes
  234. if file and suffix:
  235. if isinstance(suffix, str):
  236. suffix = [suffix]
  237. for f in file if isinstance(file, (list, tuple)) else [file]:
  238. assert Path(f).suffix.lower() in suffix, f"{msg}{f} acceptable suffix is {suffix}"
  239. def check_yaml(file, suffix=('.yaml', '.yml')):
  240. # Search/download YAML file (if necessary) and return path, checking suffix
  241. return check_file(file, suffix)
  242. def check_file(file, suffix=''):
  243. # Search/download file (if necessary) and return path
  244. check_suffix(file, suffix) # optional
  245. file = str(file) # convert to str()
  246. if Path(file).is_file() or file == '': # exists
  247. return file
  248. elif file.startswith(('http:/', 'https:/')): # download
  249. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  250. file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
  251. print(f'Downloading {url} to {file}...')
  252. torch.hub.download_url_to_file(url, file)
  253. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  254. return file
  255. else: # search
  256. files = []
  257. for d in 'data', 'models', 'utils': # search directories
  258. files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
  259. assert len(files), f'File not found: {file}' # assert file was found
  260. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  261. return files[0] # return file
  262. def check_dataset(data, autodownload=True):
  263. # Download and/or unzip dataset if not found locally
  264. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  265. # Download (optional)
  266. extract_dir = ''
  267. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  268. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  269. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  270. extract_dir, autodownload = data.parent, False
  271. # Read yaml (optional)
  272. if isinstance(data, (str, Path)):
  273. with open(data, errors='ignore') as f:
  274. data = yaml.safe_load(f) # dictionary
  275. # Parse yaml
  276. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  277. for k in 'train', 'val', 'test':
  278. if data.get(k): # prepend path
  279. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  280. assert 'nc' in data, "Dataset 'nc' key missing."
  281. if 'names' not in data:
  282. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  283. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  284. if val:
  285. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  286. if not all(x.exists() for x in val):
  287. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  288. if s and autodownload: # download script
  289. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  290. if s.startswith('http') and s.endswith('.zip'): # URL
  291. f = Path(s).name # filename
  292. print(f'Downloading {s} to {f}...')
  293. torch.hub.download_url_to_file(s, f)
  294. Path(root).mkdir(parents=True, exist_ok=True) # create root
  295. ZipFile(f).extractall(path=root) # unzip
  296. Path(f).unlink() # remove zip
  297. r = None # success
  298. elif s.startswith('bash '): # bash script
  299. print(f'Running {s} ...')
  300. r = os.system(s)
  301. else: # python script
  302. r = exec(s, {'yaml': data}) # return None
  303. print(f"Dataset autodownload {f'success, saved to {root}' if r in (0, None) else 'failure'}\n")
  304. else:
  305. raise Exception('Dataset not found.')
  306. return data # dictionary
  307. def url2file(url):
  308. # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
  309. url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
  310. file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  311. return file
  312. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  313. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  314. def download_one(url, dir):
  315. # Download 1 file
  316. f = dir / Path(url).name # filename
  317. if Path(url).is_file(): # exists in current path
  318. Path(url).rename(f) # move to dir
  319. elif not f.exists():
  320. print(f'Downloading {url} to {f}...')
  321. if curl:
  322. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  323. else:
  324. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  325. if unzip and f.suffix in ('.zip', '.gz'):
  326. print(f'Unzipping {f}...')
  327. if f.suffix == '.zip':
  328. ZipFile(f).extractall(path=dir) # unzip
  329. elif f.suffix == '.gz':
  330. os.system(f'tar xfz {f} --directory {f.parent}') # unzip
  331. if delete:
  332. f.unlink() # remove zip
  333. dir = Path(dir)
  334. dir.mkdir(parents=True, exist_ok=True) # make directory
  335. if threads > 1:
  336. pool = ThreadPool(threads)
  337. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  338. pool.close()
  339. pool.join()
  340. else:
  341. for u in [url] if isinstance(url, (str, Path)) else url:
  342. download_one(u, dir)
  343. def make_divisible(x, divisor):
  344. # Returns x evenly divisible by divisor
  345. return math.ceil(x / divisor) * divisor
  346. def clean_str(s):
  347. # Cleans a string by replacing special characters with underscore _
  348. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  349. def one_cycle(y1=0.0, y2=1.0, steps=100):
  350. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  351. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  352. def colorstr(*input):
  353. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  354. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  355. colors = {'black': '\033[30m', # basic colors
  356. 'red': '\033[31m',
  357. 'green': '\033[32m',
  358. 'yellow': '\033[33m',
  359. 'blue': '\033[34m',
  360. 'magenta': '\033[35m',
  361. 'cyan': '\033[36m',
  362. 'white': '\033[37m',
  363. 'bright_black': '\033[90m', # bright colors
  364. 'bright_red': '\033[91m',
  365. 'bright_green': '\033[92m',
  366. 'bright_yellow': '\033[93m',
  367. 'bright_blue': '\033[94m',
  368. 'bright_magenta': '\033[95m',
  369. 'bright_cyan': '\033[96m',
  370. 'bright_white': '\033[97m',
  371. 'end': '\033[0m', # misc
  372. 'bold': '\033[1m',
  373. 'underline': '\033[4m'}
  374. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  375. def labels_to_class_weights(labels, nc=80):
  376. # Get class weights (inverse frequency) from training labels
  377. if labels[0] is None: # no labels loaded
  378. return torch.Tensor()
  379. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  380. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  381. weights = np.bincount(classes, minlength=nc) # occurrences per class
  382. # Prepend gridpoint count (for uCE training)
  383. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  384. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  385. weights[weights == 0] = 1 # replace empty bins with 1
  386. weights = 1 / weights # number of targets per class
  387. weights /= weights.sum() # normalize
  388. return torch.from_numpy(weights)
  389. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  390. # Produces image weights based on class_weights and image contents
  391. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  392. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  393. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  394. return image_weights
  395. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  396. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  397. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  398. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  399. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  400. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  401. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  402. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  403. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  404. return x
  405. def xyxy2xywh(x):
  406. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  407. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  408. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  409. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  410. y[:, 2] = x[:, 2] - x[:, 0] # width
  411. y[:, 3] = x[:, 3] - x[:, 1] # height
  412. return y
  413. def xywh2xyxy(x):
  414. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  415. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  416. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  417. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  418. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  419. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  420. return y
  421. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  422. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  423. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  424. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  425. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  426. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  427. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  428. return y
  429. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  430. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  431. if clip:
  432. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  433. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  434. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  435. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  436. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  437. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  438. return y
  439. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  440. # Convert normalized segments into pixel segments, shape (n,2)
  441. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  442. y[:, 0] = w * x[:, 0] + padw # top left x
  443. y[:, 1] = h * x[:, 1] + padh # top left y
  444. return y
  445. def segment2box(segment, width=640, height=640):
  446. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  447. x, y = segment.T # segment xy
  448. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  449. x, y, = x[inside], y[inside]
  450. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  451. def segments2boxes(segments):
  452. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  453. boxes = []
  454. for s in segments:
  455. x, y = s.T # segment xy
  456. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  457. return xyxy2xywh(np.array(boxes)) # cls, xywh
  458. def resample_segments(segments, n=1000):
  459. # Up-sample an (n,2) segment
  460. for i, s in enumerate(segments):
  461. x = np.linspace(0, len(s) - 1, n)
  462. xp = np.arange(len(s))
  463. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  464. return segments
  465. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  466. # Rescale coords (xyxy) from img1_shape to img0_shape
  467. if ratio_pad is None: # calculate from img0_shape
  468. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  469. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  470. else:
  471. gain = ratio_pad[0][0]
  472. pad = ratio_pad[1]
  473. coords[:, [0, 2]] -= pad[0] # x padding
  474. coords[:, [1, 3]] -= pad[1] # y padding
  475. coords[:, :4] /= gain
  476. clip_coords(coords, img0_shape)
  477. return coords
  478. def clip_coords(boxes, shape):
  479. # Clip bounding xyxy bounding boxes to image shape (height, width)
  480. if isinstance(boxes, torch.Tensor): # faster individually
  481. boxes[:, 0].clamp_(0, shape[1]) # x1
  482. boxes[:, 1].clamp_(0, shape[0]) # y1
  483. boxes[:, 2].clamp_(0, shape[1]) # x2
  484. boxes[:, 3].clamp_(0, shape[0]) # y2
  485. else: # np.array (faster grouped)
  486. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  487. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  488. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  489. labels=(), max_det=300):
  490. """Runs Non-Maximum Suppression (NMS) on inference results
  491. Returns:
  492. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  493. """
  494. nc = prediction.shape[2] - 5 # number of classes
  495. xc = prediction[..., 4] > conf_thres # candidates
  496. # Checks
  497. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  498. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  499. # Settings
  500. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  501. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  502. time_limit = 10.0 # seconds to quit after
  503. redundant = True # require redundant detections
  504. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  505. merge = False # use merge-NMS
  506. t = time.time()
  507. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  508. for xi, x in enumerate(prediction): # image index, image inference
  509. # Apply constraints
  510. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  511. x = x[xc[xi]] # confidence
  512. # Cat apriori labels if autolabelling
  513. if labels and len(labels[xi]):
  514. l = labels[xi]
  515. v = torch.zeros((len(l), nc + 5), device=x.device)
  516. v[:, :4] = l[:, 1:5] # box
  517. v[:, 4] = 1.0 # conf
  518. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  519. x = torch.cat((x, v), 0)
  520. # If none remain process next image
  521. if not x.shape[0]:
  522. continue
  523. # Compute conf
  524. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  525. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  526. box = xywh2xyxy(x[:, :4])
  527. # Detections matrix nx6 (xyxy, conf, cls)
  528. if multi_label:
  529. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  530. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  531. else: # best class only
  532. conf, j = x[:, 5:].max(1, keepdim=True)
  533. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  534. # Filter by class
  535. if classes is not None:
  536. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  537. # Apply finite constraint
  538. # if not torch.isfinite(x).all():
  539. # x = x[torch.isfinite(x).all(1)]
  540. # Check shape
  541. n = x.shape[0] # number of boxes
  542. if not n: # no boxes
  543. continue
  544. elif n > max_nms: # excess boxes
  545. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  546. # Batched NMS
  547. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  548. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  549. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  550. if i.shape[0] > max_det: # limit detections
  551. i = i[:max_det]
  552. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  553. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  554. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  555. weights = iou * scores[None] # box weights
  556. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  557. if redundant:
  558. i = i[iou.sum(1) > 1] # require redundancy
  559. output[xi] = x[i]
  560. if (time.time() - t) > time_limit:
  561. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  562. break # time limit exceeded
  563. return output
  564. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  565. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  566. x = torch.load(f, map_location=torch.device('cpu'))
  567. if x.get('ema'):
  568. x['model'] = x['ema'] # replace model with ema
  569. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  570. x[k] = None
  571. x['epoch'] = -1
  572. x['model'].half() # to FP16
  573. for p in x['model'].parameters():
  574. p.requires_grad = False
  575. torch.save(x, s or f)
  576. mb = os.path.getsize(s or f) / 1E6 # filesize
  577. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  578. def print_mutation(results, hyp, save_dir, bucket):
  579. evolve_csv, results_csv, evolve_yaml = save_dir / 'evolve.csv', save_dir / 'results.csv', save_dir / 'hyp_evolve.yaml'
  580. keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  581. 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps]
  582. keys = tuple(x.strip() for x in keys)
  583. vals = results + tuple(hyp.values())
  584. n = len(keys)
  585. # Download (optional)
  586. if bucket:
  587. url = f'gs://{bucket}/evolve.csv'
  588. if gsutil_getsize(url) > (os.path.getsize(evolve_csv) if os.path.exists(evolve_csv) else 0):
  589. os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
  590. # Log to evolve.csv
  591. s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
  592. with open(evolve_csv, 'a') as f:
  593. f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
  594. # Print to screen
  595. print(colorstr('evolve: ') + ', '.join(f'{x.strip():>20s}' for x in keys))
  596. print(colorstr('evolve: ') + ', '.join(f'{x:20.5g}' for x in vals), end='\n\n\n')
  597. # Save yaml
  598. with open(evolve_yaml, 'w') as f:
  599. data = pd.read_csv(evolve_csv)
  600. data = data.rename(columns=lambda x: x.strip()) # strip keys
  601. i = np.argmax(fitness(data.values[:, :7])) #
  602. f.write('# YOLOv5 Hyperparameter Evolution Results\n' +
  603. f'# Best generation: {i}\n' +
  604. f'# Last generation: {len(data)}\n' +
  605. '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + '\n' +
  606. '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
  607. yaml.safe_dump(hyp, f, sort_keys=False)
  608. if bucket:
  609. os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
  610. def apply_classifier(x, model, img, im0):
  611. # Apply a second stage classifier to yolo outputs
  612. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  613. for i, d in enumerate(x): # per image
  614. if d is not None and len(d):
  615. d = d.clone()
  616. # Reshape and pad cutouts
  617. b = xyxy2xywh(d[:, :4]) # boxes
  618. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  619. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  620. d[:, :4] = xywh2xyxy(b).long()
  621. # Rescale boxes from img_size to im0 size
  622. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  623. # Classes
  624. pred_cls1 = d[:, 5].long()
  625. ims = []
  626. for j, a in enumerate(d): # per item
  627. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  628. im = cv2.resize(cutout, (224, 224)) # BGR
  629. # cv2.imwrite('example%i.jpg' % j, cutout)
  630. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  631. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  632. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  633. ims.append(im)
  634. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  635. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  636. return x
  637. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  638. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  639. xyxy = torch.tensor(xyxy).view(-1, 4)
  640. b = xyxy2xywh(xyxy) # boxes
  641. if square:
  642. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  643. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  644. xyxy = xywh2xyxy(b).long()
  645. clip_coords(xyxy, im.shape)
  646. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  647. if save:
  648. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  649. return crop
  650. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  651. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  652. path = Path(path) # os-agnostic
  653. if path.exists() and not exist_ok:
  654. suffix = path.suffix
  655. path = path.with_suffix('')
  656. dirs = glob.glob(f"{path}{sep}*") # similar paths
  657. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  658. i = [int(m.groups()[0]) for m in matches if m] # indices
  659. n = max(i) + 1 if i else 2 # increment number
  660. path = Path(f"{path}{sep}{n}{suffix}") # update path
  661. dir = path if path.suffix == '' else path.parent # directory
  662. if not dir.exists() and mkdir:
  663. dir.mkdir(parents=True, exist_ok=True) # make directory
  664. return path