You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

705 lines
29KB

  1. # YOLOv5 general utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from itertools import repeat
  12. from multiprocessing.pool import ThreadPool
  13. from pathlib import Path
  14. import cv2
  15. import numpy as np
  16. import pandas as pd
  17. import pkg_resources as pkg
  18. import torch
  19. import torchvision
  20. import yaml
  21. from utils.google_utils import gsutil_getsize
  22. from utils.metrics import fitness
  23. from utils.torch_utils import init_torch_seeds
  24. # Settings
  25. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  26. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  27. pd.options.display.max_columns = 10
  28. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  29. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  30. def set_logging(rank=-1, verbose=True):
  31. logging.basicConfig(
  32. format="%(message)s",
  33. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  34. def init_seeds(seed=0):
  35. # Initialize random number generator (RNG) seeds
  36. random.seed(seed)
  37. np.random.seed(seed)
  38. init_torch_seeds(seed)
  39. def get_latest_run(search_dir='.'):
  40. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  41. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  42. return max(last_list, key=os.path.getctime) if last_list else ''
  43. def is_docker():
  44. # Is environment a Docker container?
  45. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  46. def is_colab():
  47. # Is environment a Google Colab instance?
  48. try:
  49. import google.colab
  50. return True
  51. except Exception as e:
  52. return False
  53. def is_pip():
  54. # Is file in a pip package?
  55. return 'site-packages' in Path(__file__).absolute().parts
  56. def emojis(str=''):
  57. # Return platform-dependent emoji-safe version of string
  58. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  59. def file_size(file):
  60. # Return file size in MB
  61. return Path(file).stat().st_size / 1e6
  62. def check_online():
  63. # Check internet connectivity
  64. import socket
  65. try:
  66. socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
  67. return True
  68. except OSError:
  69. return False
  70. def check_git_status():
  71. # Recommend 'git pull' if code is out of date
  72. print(colorstr('github: '), end='')
  73. try:
  74. assert Path('.git').exists(), 'skipping check (not a git repository)'
  75. assert not is_docker(), 'skipping check (Docker image)'
  76. assert check_online(), 'skipping check (offline)'
  77. cmd = 'git fetch && git config --get remote.origin.url'
  78. url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
  79. branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  80. n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  81. if n > 0:
  82. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  83. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  84. else:
  85. s = f'up to date with {url} ✅'
  86. print(emojis(s)) # emoji-safe
  87. except Exception as e:
  88. print(e)
  89. def check_python(minimum='3.7.0', required=True):
  90. # Check current python version vs. required python version
  91. current = platform.python_version()
  92. result = pkg.parse_version(current) >= pkg.parse_version(minimum)
  93. if required:
  94. assert result, f'Python {minimum} required by YOLOv5, but Python {current} is currently installed'
  95. return result
  96. def check_requirements(requirements='requirements.txt', exclude=()):
  97. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  98. prefix = colorstr('red', 'bold', 'requirements:')
  99. check_python() # check python version
  100. if isinstance(requirements, (str, Path)): # requirements.txt file
  101. file = Path(requirements)
  102. if not file.exists():
  103. print(f"{prefix} {file.resolve()} not found, check failed.")
  104. return
  105. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  106. else: # list or tuple of packages
  107. requirements = [x for x in requirements if x not in exclude]
  108. n = 0 # number of packages updates
  109. for r in requirements:
  110. try:
  111. pkg.require(r)
  112. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  113. n += 1
  114. print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...")
  115. try:
  116. print(subprocess.check_output(f"pip install '{r}'", shell=True).decode())
  117. except Exception as e:
  118. print(f'{prefix} {e}')
  119. if n: # if packages updated
  120. source = file.resolve() if 'file' in locals() else requirements
  121. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  122. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  123. print(emojis(s)) # emoji-safe
  124. def check_img_size(img_size, s=32):
  125. # Verify img_size is a multiple of stride s
  126. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  127. if new_size != img_size:
  128. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  129. return new_size
  130. def check_imshow():
  131. # Check if environment supports image displays
  132. try:
  133. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  134. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  135. cv2.imshow('test', np.zeros((1, 1, 3)))
  136. cv2.waitKey(1)
  137. cv2.destroyAllWindows()
  138. cv2.waitKey(1)
  139. return True
  140. except Exception as e:
  141. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  142. return False
  143. def check_file(file):
  144. # Search/download file (if necessary) and return path
  145. file = str(file) # convert to str()
  146. if Path(file).is_file() or file == '': # exists
  147. return file
  148. elif file.startswith(('http://', 'https://')): # download
  149. url, file = file, Path(file).name
  150. print(f'Downloading {url} to {file}...')
  151. torch.hub.download_url_to_file(url, file)
  152. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  153. return file
  154. else: # search
  155. files = glob.glob('./**/' + file, recursive=True) # find file
  156. assert len(files), f'File not found: {file}' # assert file was found
  157. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  158. return files[0] # return file
  159. def check_dataset(dict):
  160. # Download dataset if not found locally
  161. val, s = dict.get('val'), dict.get('download')
  162. if val and len(val):
  163. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  164. if not all(x.exists() for x in val):
  165. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  166. if s and len(s): # download script
  167. if s.startswith('http') and s.endswith('.zip'): # URL
  168. f = Path(s).name # filename
  169. print(f'Downloading {s} ...')
  170. torch.hub.download_url_to_file(s, f)
  171. r = os.system(f'unzip -q {f} -d ../ && rm {f}') # unzip
  172. elif s.startswith('bash '): # bash script
  173. print(f'Running {s} ...')
  174. r = os.system(s)
  175. else: # python script
  176. r = exec(s) # return None
  177. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  178. else:
  179. raise Exception('Dataset not found.')
  180. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  181. # Multi-threaded file download and unzip function
  182. def download_one(url, dir):
  183. # Download 1 file
  184. f = dir / Path(url).name # filename
  185. if not f.exists():
  186. print(f'Downloading {url} to {f}...')
  187. if curl:
  188. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  189. else:
  190. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  191. if unzip and f.suffix in ('.zip', '.gz'):
  192. print(f'Unzipping {f}...')
  193. if f.suffix == '.zip':
  194. s = f'unzip -qo {f} -d {dir} && rm {f}' # unzip -quiet -overwrite
  195. elif f.suffix == '.gz':
  196. s = f'tar xfz {f} --directory {f.parent}' # unzip
  197. if delete: # delete zip file after unzip
  198. s += f' && rm {f}'
  199. os.system(s)
  200. dir = Path(dir)
  201. dir.mkdir(parents=True, exist_ok=True) # make directory
  202. if threads > 1:
  203. pool = ThreadPool(threads)
  204. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  205. pool.close()
  206. pool.join()
  207. else:
  208. for u in tuple(url) if isinstance(url, str) else url:
  209. download_one(u, dir)
  210. def make_divisible(x, divisor):
  211. # Returns x evenly divisible by divisor
  212. return math.ceil(x / divisor) * divisor
  213. def clean_str(s):
  214. # Cleans a string by replacing special characters with underscore _
  215. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  216. def one_cycle(y1=0.0, y2=1.0, steps=100):
  217. # lambda function for sinusoidal ramp from y1 to y2
  218. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  219. def colorstr(*input):
  220. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  221. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  222. colors = {'black': '\033[30m', # basic colors
  223. 'red': '\033[31m',
  224. 'green': '\033[32m',
  225. 'yellow': '\033[33m',
  226. 'blue': '\033[34m',
  227. 'magenta': '\033[35m',
  228. 'cyan': '\033[36m',
  229. 'white': '\033[37m',
  230. 'bright_black': '\033[90m', # bright colors
  231. 'bright_red': '\033[91m',
  232. 'bright_green': '\033[92m',
  233. 'bright_yellow': '\033[93m',
  234. 'bright_blue': '\033[94m',
  235. 'bright_magenta': '\033[95m',
  236. 'bright_cyan': '\033[96m',
  237. 'bright_white': '\033[97m',
  238. 'end': '\033[0m', # misc
  239. 'bold': '\033[1m',
  240. 'underline': '\033[4m'}
  241. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  242. def labels_to_class_weights(labels, nc=80):
  243. # Get class weights (inverse frequency) from training labels
  244. if labels[0] is None: # no labels loaded
  245. return torch.Tensor()
  246. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  247. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  248. weights = np.bincount(classes, minlength=nc) # occurrences per class
  249. # Prepend gridpoint count (for uCE training)
  250. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  251. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  252. weights[weights == 0] = 1 # replace empty bins with 1
  253. weights = 1 / weights # number of targets per class
  254. weights /= weights.sum() # normalize
  255. return torch.from_numpy(weights)
  256. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  257. # Produces image weights based on class_weights and image contents
  258. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  259. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  260. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  261. return image_weights
  262. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  263. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  264. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  265. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  266. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  267. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  268. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  269. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  270. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  271. return x
  272. def xyxy2xywh(x):
  273. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  274. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  275. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  276. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  277. y[:, 2] = x[:, 2] - x[:, 0] # width
  278. y[:, 3] = x[:, 3] - x[:, 1] # height
  279. return y
  280. def xywh2xyxy(x):
  281. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  282. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  283. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  284. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  285. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  286. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  287. return y
  288. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  289. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  290. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  291. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  292. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  293. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  294. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  295. return y
  296. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  297. # Convert normalized segments into pixel segments, shape (n,2)
  298. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  299. y[:, 0] = w * x[:, 0] + padw # top left x
  300. y[:, 1] = h * x[:, 1] + padh # top left y
  301. return y
  302. def segment2box(segment, width=640, height=640):
  303. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  304. x, y = segment.T # segment xy
  305. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  306. x, y, = x[inside], y[inside]
  307. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  308. def segments2boxes(segments):
  309. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  310. boxes = []
  311. for s in segments:
  312. x, y = s.T # segment xy
  313. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  314. return xyxy2xywh(np.array(boxes)) # cls, xywh
  315. def resample_segments(segments, n=1000):
  316. # Up-sample an (n,2) segment
  317. for i, s in enumerate(segments):
  318. x = np.linspace(0, len(s) - 1, n)
  319. xp = np.arange(len(s))
  320. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  321. return segments
  322. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  323. # Rescale coords (xyxy) from img1_shape to img0_shape
  324. if ratio_pad is None: # calculate from img0_shape
  325. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  326. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  327. else:
  328. gain = ratio_pad[0][0]
  329. pad = ratio_pad[1]
  330. coords[:, [0, 2]] -= pad[0] # x padding
  331. coords[:, [1, 3]] -= pad[1] # y padding
  332. coords[:, :4] /= gain
  333. clip_coords(coords, img0_shape)
  334. return coords
  335. def clip_coords(boxes, img_shape):
  336. # Clip bounding xyxy bounding boxes to image shape (height, width)
  337. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  338. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  339. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  340. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  341. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  342. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  343. box2 = box2.T
  344. # Get the coordinates of bounding boxes
  345. if x1y1x2y2: # x1, y1, x2, y2 = box1
  346. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  347. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  348. else: # transform from xywh to xyxy
  349. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  350. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  351. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  352. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  353. # Intersection area
  354. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  355. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  356. # Union Area
  357. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  358. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  359. union = w1 * h1 + w2 * h2 - inter + eps
  360. iou = inter / union
  361. if GIoU or DIoU or CIoU:
  362. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  363. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  364. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  365. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  366. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  367. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  368. if DIoU:
  369. return iou - rho2 / c2 # DIoU
  370. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  371. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  372. with torch.no_grad():
  373. alpha = v / (v - iou + (1 + eps))
  374. return iou - (rho2 / c2 + v * alpha) # CIoU
  375. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  376. c_area = cw * ch + eps # convex area
  377. return iou - (c_area - union) / c_area # GIoU
  378. else:
  379. return iou # IoU
  380. def box_iou(box1, box2):
  381. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  382. """
  383. Return intersection-over-union (Jaccard index) of boxes.
  384. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  385. Arguments:
  386. box1 (Tensor[N, 4])
  387. box2 (Tensor[M, 4])
  388. Returns:
  389. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  390. IoU values for every element in boxes1 and boxes2
  391. """
  392. def box_area(box):
  393. # box = 4xn
  394. return (box[2] - box[0]) * (box[3] - box[1])
  395. area1 = box_area(box1.T)
  396. area2 = box_area(box2.T)
  397. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  398. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  399. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  400. def wh_iou(wh1, wh2):
  401. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  402. wh1 = wh1[:, None] # [N,1,2]
  403. wh2 = wh2[None] # [1,M,2]
  404. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  405. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  406. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  407. labels=(), max_det=300):
  408. """Runs Non-Maximum Suppression (NMS) on inference results
  409. Returns:
  410. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  411. """
  412. nc = prediction.shape[2] - 5 # number of classes
  413. xc = prediction[..., 4] > conf_thres # candidates
  414. # Checks
  415. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  416. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  417. # Settings
  418. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  419. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  420. time_limit = 10.0 # seconds to quit after
  421. redundant = True # require redundant detections
  422. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  423. merge = False # use merge-NMS
  424. t = time.time()
  425. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  426. for xi, x in enumerate(prediction): # image index, image inference
  427. # Apply constraints
  428. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  429. x = x[xc[xi]] # confidence
  430. # Cat apriori labels if autolabelling
  431. if labels and len(labels[xi]):
  432. l = labels[xi]
  433. v = torch.zeros((len(l), nc + 5), device=x.device)
  434. v[:, :4] = l[:, 1:5] # box
  435. v[:, 4] = 1.0 # conf
  436. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  437. x = torch.cat((x, v), 0)
  438. # If none remain process next image
  439. if not x.shape[0]:
  440. continue
  441. # Compute conf
  442. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  443. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  444. box = xywh2xyxy(x[:, :4])
  445. # Detections matrix nx6 (xyxy, conf, cls)
  446. if multi_label:
  447. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  448. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  449. else: # best class only
  450. conf, j = x[:, 5:].max(1, keepdim=True)
  451. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  452. # Filter by class
  453. if classes is not None:
  454. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  455. # Apply finite constraint
  456. # if not torch.isfinite(x).all():
  457. # x = x[torch.isfinite(x).all(1)]
  458. # Check shape
  459. n = x.shape[0] # number of boxes
  460. if not n: # no boxes
  461. continue
  462. elif n > max_nms: # excess boxes
  463. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  464. # Batched NMS
  465. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  466. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  467. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  468. if i.shape[0] > max_det: # limit detections
  469. i = i[:max_det]
  470. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  471. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  472. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  473. weights = iou * scores[None] # box weights
  474. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  475. if redundant:
  476. i = i[iou.sum(1) > 1] # require redundancy
  477. output[xi] = x[i]
  478. if (time.time() - t) > time_limit:
  479. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  480. break # time limit exceeded
  481. return output
  482. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  483. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  484. x = torch.load(f, map_location=torch.device('cpu'))
  485. if x.get('ema'):
  486. x['model'] = x['ema'] # replace model with ema
  487. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  488. x[k] = None
  489. x['epoch'] = -1
  490. x['model'].half() # to FP16
  491. for p in x['model'].parameters():
  492. p.requires_grad = False
  493. torch.save(x, s or f)
  494. mb = os.path.getsize(s or f) / 1E6 # filesize
  495. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  496. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  497. # Print mutation results to evolve.txt (for use with train.py --evolve)
  498. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  499. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  500. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  501. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  502. if bucket:
  503. url = 'gs://%s/evolve.txt' % bucket
  504. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  505. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  506. with open('evolve.txt', 'a') as f: # append result
  507. f.write(c + b + '\n')
  508. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  509. x = x[np.argsort(-fitness(x))] # sort
  510. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  511. # Save yaml
  512. for i, k in enumerate(hyp.keys()):
  513. hyp[k] = float(x[0, i + 7])
  514. with open(yaml_file, 'w') as f:
  515. results = tuple(x[0, :7])
  516. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  517. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  518. yaml.safe_dump(hyp, f, sort_keys=False)
  519. if bucket:
  520. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  521. def apply_classifier(x, model, img, im0):
  522. # Apply a second stage classifier to yolo outputs
  523. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  524. for i, d in enumerate(x): # per image
  525. if d is not None and len(d):
  526. d = d.clone()
  527. # Reshape and pad cutouts
  528. b = xyxy2xywh(d[:, :4]) # boxes
  529. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  530. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  531. d[:, :4] = xywh2xyxy(b).long()
  532. # Rescale boxes from img_size to im0 size
  533. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  534. # Classes
  535. pred_cls1 = d[:, 5].long()
  536. ims = []
  537. for j, a in enumerate(d): # per item
  538. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  539. im = cv2.resize(cutout, (224, 224)) # BGR
  540. # cv2.imwrite('test%i.jpg' % j, cutout)
  541. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  542. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  543. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  544. ims.append(im)
  545. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  546. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  547. return x
  548. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  549. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  550. xyxy = torch.tensor(xyxy).view(-1, 4)
  551. b = xyxy2xywh(xyxy) # boxes
  552. if square:
  553. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  554. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  555. xyxy = xywh2xyxy(b).long()
  556. clip_coords(xyxy, im.shape)
  557. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  558. if save:
  559. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  560. return crop
  561. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  562. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  563. path = Path(path) # os-agnostic
  564. if path.exists() and not exist_ok:
  565. suffix = path.suffix
  566. path = path.with_suffix('')
  567. dirs = glob.glob(f"{path}{sep}*") # similar paths
  568. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  569. i = [int(m.groups()[0]) for m in matches if m] # indices
  570. n = max(i) + 1 if i else 2 # increment number
  571. path = Path(f"{path}{sep}{n}{suffix}") # update path
  572. dir = path if path.suffix == '' else path.parent # directory
  573. if not dir.exists() and mkdir:
  574. dir.mkdir(parents=True, exist_ok=True) # make directory
  575. return path