You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

700 lines
29KB

  1. # YOLOv5 general utils
  2. import contextlib
  3. import glob
  4. import logging
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import signal
  10. import time
  11. import urllib
  12. from itertools import repeat
  13. from multiprocessing.pool import ThreadPool
  14. from pathlib import Path
  15. from subprocess import check_output
  16. import cv2
  17. import math
  18. import numpy as np
  19. import pandas as pd
  20. import pkg_resources as pkg
  21. import torch
  22. import torchvision
  23. import yaml
  24. from utils.downloads import gsutil_getsize
  25. from utils.metrics import box_iou, fitness
  26. from utils.torch_utils import init_torch_seeds
  27. # Settings
  28. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  29. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  30. pd.options.display.max_columns = 10
  31. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  32. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  33. class timeout(contextlib.ContextDecorator):
  34. # Usage: @timeout(seconds) decorator or 'with timeout(seconds):' context manager
  35. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  36. self.seconds = int(seconds)
  37. self.timeout_message = timeout_msg
  38. self.suppress = bool(suppress_timeout_errors)
  39. def _timeout_handler(self, signum, frame):
  40. raise TimeoutError(self.timeout_message)
  41. def __enter__(self):
  42. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  43. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  44. def __exit__(self, exc_type, exc_val, exc_tb):
  45. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  46. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  47. return True
  48. def set_logging(rank=-1, verbose=True):
  49. logging.basicConfig(
  50. format="%(message)s",
  51. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  52. def init_seeds(seed=0):
  53. # Initialize random number generator (RNG) seeds
  54. random.seed(seed)
  55. np.random.seed(seed)
  56. init_torch_seeds(seed)
  57. def get_latest_run(search_dir='.'):
  58. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  59. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  60. return max(last_list, key=os.path.getctime) if last_list else ''
  61. def is_docker():
  62. # Is environment a Docker container?
  63. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  64. def is_colab():
  65. # Is environment a Google Colab instance?
  66. try:
  67. import google.colab
  68. return True
  69. except Exception as e:
  70. return False
  71. def is_pip():
  72. # Is file in a pip package?
  73. return 'site-packages' in Path(__file__).absolute().parts
  74. def emojis(str=''):
  75. # Return platform-dependent emoji-safe version of string
  76. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  77. def file_size(file):
  78. # Return file size in MB
  79. return Path(file).stat().st_size / 1e6
  80. def check_online():
  81. # Check internet connectivity
  82. import socket
  83. try:
  84. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  85. return True
  86. except OSError:
  87. return False
  88. def check_git_status(err_msg=', for updates see https://github.com/ultralytics/yolov5'):
  89. # Recommend 'git pull' if code is out of date
  90. print(colorstr('github: '), end='')
  91. try:
  92. assert Path('.git').exists(), 'skipping check (not a git repository)'
  93. assert not is_docker(), 'skipping check (Docker image)'
  94. assert check_online(), 'skipping check (offline)'
  95. cmd = 'git fetch && git config --get remote.origin.url'
  96. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  97. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  98. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  99. if n > 0:
  100. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  101. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  102. else:
  103. s = f'up to date with {url} ✅'
  104. print(emojis(s)) # emoji-safe
  105. except Exception as e:
  106. print(f'{e}{err_msg}')
  107. def check_python(minimum='3.6.2'):
  108. # Check current python version vs. required python version
  109. check_version(platform.python_version(), minimum, name='Python ')
  110. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
  111. # Check version vs. required version
  112. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  113. result = (current == minimum) if pinned else (current >= minimum)
  114. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  115. def check_requirements(requirements='requirements.txt', exclude=()):
  116. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  117. prefix = colorstr('red', 'bold', 'requirements:')
  118. check_python() # check python version
  119. if isinstance(requirements, (str, Path)): # requirements.txt file
  120. file = Path(requirements)
  121. if not file.exists():
  122. print(f"{prefix} {file.resolve()} not found, check failed.")
  123. return
  124. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  125. else: # list or tuple of packages
  126. requirements = [x for x in requirements if x not in exclude]
  127. n = 0 # number of packages updates
  128. for r in requirements:
  129. try:
  130. pkg.require(r)
  131. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  132. print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...")
  133. try:
  134. assert check_online(), f"'pip install {r}' skipped (offline)"
  135. print(check_output(f"pip install '{r}'", shell=True).decode())
  136. n += 1
  137. except Exception as e:
  138. print(f'{prefix} {e}')
  139. if n: # if packages updated
  140. source = file.resolve() if 'file' in locals() else requirements
  141. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  142. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  143. print(emojis(s)) # emoji-safe
  144. def check_img_size(img_size, s=32, floor=0):
  145. # Verify img_size is a multiple of stride s
  146. new_size = max(make_divisible(img_size, int(s)), floor) # ceil gs-multiple
  147. if new_size != img_size:
  148. print(f'WARNING: --img-size {img_size} must be multiple of max stride {s}, updating to {new_size}')
  149. return new_size
  150. def check_imshow():
  151. # Check if environment supports image displays
  152. try:
  153. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  154. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  155. cv2.imshow('test', np.zeros((1, 1, 3)))
  156. cv2.waitKey(1)
  157. cv2.destroyAllWindows()
  158. cv2.waitKey(1)
  159. return True
  160. except Exception as e:
  161. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  162. return False
  163. def check_file(file):
  164. # Search/download file (if necessary) and return path
  165. file = str(file) # convert to str()
  166. if Path(file).is_file() or file == '': # exists
  167. return file
  168. elif file.startswith(('http:/', 'https:/')): # download
  169. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  170. file = Path(urllib.parse.unquote(file)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  171. print(f'Downloading {url} to {file}...')
  172. torch.hub.download_url_to_file(url, file)
  173. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  174. return file
  175. else: # search
  176. files = glob.glob('./**/' + file, recursive=True) # find file
  177. assert len(files), f'File not found: {file}' # assert file was found
  178. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  179. return files[0] # return file
  180. def check_dataset(data, autodownload=True):
  181. # Download and/or unzip dataset if not found locally
  182. # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip
  183. # Download (optional)
  184. extract_dir = ''
  185. if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
  186. download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
  187. data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
  188. extract_dir, autodownload = data.parent, False
  189. # Read yaml (optional)
  190. if isinstance(data, (str, Path)):
  191. with open(data, encoding='ascii', errors='ignore') as f:
  192. data = yaml.safe_load(f) # dictionary
  193. # Parse yaml
  194. path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
  195. for k in 'train', 'val', 'test':
  196. if data.get(k): # prepend path
  197. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  198. assert 'nc' in data, "Dataset 'nc' key missing."
  199. if 'names' not in data:
  200. data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
  201. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  202. if val:
  203. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  204. if not all(x.exists() for x in val):
  205. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  206. if s and autodownload: # download script
  207. if s.startswith('http') and s.endswith('.zip'): # URL
  208. f = Path(s).name # filename
  209. print(f'Downloading {s} ...')
  210. torch.hub.download_url_to_file(s, f)
  211. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  212. Path(root).mkdir(parents=True, exist_ok=True) # create root
  213. r = os.system(f'unzip -q {f} -d {root} && rm {f}') # unzip
  214. elif s.startswith('bash '): # bash script
  215. print(f'Running {s} ...')
  216. r = os.system(s)
  217. else: # python script
  218. r = exec(s, {'yaml': data}) # return None
  219. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  220. else:
  221. raise Exception('Dataset not found.')
  222. return data # dictionary
  223. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  224. # Multi-threaded file download and unzip function, used in data.yaml for autodownload
  225. def download_one(url, dir):
  226. # Download 1 file
  227. f = dir / Path(url).name # filename
  228. if Path(url).is_file(): # exists in current path
  229. Path(url).rename(f) # move to dir
  230. elif not f.exists():
  231. print(f'Downloading {url} to {f}...')
  232. if curl:
  233. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  234. else:
  235. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  236. if unzip and f.suffix in ('.zip', '.gz'):
  237. print(f'Unzipping {f}...')
  238. if f.suffix == '.zip':
  239. s = f'unzip -qo {f} -d {dir}' # unzip -quiet -overwrite
  240. elif f.suffix == '.gz':
  241. s = f'tar xfz {f} --directory {f.parent}' # unzip
  242. if delete: # delete zip file after unzip
  243. s += f' && rm {f}'
  244. os.system(s)
  245. dir = Path(dir)
  246. dir.mkdir(parents=True, exist_ok=True) # make directory
  247. if threads > 1:
  248. pool = ThreadPool(threads)
  249. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  250. pool.close()
  251. pool.join()
  252. else:
  253. for u in [url] if isinstance(url, (str, Path)) else url:
  254. download_one(u, dir)
  255. def make_divisible(x, divisor):
  256. # Returns x evenly divisible by divisor
  257. return math.ceil(x / divisor) * divisor
  258. def clean_str(s):
  259. # Cleans a string by replacing special characters with underscore _
  260. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  261. def one_cycle(y1=0.0, y2=1.0, steps=100):
  262. # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
  263. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  264. def colorstr(*input):
  265. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  266. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  267. colors = {'black': '\033[30m', # basic colors
  268. 'red': '\033[31m',
  269. 'green': '\033[32m',
  270. 'yellow': '\033[33m',
  271. 'blue': '\033[34m',
  272. 'magenta': '\033[35m',
  273. 'cyan': '\033[36m',
  274. 'white': '\033[37m',
  275. 'bright_black': '\033[90m', # bright colors
  276. 'bright_red': '\033[91m',
  277. 'bright_green': '\033[92m',
  278. 'bright_yellow': '\033[93m',
  279. 'bright_blue': '\033[94m',
  280. 'bright_magenta': '\033[95m',
  281. 'bright_cyan': '\033[96m',
  282. 'bright_white': '\033[97m',
  283. 'end': '\033[0m', # misc
  284. 'bold': '\033[1m',
  285. 'underline': '\033[4m'}
  286. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  287. def labels_to_class_weights(labels, nc=80):
  288. # Get class weights (inverse frequency) from training labels
  289. if labels[0] is None: # no labels loaded
  290. return torch.Tensor()
  291. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  292. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  293. weights = np.bincount(classes, minlength=nc) # occurrences per class
  294. # Prepend gridpoint count (for uCE training)
  295. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  296. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  297. weights[weights == 0] = 1 # replace empty bins with 1
  298. weights = 1 / weights # number of targets per class
  299. weights /= weights.sum() # normalize
  300. return torch.from_numpy(weights)
  301. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  302. # Produces image weights based on class_weights and image contents
  303. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  304. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  305. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  306. return image_weights
  307. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  308. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  309. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  310. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  311. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  312. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  313. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  314. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  315. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  316. return x
  317. def xyxy2xywh(x):
  318. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  319. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  320. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  321. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  322. y[:, 2] = x[:, 2] - x[:, 0] # width
  323. y[:, 3] = x[:, 3] - x[:, 1] # height
  324. return y
  325. def xywh2xyxy(x):
  326. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  327. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  328. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  329. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  330. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  331. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  332. return y
  333. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  334. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  335. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  336. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  337. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  338. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  339. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  340. return y
  341. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  342. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  343. if clip:
  344. clip_coords(x, (h - eps, w - eps)) # warning: inplace clip
  345. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  346. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  347. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  348. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  349. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  350. return y
  351. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  352. # Convert normalized segments into pixel segments, shape (n,2)
  353. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  354. y[:, 0] = w * x[:, 0] + padw # top left x
  355. y[:, 1] = h * x[:, 1] + padh # top left y
  356. return y
  357. def segment2box(segment, width=640, height=640):
  358. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  359. x, y = segment.T # segment xy
  360. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  361. x, y, = x[inside], y[inside]
  362. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  363. def segments2boxes(segments):
  364. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  365. boxes = []
  366. for s in segments:
  367. x, y = s.T # segment xy
  368. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  369. return xyxy2xywh(np.array(boxes)) # cls, xywh
  370. def resample_segments(segments, n=1000):
  371. # Up-sample an (n,2) segment
  372. for i, s in enumerate(segments):
  373. x = np.linspace(0, len(s) - 1, n)
  374. xp = np.arange(len(s))
  375. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  376. return segments
  377. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  378. # Rescale coords (xyxy) from img1_shape to img0_shape
  379. if ratio_pad is None: # calculate from img0_shape
  380. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  381. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  382. else:
  383. gain = ratio_pad[0][0]
  384. pad = ratio_pad[1]
  385. coords[:, [0, 2]] -= pad[0] # x padding
  386. coords[:, [1, 3]] -= pad[1] # y padding
  387. coords[:, :4] /= gain
  388. clip_coords(coords, img0_shape)
  389. return coords
  390. def clip_coords(boxes, shape):
  391. # Clip bounding xyxy bounding boxes to image shape (height, width)
  392. if isinstance(boxes, torch.Tensor): # faster individually
  393. boxes[:, 0].clamp_(0, shape[1]) # x1
  394. boxes[:, 1].clamp_(0, shape[0]) # y1
  395. boxes[:, 2].clamp_(0, shape[1]) # x2
  396. boxes[:, 3].clamp_(0, shape[0]) # y2
  397. else: # np.array (faster grouped)
  398. boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
  399. boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
  400. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  401. labels=(), max_det=300):
  402. """Runs Non-Maximum Suppression (NMS) on inference results
  403. Returns:
  404. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  405. """
  406. nc = prediction.shape[2] - 5 # number of classes
  407. xc = prediction[..., 4] > conf_thres # candidates
  408. # Checks
  409. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  410. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  411. # Settings
  412. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  413. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  414. time_limit = 10.0 # seconds to quit after
  415. redundant = True # require redundant detections
  416. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  417. merge = False # use merge-NMS
  418. t = time.time()
  419. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  420. for xi, x in enumerate(prediction): # image index, image inference
  421. # Apply constraints
  422. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  423. x = x[xc[xi]] # confidence
  424. # Cat apriori labels if autolabelling
  425. if labels and len(labels[xi]):
  426. l = labels[xi]
  427. v = torch.zeros((len(l), nc + 5), device=x.device)
  428. v[:, :4] = l[:, 1:5] # box
  429. v[:, 4] = 1.0 # conf
  430. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  431. x = torch.cat((x, v), 0)
  432. # If none remain process next image
  433. if not x.shape[0]:
  434. continue
  435. # Compute conf
  436. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  437. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  438. box = xywh2xyxy(x[:, :4])
  439. # Detections matrix nx6 (xyxy, conf, cls)
  440. if multi_label:
  441. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  442. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  443. else: # best class only
  444. conf, j = x[:, 5:].max(1, keepdim=True)
  445. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  446. # Filter by class
  447. if classes is not None:
  448. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  449. # Apply finite constraint
  450. # if not torch.isfinite(x).all():
  451. # x = x[torch.isfinite(x).all(1)]
  452. # Check shape
  453. n = x.shape[0] # number of boxes
  454. if not n: # no boxes
  455. continue
  456. elif n > max_nms: # excess boxes
  457. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  458. # Batched NMS
  459. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  460. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  461. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  462. if i.shape[0] > max_det: # limit detections
  463. i = i[:max_det]
  464. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  465. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  466. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  467. weights = iou * scores[None] # box weights
  468. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  469. if redundant:
  470. i = i[iou.sum(1) > 1] # require redundancy
  471. output[xi] = x[i]
  472. if (time.time() - t) > time_limit:
  473. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  474. break # time limit exceeded
  475. return output
  476. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  477. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  478. x = torch.load(f, map_location=torch.device('cpu'))
  479. if x.get('ema'):
  480. x['model'] = x['ema'] # replace model with ema
  481. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  482. x[k] = None
  483. x['epoch'] = -1
  484. x['model'].half() # to FP16
  485. for p in x['model'].parameters():
  486. p.requires_grad = False
  487. torch.save(x, s or f)
  488. mb = os.path.getsize(s or f) / 1E6 # filesize
  489. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  490. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  491. # Print mutation results to evolve.txt (for use with train.py --evolve)
  492. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  493. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  494. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  495. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  496. if bucket:
  497. url = 'gs://%s/evolve.txt' % bucket
  498. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  499. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  500. with open('evolve.txt', 'a') as f: # append result
  501. f.write(c + b + '\n')
  502. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  503. x = x[np.argsort(-fitness(x))] # sort
  504. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  505. # Save yaml
  506. for i, k in enumerate(hyp.keys()):
  507. hyp[k] = float(x[0, i + 7])
  508. with open(yaml_file, 'w') as f:
  509. results = tuple(x[0, :7])
  510. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  511. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  512. yaml.safe_dump(hyp, f, sort_keys=False)
  513. if bucket:
  514. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  515. def apply_classifier(x, model, img, im0):
  516. # Apply a second stage classifier to yolo outputs
  517. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  518. for i, d in enumerate(x): # per image
  519. if d is not None and len(d):
  520. d = d.clone()
  521. # Reshape and pad cutouts
  522. b = xyxy2xywh(d[:, :4]) # boxes
  523. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  524. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  525. d[:, :4] = xywh2xyxy(b).long()
  526. # Rescale boxes from img_size to im0 size
  527. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  528. # Classes
  529. pred_cls1 = d[:, 5].long()
  530. ims = []
  531. for j, a in enumerate(d): # per item
  532. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  533. im = cv2.resize(cutout, (224, 224)) # BGR
  534. # cv2.imwrite('example%i.jpg' % j, cutout)
  535. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  536. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  537. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  538. ims.append(im)
  539. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  540. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  541. return x
  542. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  543. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  544. xyxy = torch.tensor(xyxy).view(-1, 4)
  545. b = xyxy2xywh(xyxy) # boxes
  546. if square:
  547. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  548. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  549. xyxy = xywh2xyxy(b).long()
  550. clip_coords(xyxy, im.shape)
  551. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  552. if save:
  553. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  554. return crop
  555. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  556. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  557. path = Path(path) # os-agnostic
  558. if path.exists() and not exist_ok:
  559. suffix = path.suffix
  560. path = path.with_suffix('')
  561. dirs = glob.glob(f"{path}{sep}*") # similar paths
  562. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  563. i = [int(m.groups()[0]) for m in matches if m] # indices
  564. n = max(i) + 1 if i else 2 # increment number
  565. path = Path(f"{path}{sep}{n}{suffix}") # update path
  566. dir = path if path.suffix == '' else path.parent # directory
  567. if not dir.exists() and mkdir:
  568. dir.mkdir(parents=True, exist_ok=True) # make directory
  569. return path