You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

748 lines
31KB

  1. # YOLOv5 general utils
  2. import contextlib
  3. import glob
  4. import logging
  5. import math
  6. import os
  7. import platform
  8. import random
  9. import re
  10. import signal
  11. import time
  12. import urllib
  13. from itertools import repeat
  14. from multiprocessing.pool import ThreadPool
  15. from pathlib import Path
  16. from subprocess import check_output
  17. import cv2
  18. import numpy as np
  19. import pandas as pd
  20. import pkg_resources as pkg
  21. import torch
  22. import torchvision
  23. import yaml
  24. from utils.google_utils import gsutil_getsize
  25. from utils.metrics import fitness
  26. from utils.torch_utils import init_torch_seeds
  27. # Settings
  28. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  29. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  30. pd.options.display.max_columns = 10
  31. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  32. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  33. class timeout(contextlib.ContextDecorator):
  34. # Usage: @timeout(seconds) decorator or 'with timeout(seconds):' context manager
  35. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  36. self.seconds = int(seconds)
  37. self.timeout_message = timeout_msg
  38. self.suppress = bool(suppress_timeout_errors)
  39. def _timeout_handler(self, signum, frame):
  40. raise TimeoutError(self.timeout_message)
  41. def __enter__(self):
  42. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  43. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  44. def __exit__(self, exc_type, exc_val, exc_tb):
  45. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  46. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  47. return True
  48. def set_logging(rank=-1, verbose=True):
  49. logging.basicConfig(
  50. format="%(message)s",
  51. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  52. def init_seeds(seed=0):
  53. # Initialize random number generator (RNG) seeds
  54. random.seed(seed)
  55. np.random.seed(seed)
  56. init_torch_seeds(seed)
  57. def get_latest_run(search_dir='.'):
  58. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  59. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  60. return max(last_list, key=os.path.getctime) if last_list else ''
  61. def is_docker():
  62. # Is environment a Docker container?
  63. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  64. def is_colab():
  65. # Is environment a Google Colab instance?
  66. try:
  67. import google.colab
  68. return True
  69. except Exception as e:
  70. return False
  71. def is_pip():
  72. # Is file in a pip package?
  73. return 'site-packages' in Path(__file__).absolute().parts
  74. def emojis(str=''):
  75. # Return platform-dependent emoji-safe version of string
  76. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  77. def file_size(file):
  78. # Return file size in MB
  79. return Path(file).stat().st_size / 1e6
  80. def check_online():
  81. # Check internet connectivity
  82. import socket
  83. try:
  84. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  85. return True
  86. except OSError:
  87. return False
  88. def check_git_status(err_msg=', for updates see https://github.com/ultralytics/yolov5'):
  89. # Recommend 'git pull' if code is out of date
  90. print(colorstr('github: '), end='')
  91. try:
  92. assert Path('.git').exists(), 'skipping check (not a git repository)'
  93. assert not is_docker(), 'skipping check (Docker image)'
  94. assert check_online(), 'skipping check (offline)'
  95. cmd = 'git fetch && git config --get remote.origin.url'
  96. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  97. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  98. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  99. if n > 0:
  100. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  101. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  102. else:
  103. s = f'up to date with {url} ✅'
  104. print(emojis(s)) # emoji-safe
  105. except Exception as e:
  106. print(f'{e}{err_msg}')
  107. def check_python(minimum='3.6.2', required=True):
  108. # Check current python version vs. required python version
  109. current = platform.python_version()
  110. result = pkg.parse_version(current) >= pkg.parse_version(minimum)
  111. if required:
  112. assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed'
  113. return result
  114. def check_requirements(requirements='requirements.txt', exclude=()):
  115. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  116. prefix = colorstr('red', 'bold', 'requirements:')
  117. check_python() # check python version
  118. if isinstance(requirements, (str, Path)): # requirements.txt file
  119. file = Path(requirements)
  120. if not file.exists():
  121. print(f"{prefix} {file.resolve()} not found, check failed.")
  122. return
  123. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  124. else: # list or tuple of packages
  125. requirements = [x for x in requirements if x not in exclude]
  126. n = 0 # number of packages updates
  127. for r in requirements:
  128. try:
  129. pkg.require(r)
  130. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  131. print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...")
  132. try:
  133. assert check_online(), f"'pip install {r}' skipped (offline)"
  134. print(check_output(f"pip install '{r}'", shell=True).decode())
  135. n += 1
  136. except Exception as e:
  137. print(f'{prefix} {e}')
  138. if n: # if packages updated
  139. source = file.resolve() if 'file' in locals() else requirements
  140. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  141. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  142. print(emojis(s)) # emoji-safe
  143. def check_img_size(img_size, s=32):
  144. # Verify img_size is a multiple of stride s
  145. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  146. if new_size != img_size:
  147. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  148. return new_size
  149. def check_imshow():
  150. # Check if environment supports image displays
  151. try:
  152. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  153. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  154. cv2.imshow('test', np.zeros((1, 1, 3)))
  155. cv2.waitKey(1)
  156. cv2.destroyAllWindows()
  157. cv2.waitKey(1)
  158. return True
  159. except Exception as e:
  160. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  161. return False
  162. def check_file(file):
  163. # Search/download file (if necessary) and return path
  164. file = str(file) # convert to str()
  165. if Path(file).is_file() or file == '': # exists
  166. return file
  167. elif file.startswith(('http:/', 'https:/')): # download
  168. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  169. file = Path(urllib.parse.unquote(file)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  170. print(f'Downloading {url} to {file}...')
  171. torch.hub.download_url_to_file(url, file)
  172. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  173. return file
  174. else: # search
  175. files = glob.glob('./**/' + file, recursive=True) # find file
  176. assert len(files), f'File not found: {file}' # assert file was found
  177. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  178. return files[0] # return file
  179. def check_dataset(data, autodownload=True):
  180. # Download dataset if not found locally
  181. path = Path(data.get('path', '')) # optional 'path' field
  182. if path:
  183. for k in 'train', 'val', 'test':
  184. if data.get(k): # prepend path
  185. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  186. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  187. if val:
  188. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  189. if not all(x.exists() for x in val):
  190. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  191. if s and autodownload: # download script
  192. if s.startswith('http') and s.endswith('.zip'): # URL
  193. f = Path(s).name # filename
  194. print(f'Downloading {s} ...')
  195. torch.hub.download_url_to_file(s, f)
  196. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  197. Path(root).mkdir(parents=True, exist_ok=True) # create root
  198. r = os.system(f'unzip -q {f} -d {root} && rm {f}') # unzip
  199. elif s.startswith('bash '): # bash script
  200. print(f'Running {s} ...')
  201. r = os.system(s)
  202. else: # python script
  203. r = exec(s, {'yaml': data}) # return None
  204. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  205. else:
  206. raise Exception('Dataset not found.')
  207. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  208. # Multi-threaded file download and unzip function
  209. def download_one(url, dir):
  210. # Download 1 file
  211. f = dir / Path(url).name # filename
  212. if not f.exists():
  213. print(f'Downloading {url} to {f}...')
  214. if curl:
  215. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  216. else:
  217. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  218. if unzip and f.suffix in ('.zip', '.gz'):
  219. print(f'Unzipping {f}...')
  220. if f.suffix == '.zip':
  221. s = f'unzip -qo {f} -d {dir}' # unzip -quiet -overwrite
  222. elif f.suffix == '.gz':
  223. s = f'tar xfz {f} --directory {f.parent}' # unzip
  224. if delete: # delete zip file after unzip
  225. s += f' && rm {f}'
  226. os.system(s)
  227. dir = Path(dir)
  228. dir.mkdir(parents=True, exist_ok=True) # make directory
  229. if threads > 1:
  230. pool = ThreadPool(threads)
  231. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  232. pool.close()
  233. pool.join()
  234. else:
  235. for u in tuple(url) if isinstance(url, str) else url:
  236. download_one(u, dir)
  237. def make_divisible(x, divisor):
  238. # Returns x evenly divisible by divisor
  239. return math.ceil(x / divisor) * divisor
  240. def clean_str(s):
  241. # Cleans a string by replacing special characters with underscore _
  242. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  243. def one_cycle(y1=0.0, y2=1.0, steps=100):
  244. # lambda function for sinusoidal ramp from y1 to y2
  245. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  246. def colorstr(*input):
  247. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  248. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  249. colors = {'black': '\033[30m', # basic colors
  250. 'red': '\033[31m',
  251. 'green': '\033[32m',
  252. 'yellow': '\033[33m',
  253. 'blue': '\033[34m',
  254. 'magenta': '\033[35m',
  255. 'cyan': '\033[36m',
  256. 'white': '\033[37m',
  257. 'bright_black': '\033[90m', # bright colors
  258. 'bright_red': '\033[91m',
  259. 'bright_green': '\033[92m',
  260. 'bright_yellow': '\033[93m',
  261. 'bright_blue': '\033[94m',
  262. 'bright_magenta': '\033[95m',
  263. 'bright_cyan': '\033[96m',
  264. 'bright_white': '\033[97m',
  265. 'end': '\033[0m', # misc
  266. 'bold': '\033[1m',
  267. 'underline': '\033[4m'}
  268. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  269. def labels_to_class_weights(labels, nc=80):
  270. # Get class weights (inverse frequency) from training labels
  271. if labels[0] is None: # no labels loaded
  272. return torch.Tensor()
  273. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  274. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  275. weights = np.bincount(classes, minlength=nc) # occurrences per class
  276. # Prepend gridpoint count (for uCE training)
  277. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  278. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  279. weights[weights == 0] = 1 # replace empty bins with 1
  280. weights = 1 / weights # number of targets per class
  281. weights /= weights.sum() # normalize
  282. return torch.from_numpy(weights)
  283. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  284. # Produces image weights based on class_weights and image contents
  285. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  286. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  287. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  288. return image_weights
  289. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  290. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  291. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  292. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  293. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  294. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  295. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  296. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  297. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  298. return x
  299. def xyxy2xywh(x):
  300. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  301. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  302. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  303. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  304. y[:, 2] = x[:, 2] - x[:, 0] # width
  305. y[:, 3] = x[:, 3] - x[:, 1] # height
  306. return y
  307. def xywh2xyxy(x):
  308. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  309. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  310. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  311. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  312. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  313. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  314. return y
  315. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  316. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  317. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  318. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  319. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  320. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  321. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  322. return y
  323. def xyxy2xywhn(x, w=640, h=640):
  324. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  325. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  326. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  327. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  328. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  329. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  330. return y
  331. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  332. # Convert normalized segments into pixel segments, shape (n,2)
  333. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  334. y[:, 0] = w * x[:, 0] + padw # top left x
  335. y[:, 1] = h * x[:, 1] + padh # top left y
  336. return y
  337. def segment2box(segment, width=640, height=640):
  338. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  339. x, y = segment.T # segment xy
  340. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  341. x, y, = x[inside], y[inside]
  342. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  343. def segments2boxes(segments):
  344. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  345. boxes = []
  346. for s in segments:
  347. x, y = s.T # segment xy
  348. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  349. return xyxy2xywh(np.array(boxes)) # cls, xywh
  350. def resample_segments(segments, n=1000):
  351. # Up-sample an (n,2) segment
  352. for i, s in enumerate(segments):
  353. x = np.linspace(0, len(s) - 1, n)
  354. xp = np.arange(len(s))
  355. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  356. return segments
  357. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  358. # Rescale coords (xyxy) from img1_shape to img0_shape
  359. if ratio_pad is None: # calculate from img0_shape
  360. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  361. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  362. else:
  363. gain = ratio_pad[0][0]
  364. pad = ratio_pad[1]
  365. coords[:, [0, 2]] -= pad[0] # x padding
  366. coords[:, [1, 3]] -= pad[1] # y padding
  367. coords[:, :4] /= gain
  368. clip_coords(coords, img0_shape)
  369. return coords
  370. def clip_coords(boxes, img_shape):
  371. # Clip bounding xyxy bounding boxes to image shape (height, width)
  372. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  373. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  374. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  375. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  376. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  377. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  378. box2 = box2.T
  379. # Get the coordinates of bounding boxes
  380. if x1y1x2y2: # x1, y1, x2, y2 = box1
  381. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  382. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  383. else: # transform from xywh to xyxy
  384. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  385. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  386. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  387. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  388. # Intersection area
  389. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  390. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  391. # Union Area
  392. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  393. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  394. union = w1 * h1 + w2 * h2 - inter + eps
  395. iou = inter / union
  396. if GIoU or DIoU or CIoU:
  397. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  398. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  399. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  400. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  401. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  402. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  403. if DIoU:
  404. return iou - rho2 / c2 # DIoU
  405. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  406. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  407. with torch.no_grad():
  408. alpha = v / (v - iou + (1 + eps))
  409. return iou - (rho2 / c2 + v * alpha) # CIoU
  410. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  411. c_area = cw * ch + eps # convex area
  412. return iou - (c_area - union) / c_area # GIoU
  413. else:
  414. return iou # IoU
  415. def box_iou(box1, box2):
  416. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  417. """
  418. Return intersection-over-union (Jaccard index) of boxes.
  419. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  420. Arguments:
  421. box1 (Tensor[N, 4])
  422. box2 (Tensor[M, 4])
  423. Returns:
  424. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  425. IoU values for every element in boxes1 and boxes2
  426. """
  427. def box_area(box):
  428. # box = 4xn
  429. return (box[2] - box[0]) * (box[3] - box[1])
  430. area1 = box_area(box1.T)
  431. area2 = box_area(box2.T)
  432. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  433. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  434. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  435. def wh_iou(wh1, wh2):
  436. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  437. wh1 = wh1[:, None] # [N,1,2]
  438. wh2 = wh2[None] # [1,M,2]
  439. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  440. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  441. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  442. labels=(), max_det=300):
  443. """Runs Non-Maximum Suppression (NMS) on inference results
  444. Returns:
  445. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  446. """
  447. nc = prediction.shape[2] - 5 # number of classes
  448. xc = prediction[..., 4] > conf_thres # candidates
  449. # Checks
  450. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  451. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  452. # Settings
  453. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  454. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  455. time_limit = 10.0 # seconds to quit after
  456. redundant = True # require redundant detections
  457. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  458. merge = False # use merge-NMS
  459. t = time.time()
  460. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  461. for xi, x in enumerate(prediction): # image index, image inference
  462. # Apply constraints
  463. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  464. x = x[xc[xi]] # confidence
  465. # Cat apriori labels if autolabelling
  466. if labels and len(labels[xi]):
  467. l = labels[xi]
  468. v = torch.zeros((len(l), nc + 5), device=x.device)
  469. v[:, :4] = l[:, 1:5] # box
  470. v[:, 4] = 1.0 # conf
  471. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  472. x = torch.cat((x, v), 0)
  473. # If none remain process next image
  474. if not x.shape[0]:
  475. continue
  476. # Compute conf
  477. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  478. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  479. box = xywh2xyxy(x[:, :4])
  480. # Detections matrix nx6 (xyxy, conf, cls)
  481. if multi_label:
  482. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  483. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  484. else: # best class only
  485. conf, j = x[:, 5:].max(1, keepdim=True)
  486. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  487. # Filter by class
  488. if classes is not None:
  489. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  490. # Apply finite constraint
  491. # if not torch.isfinite(x).all():
  492. # x = x[torch.isfinite(x).all(1)]
  493. # Check shape
  494. n = x.shape[0] # number of boxes
  495. if not n: # no boxes
  496. continue
  497. elif n > max_nms: # excess boxes
  498. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  499. # Batched NMS
  500. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  501. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  502. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  503. if i.shape[0] > max_det: # limit detections
  504. i = i[:max_det]
  505. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  506. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  507. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  508. weights = iou * scores[None] # box weights
  509. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  510. if redundant:
  511. i = i[iou.sum(1) > 1] # require redundancy
  512. output[xi] = x[i]
  513. if (time.time() - t) > time_limit:
  514. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  515. break # time limit exceeded
  516. return output
  517. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  518. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  519. x = torch.load(f, map_location=torch.device('cpu'))
  520. if x.get('ema'):
  521. x['model'] = x['ema'] # replace model with ema
  522. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  523. x[k] = None
  524. x['epoch'] = -1
  525. x['model'].half() # to FP16
  526. for p in x['model'].parameters():
  527. p.requires_grad = False
  528. torch.save(x, s or f)
  529. mb = os.path.getsize(s or f) / 1E6 # filesize
  530. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  531. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  532. # Print mutation results to evolve.txt (for use with train.py --evolve)
  533. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  534. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  535. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  536. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  537. if bucket:
  538. url = 'gs://%s/evolve.txt' % bucket
  539. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  540. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  541. with open('evolve.txt', 'a') as f: # append result
  542. f.write(c + b + '\n')
  543. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  544. x = x[np.argsort(-fitness(x))] # sort
  545. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  546. # Save yaml
  547. for i, k in enumerate(hyp.keys()):
  548. hyp[k] = float(x[0, i + 7])
  549. with open(yaml_file, 'w') as f:
  550. results = tuple(x[0, :7])
  551. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  552. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  553. yaml.safe_dump(hyp, f, sort_keys=False)
  554. if bucket:
  555. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  556. def apply_classifier(x, model, img, im0):
  557. # Apply a second stage classifier to yolo outputs
  558. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  559. for i, d in enumerate(x): # per image
  560. if d is not None and len(d):
  561. d = d.clone()
  562. # Reshape and pad cutouts
  563. b = xyxy2xywh(d[:, :4]) # boxes
  564. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  565. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  566. d[:, :4] = xywh2xyxy(b).long()
  567. # Rescale boxes from img_size to im0 size
  568. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  569. # Classes
  570. pred_cls1 = d[:, 5].long()
  571. ims = []
  572. for j, a in enumerate(d): # per item
  573. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  574. im = cv2.resize(cutout, (224, 224)) # BGR
  575. # cv2.imwrite('test%i.jpg' % j, cutout)
  576. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  577. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  578. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  579. ims.append(im)
  580. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  581. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  582. return x
  583. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  584. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  585. xyxy = torch.tensor(xyxy).view(-1, 4)
  586. b = xyxy2xywh(xyxy) # boxes
  587. if square:
  588. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  589. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  590. xyxy = xywh2xyxy(b).long()
  591. clip_coords(xyxy, im.shape)
  592. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  593. if save:
  594. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  595. return crop
  596. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  597. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  598. path = Path(path) # os-agnostic
  599. if path.exists() and not exist_ok:
  600. suffix = path.suffix
  601. path = path.with_suffix('')
  602. dirs = glob.glob(f"{path}{sep}*") # similar paths
  603. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  604. i = [int(m.groups()[0]) for m in matches if m] # indices
  605. n = max(i) + 1 if i else 2 # increment number
  606. path = Path(f"{path}{sep}{n}{suffix}") # update path
  607. dir = path if path.suffix == '' else path.parent # directory
  608. if not dir.exists() and mkdir:
  609. dir.mkdir(parents=True, exist_ok=True) # make directory
  610. return path