You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

756 lines
31KB

  1. # YOLOv5 general utils
  2. import contextlib
  3. import glob
  4. import logging
  5. import math
  6. import os
  7. import platform
  8. import random
  9. import re
  10. import signal
  11. import time
  12. import urllib
  13. from itertools import repeat
  14. from multiprocessing.pool import ThreadPool
  15. from pathlib import Path
  16. from subprocess import check_output
  17. import cv2
  18. import numpy as np
  19. import pandas as pd
  20. import pkg_resources as pkg
  21. import torch
  22. import torchvision
  23. import yaml
  24. from utils.google_utils import gsutil_getsize
  25. from utils.metrics import fitness
  26. from utils.torch_utils import init_torch_seeds
  27. # Settings
  28. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  29. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  30. pd.options.display.max_columns = 10
  31. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  32. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  33. class timeout(contextlib.ContextDecorator):
  34. # Usage: @timeout(seconds) decorator or 'with timeout(seconds):' context manager
  35. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  36. self.seconds = int(seconds)
  37. self.timeout_message = timeout_msg
  38. self.suppress = bool(suppress_timeout_errors)
  39. def _timeout_handler(self, signum, frame):
  40. raise TimeoutError(self.timeout_message)
  41. def __enter__(self):
  42. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  43. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  44. def __exit__(self, exc_type, exc_val, exc_tb):
  45. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  46. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  47. return True
  48. def set_logging(rank=-1, verbose=True):
  49. logging.basicConfig(
  50. format="%(message)s",
  51. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  52. def init_seeds(seed=0):
  53. # Initialize random number generator (RNG) seeds
  54. random.seed(seed)
  55. np.random.seed(seed)
  56. init_torch_seeds(seed)
  57. def get_latest_run(search_dir='.'):
  58. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  59. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  60. return max(last_list, key=os.path.getctime) if last_list else ''
  61. def is_docker():
  62. # Is environment a Docker container?
  63. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  64. def is_colab():
  65. # Is environment a Google Colab instance?
  66. try:
  67. import google.colab
  68. return True
  69. except Exception as e:
  70. return False
  71. def is_pip():
  72. # Is file in a pip package?
  73. return 'site-packages' in Path(__file__).absolute().parts
  74. def emojis(str=''):
  75. # Return platform-dependent emoji-safe version of string
  76. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  77. def file_size(file):
  78. # Return file size in MB
  79. return Path(file).stat().st_size / 1e6
  80. def check_online():
  81. # Check internet connectivity
  82. import socket
  83. try:
  84. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  85. return True
  86. except OSError:
  87. return False
  88. def check_git_status(err_msg=', for updates see https://github.com/ultralytics/yolov5'):
  89. # Recommend 'git pull' if code is out of date
  90. print(colorstr('github: '), end='')
  91. try:
  92. assert Path('.git').exists(), 'skipping check (not a git repository)'
  93. assert not is_docker(), 'skipping check (Docker image)'
  94. assert check_online(), 'skipping check (offline)'
  95. cmd = 'git fetch && git config --get remote.origin.url'
  96. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  97. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  98. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  99. if n > 0:
  100. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  101. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  102. else:
  103. s = f'up to date with {url} ✅'
  104. print(emojis(s)) # emoji-safe
  105. except Exception as e:
  106. print(f'{e}{err_msg}')
  107. def check_python(minimum='3.6.2', required=True):
  108. # Check current python version vs. required python version
  109. current = platform.python_version()
  110. result = pkg.parse_version(current) >= pkg.parse_version(minimum)
  111. if required:
  112. assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed'
  113. return result
  114. def check_requirements(requirements='requirements.txt', exclude=()):
  115. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  116. prefix = colorstr('red', 'bold', 'requirements:')
  117. check_python() # check python version
  118. if isinstance(requirements, (str, Path)): # requirements.txt file
  119. file = Path(requirements)
  120. if not file.exists():
  121. print(f"{prefix} {file.resolve()} not found, check failed.")
  122. return
  123. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  124. else: # list or tuple of packages
  125. requirements = [x for x in requirements if x not in exclude]
  126. n = 0 # number of packages updates
  127. for r in requirements:
  128. try:
  129. pkg.require(r)
  130. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  131. print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...")
  132. try:
  133. assert check_online(), f"'pip install {r}' skipped (offline)"
  134. print(check_output(f"pip install '{r}'", shell=True).decode())
  135. n += 1
  136. except Exception as e:
  137. print(f'{prefix} {e}')
  138. if n: # if packages updated
  139. source = file.resolve() if 'file' in locals() else requirements
  140. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  141. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  142. print(emojis(s)) # emoji-safe
  143. def check_img_size(img_size, s=32):
  144. # Verify img_size is a multiple of stride s
  145. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  146. if new_size != img_size:
  147. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  148. return new_size
  149. def check_imshow():
  150. # Check if environment supports image displays
  151. try:
  152. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  153. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  154. cv2.imshow('test', np.zeros((1, 1, 3)))
  155. cv2.waitKey(1)
  156. cv2.destroyAllWindows()
  157. cv2.waitKey(1)
  158. return True
  159. except Exception as e:
  160. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  161. return False
  162. def check_file(file):
  163. # Search/download file (if necessary) and return path
  164. file = str(file) # convert to str()
  165. if Path(file).is_file() or file == '': # exists
  166. return file
  167. elif file.startswith(('http:/', 'https:/')): # download
  168. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  169. file = Path(urllib.parse.unquote(file)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  170. print(f'Downloading {url} to {file}...')
  171. torch.hub.download_url_to_file(url, file)
  172. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  173. return file
  174. else: # search
  175. files = glob.glob('./**/' + file, recursive=True) # find file
  176. assert len(files), f'File not found: {file}' # assert file was found
  177. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  178. return files[0] # return file
  179. def check_dataset(data, autodownload=True):
  180. # Download dataset if not found locally
  181. path = Path(data.get('path', '')) # optional 'path' field
  182. if path:
  183. for k in 'train', 'val', 'test':
  184. if data.get(k): # prepend path
  185. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  186. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  187. if val:
  188. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  189. if not all(x.exists() for x in val):
  190. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  191. if s and autodownload: # download script
  192. if s.startswith('http') and s.endswith('.zip'): # URL
  193. f = Path(s).name # filename
  194. print(f'Downloading {s} ...')
  195. torch.hub.download_url_to_file(s, f)
  196. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  197. Path(root).mkdir(parents=True, exist_ok=True) # create root
  198. r = os.system(f'unzip -q {f} -d {root} && rm {f}') # unzip
  199. elif s.startswith('bash '): # bash script
  200. print(f'Running {s} ...')
  201. r = os.system(s)
  202. else: # python script
  203. r = exec(s, {'yaml': data}) # return None
  204. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  205. else:
  206. raise Exception('Dataset not found.')
  207. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  208. # Multi-threaded file download and unzip function
  209. def download_one(url, dir):
  210. # Download 1 file
  211. f = dir / Path(url).name # filename
  212. if not f.exists():
  213. print(f'Downloading {url} to {f}...')
  214. if curl:
  215. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  216. else:
  217. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  218. if unzip and f.suffix in ('.zip', '.gz'):
  219. print(f'Unzipping {f}...')
  220. if f.suffix == '.zip':
  221. s = f'unzip -qo {f} -d {dir}' # unzip -quiet -overwrite
  222. elif f.suffix == '.gz':
  223. s = f'tar xfz {f} --directory {f.parent}' # unzip
  224. if delete: # delete zip file after unzip
  225. s += f' && rm {f}'
  226. os.system(s)
  227. dir = Path(dir)
  228. dir.mkdir(parents=True, exist_ok=True) # make directory
  229. if threads > 1:
  230. pool = ThreadPool(threads)
  231. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  232. pool.close()
  233. pool.join()
  234. else:
  235. for u in tuple(url) if isinstance(url, str) else url:
  236. download_one(u, dir)
  237. def make_divisible(x, divisor):
  238. # Returns x evenly divisible by divisor
  239. return math.ceil(x / divisor) * divisor
  240. def clean_str(s):
  241. # Cleans a string by replacing special characters with underscore _
  242. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  243. def one_cycle(y1=0.0, y2=1.0, steps=100):
  244. # lambda function for sinusoidal ramp from y1 to y2
  245. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  246. def colorstr(*input):
  247. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  248. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  249. colors = {'black': '\033[30m', # basic colors
  250. 'red': '\033[31m',
  251. 'green': '\033[32m',
  252. 'yellow': '\033[33m',
  253. 'blue': '\033[34m',
  254. 'magenta': '\033[35m',
  255. 'cyan': '\033[36m',
  256. 'white': '\033[37m',
  257. 'bright_black': '\033[90m', # bright colors
  258. 'bright_red': '\033[91m',
  259. 'bright_green': '\033[92m',
  260. 'bright_yellow': '\033[93m',
  261. 'bright_blue': '\033[94m',
  262. 'bright_magenta': '\033[95m',
  263. 'bright_cyan': '\033[96m',
  264. 'bright_white': '\033[97m',
  265. 'end': '\033[0m', # misc
  266. 'bold': '\033[1m',
  267. 'underline': '\033[4m'}
  268. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  269. def labels_to_class_weights(labels, nc=80):
  270. # Get class weights (inverse frequency) from training labels
  271. if labels[0] is None: # no labels loaded
  272. return torch.Tensor()
  273. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  274. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  275. weights = np.bincount(classes, minlength=nc) # occurrences per class
  276. # Prepend gridpoint count (for uCE training)
  277. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  278. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  279. weights[weights == 0] = 1 # replace empty bins with 1
  280. weights = 1 / weights # number of targets per class
  281. weights /= weights.sum() # normalize
  282. return torch.from_numpy(weights)
  283. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  284. # Produces image weights based on class_weights and image contents
  285. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  286. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  287. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  288. return image_weights
  289. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  290. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  291. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  292. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  293. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  294. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  295. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  296. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  297. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  298. return x
  299. def xyxy2xywh(x):
  300. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  301. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  302. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  303. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  304. y[:, 2] = x[:, 2] - x[:, 0] # width
  305. y[:, 3] = x[:, 3] - x[:, 1] # height
  306. return y
  307. def xywh2xyxy(x):
  308. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  309. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  310. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  311. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  312. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  313. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  314. return y
  315. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  316. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  317. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  318. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  319. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  320. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  321. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  322. return y
  323. def xyxy2xywhn(x, w=640, h=640, clip=False):
  324. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  325. if clip:
  326. clip_coords(x, (h, w)) # warning: inplace clip
  327. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  328. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  329. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  330. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  331. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  332. return y
  333. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  334. # Convert normalized segments into pixel segments, shape (n,2)
  335. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  336. y[:, 0] = w * x[:, 0] + padw # top left x
  337. y[:, 1] = h * x[:, 1] + padh # top left y
  338. return y
  339. def segment2box(segment, width=640, height=640):
  340. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  341. x, y = segment.T # segment xy
  342. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  343. x, y, = x[inside], y[inside]
  344. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  345. def segments2boxes(segments):
  346. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  347. boxes = []
  348. for s in segments:
  349. x, y = s.T # segment xy
  350. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  351. return xyxy2xywh(np.array(boxes)) # cls, xywh
  352. def resample_segments(segments, n=1000):
  353. # Up-sample an (n,2) segment
  354. for i, s in enumerate(segments):
  355. x = np.linspace(0, len(s) - 1, n)
  356. xp = np.arange(len(s))
  357. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  358. return segments
  359. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  360. # Rescale coords (xyxy) from img1_shape to img0_shape
  361. if ratio_pad is None: # calculate from img0_shape
  362. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  363. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  364. else:
  365. gain = ratio_pad[0][0]
  366. pad = ratio_pad[1]
  367. coords[:, [0, 2]] -= pad[0] # x padding
  368. coords[:, [1, 3]] -= pad[1] # y padding
  369. coords[:, :4] /= gain
  370. clip_coords(coords, img0_shape)
  371. return coords
  372. def clip_coords(boxes, img_shape):
  373. # Clip bounding xyxy bounding boxes to image shape (height, width)
  374. if isinstance(boxes, torch.Tensor):
  375. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  376. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  377. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  378. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  379. else: # np.array
  380. boxes[:, 0].clip(0, img_shape[1], out=boxes[:, 0]) # x1
  381. boxes[:, 1].clip(0, img_shape[0], out=boxes[:, 1]) # y1
  382. boxes[:, 2].clip(0, img_shape[1], out=boxes[:, 2]) # x2
  383. boxes[:, 3].clip(0, img_shape[0], out=boxes[:, 3]) # y2
  384. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  385. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  386. box2 = box2.T
  387. # Get the coordinates of bounding boxes
  388. if x1y1x2y2: # x1, y1, x2, y2 = box1
  389. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  390. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  391. else: # transform from xywh to xyxy
  392. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  393. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  394. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  395. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  396. # Intersection area
  397. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  398. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  399. # Union Area
  400. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  401. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  402. union = w1 * h1 + w2 * h2 - inter + eps
  403. iou = inter / union
  404. if GIoU or DIoU or CIoU:
  405. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  406. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  407. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  408. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  409. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  410. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  411. if DIoU:
  412. return iou - rho2 / c2 # DIoU
  413. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  414. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  415. with torch.no_grad():
  416. alpha = v / (v - iou + (1 + eps))
  417. return iou - (rho2 / c2 + v * alpha) # CIoU
  418. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  419. c_area = cw * ch + eps # convex area
  420. return iou - (c_area - union) / c_area # GIoU
  421. else:
  422. return iou # IoU
  423. def box_iou(box1, box2):
  424. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  425. """
  426. Return intersection-over-union (Jaccard index) of boxes.
  427. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  428. Arguments:
  429. box1 (Tensor[N, 4])
  430. box2 (Tensor[M, 4])
  431. Returns:
  432. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  433. IoU values for every element in boxes1 and boxes2
  434. """
  435. def box_area(box):
  436. # box = 4xn
  437. return (box[2] - box[0]) * (box[3] - box[1])
  438. area1 = box_area(box1.T)
  439. area2 = box_area(box2.T)
  440. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  441. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  442. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  443. def wh_iou(wh1, wh2):
  444. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  445. wh1 = wh1[:, None] # [N,1,2]
  446. wh2 = wh2[None] # [1,M,2]
  447. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  448. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  449. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  450. labels=(), max_det=300):
  451. """Runs Non-Maximum Suppression (NMS) on inference results
  452. Returns:
  453. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  454. """
  455. nc = prediction.shape[2] - 5 # number of classes
  456. xc = prediction[..., 4] > conf_thres # candidates
  457. # Checks
  458. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  459. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  460. # Settings
  461. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  462. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  463. time_limit = 10.0 # seconds to quit after
  464. redundant = True # require redundant detections
  465. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  466. merge = False # use merge-NMS
  467. t = time.time()
  468. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  469. for xi, x in enumerate(prediction): # image index, image inference
  470. # Apply constraints
  471. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  472. x = x[xc[xi]] # confidence
  473. # Cat apriori labels if autolabelling
  474. if labels and len(labels[xi]):
  475. l = labels[xi]
  476. v = torch.zeros((len(l), nc + 5), device=x.device)
  477. v[:, :4] = l[:, 1:5] # box
  478. v[:, 4] = 1.0 # conf
  479. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  480. x = torch.cat((x, v), 0)
  481. # If none remain process next image
  482. if not x.shape[0]:
  483. continue
  484. # Compute conf
  485. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  486. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  487. box = xywh2xyxy(x[:, :4])
  488. # Detections matrix nx6 (xyxy, conf, cls)
  489. if multi_label:
  490. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  491. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  492. else: # best class only
  493. conf, j = x[:, 5:].max(1, keepdim=True)
  494. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  495. # Filter by class
  496. if classes is not None:
  497. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  498. # Apply finite constraint
  499. # if not torch.isfinite(x).all():
  500. # x = x[torch.isfinite(x).all(1)]
  501. # Check shape
  502. n = x.shape[0] # number of boxes
  503. if not n: # no boxes
  504. continue
  505. elif n > max_nms: # excess boxes
  506. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  507. # Batched NMS
  508. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  509. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  510. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  511. if i.shape[0] > max_det: # limit detections
  512. i = i[:max_det]
  513. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  514. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  515. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  516. weights = iou * scores[None] # box weights
  517. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  518. if redundant:
  519. i = i[iou.sum(1) > 1] # require redundancy
  520. output[xi] = x[i]
  521. if (time.time() - t) > time_limit:
  522. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  523. break # time limit exceeded
  524. return output
  525. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  526. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  527. x = torch.load(f, map_location=torch.device('cpu'))
  528. if x.get('ema'):
  529. x['model'] = x['ema'] # replace model with ema
  530. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  531. x[k] = None
  532. x['epoch'] = -1
  533. x['model'].half() # to FP16
  534. for p in x['model'].parameters():
  535. p.requires_grad = False
  536. torch.save(x, s or f)
  537. mb = os.path.getsize(s or f) / 1E6 # filesize
  538. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  539. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  540. # Print mutation results to evolve.txt (for use with train.py --evolve)
  541. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  542. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  543. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  544. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  545. if bucket:
  546. url = 'gs://%s/evolve.txt' % bucket
  547. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  548. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  549. with open('evolve.txt', 'a') as f: # append result
  550. f.write(c + b + '\n')
  551. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  552. x = x[np.argsort(-fitness(x))] # sort
  553. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  554. # Save yaml
  555. for i, k in enumerate(hyp.keys()):
  556. hyp[k] = float(x[0, i + 7])
  557. with open(yaml_file, 'w') as f:
  558. results = tuple(x[0, :7])
  559. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  560. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  561. yaml.safe_dump(hyp, f, sort_keys=False)
  562. if bucket:
  563. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  564. def apply_classifier(x, model, img, im0):
  565. # Apply a second stage classifier to yolo outputs
  566. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  567. for i, d in enumerate(x): # per image
  568. if d is not None and len(d):
  569. d = d.clone()
  570. # Reshape and pad cutouts
  571. b = xyxy2xywh(d[:, :4]) # boxes
  572. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  573. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  574. d[:, :4] = xywh2xyxy(b).long()
  575. # Rescale boxes from img_size to im0 size
  576. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  577. # Classes
  578. pred_cls1 = d[:, 5].long()
  579. ims = []
  580. for j, a in enumerate(d): # per item
  581. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  582. im = cv2.resize(cutout, (224, 224)) # BGR
  583. # cv2.imwrite('test%i.jpg' % j, cutout)
  584. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  585. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  586. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  587. ims.append(im)
  588. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  589. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  590. return x
  591. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  592. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  593. xyxy = torch.tensor(xyxy).view(-1, 4)
  594. b = xyxy2xywh(xyxy) # boxes
  595. if square:
  596. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  597. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  598. xyxy = xywh2xyxy(b).long()
  599. clip_coords(xyxy, im.shape)
  600. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  601. if save:
  602. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  603. return crop
  604. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  605. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  606. path = Path(path) # os-agnostic
  607. if path.exists() and not exist_ok:
  608. suffix = path.suffix
  609. path = path.with_suffix('')
  610. dirs = glob.glob(f"{path}{sep}*") # similar paths
  611. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  612. i = [int(m.groups()[0]) for m in matches if m] # indices
  613. n = max(i) + 1 if i else 2 # increment number
  614. path = Path(f"{path}{sep}{n}{suffix}") # update path
  615. dir = path if path.suffix == '' else path.parent # directory
  616. if not dir.exists() and mkdir:
  617. dir.mkdir(parents=True, exist_ok=True) # make directory
  618. return path