You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

681 lines
28KB

  1. # YOLOv5 general utils
  2. import contextlib
  3. import glob
  4. import logging
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import signal
  10. import time
  11. import urllib
  12. from itertools import repeat
  13. from multiprocessing.pool import ThreadPool
  14. from pathlib import Path
  15. from subprocess import check_output
  16. import cv2
  17. import math
  18. import numpy as np
  19. import pandas as pd
  20. import pkg_resources as pkg
  21. import torch
  22. import torchvision
  23. import yaml
  24. from utils.google_utils import gsutil_getsize
  25. from utils.metrics import box_iou, fitness
  26. from utils.torch_utils import init_torch_seeds
  27. # Settings
  28. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  29. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  30. pd.options.display.max_columns = 10
  31. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  32. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  33. class timeout(contextlib.ContextDecorator):
  34. # Usage: @timeout(seconds) decorator or 'with timeout(seconds):' context manager
  35. def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
  36. self.seconds = int(seconds)
  37. self.timeout_message = timeout_msg
  38. self.suppress = bool(suppress_timeout_errors)
  39. def _timeout_handler(self, signum, frame):
  40. raise TimeoutError(self.timeout_message)
  41. def __enter__(self):
  42. signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
  43. signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
  44. def __exit__(self, exc_type, exc_val, exc_tb):
  45. signal.alarm(0) # Cancel SIGALRM if it's scheduled
  46. if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
  47. return True
  48. def set_logging(rank=-1, verbose=True):
  49. logging.basicConfig(
  50. format="%(message)s",
  51. level=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)
  52. def init_seeds(seed=0):
  53. # Initialize random number generator (RNG) seeds
  54. random.seed(seed)
  55. np.random.seed(seed)
  56. init_torch_seeds(seed)
  57. def get_latest_run(search_dir='.'):
  58. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  59. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  60. return max(last_list, key=os.path.getctime) if last_list else ''
  61. def is_docker():
  62. # Is environment a Docker container?
  63. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  64. def is_colab():
  65. # Is environment a Google Colab instance?
  66. try:
  67. import google.colab
  68. return True
  69. except Exception as e:
  70. return False
  71. def is_pip():
  72. # Is file in a pip package?
  73. return 'site-packages' in Path(__file__).absolute().parts
  74. def emojis(str=''):
  75. # Return platform-dependent emoji-safe version of string
  76. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  77. def file_size(file):
  78. # Return file size in MB
  79. return Path(file).stat().st_size / 1e6
  80. def check_online():
  81. # Check internet connectivity
  82. import socket
  83. try:
  84. socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
  85. return True
  86. except OSError:
  87. return False
  88. def check_git_status(err_msg=', for updates see https://github.com/ultralytics/yolov5'):
  89. # Recommend 'git pull' if code is out of date
  90. print(colorstr('github: '), end='')
  91. try:
  92. assert Path('.git').exists(), 'skipping check (not a git repository)'
  93. assert not is_docker(), 'skipping check (Docker image)'
  94. assert check_online(), 'skipping check (offline)'
  95. cmd = 'git fetch && git config --get remote.origin.url'
  96. url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch
  97. branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  98. n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  99. if n > 0:
  100. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  101. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  102. else:
  103. s = f'up to date with {url} ✅'
  104. print(emojis(s)) # emoji-safe
  105. except Exception as e:
  106. print(f'{e}{err_msg}')
  107. def check_python(minimum='3.6.2'):
  108. # Check current python version vs. required python version
  109. check_version(platform.python_version(), minimum, name='Python ')
  110. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False):
  111. # Check version vs. required version
  112. current, minimum = (pkg.parse_version(x) for x in (current, minimum))
  113. result = (current == minimum) if pinned else (current >= minimum)
  114. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  115. def check_requirements(requirements='requirements.txt', exclude=()):
  116. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  117. prefix = colorstr('red', 'bold', 'requirements:')
  118. check_python() # check python version
  119. if isinstance(requirements, (str, Path)): # requirements.txt file
  120. file = Path(requirements)
  121. if not file.exists():
  122. print(f"{prefix} {file.resolve()} not found, check failed.")
  123. return
  124. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  125. else: # list or tuple of packages
  126. requirements = [x for x in requirements if x not in exclude]
  127. n = 0 # number of packages updates
  128. for r in requirements:
  129. try:
  130. pkg.require(r)
  131. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  132. print(f"{prefix} {r} not found and is required by YOLOv5, attempting auto-update...")
  133. try:
  134. assert check_online(), f"'pip install {r}' skipped (offline)"
  135. print(check_output(f"pip install '{r}'", shell=True).decode())
  136. n += 1
  137. except Exception as e:
  138. print(f'{prefix} {e}')
  139. if n: # if packages updated
  140. source = file.resolve() if 'file' in locals() else requirements
  141. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  142. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  143. print(emojis(s)) # emoji-safe
  144. def check_img_size(img_size, s=32):
  145. # Verify img_size is a multiple of stride s
  146. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  147. if new_size != img_size:
  148. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  149. return new_size
  150. def check_imshow():
  151. # Check if environment supports image displays
  152. try:
  153. assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
  154. assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
  155. cv2.imshow('test', np.zeros((1, 1, 3)))
  156. cv2.waitKey(1)
  157. cv2.destroyAllWindows()
  158. cv2.waitKey(1)
  159. return True
  160. except Exception as e:
  161. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  162. return False
  163. def check_file(file):
  164. # Search/download file (if necessary) and return path
  165. file = str(file) # convert to str()
  166. if Path(file).is_file() or file == '': # exists
  167. return file
  168. elif file.startswith(('http:/', 'https:/')): # download
  169. url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
  170. file = Path(urllib.parse.unquote(file)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
  171. print(f'Downloading {url} to {file}...')
  172. torch.hub.download_url_to_file(url, file)
  173. assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
  174. return file
  175. else: # search
  176. files = glob.glob('./**/' + file, recursive=True) # find file
  177. assert len(files), f'File not found: {file}' # assert file was found
  178. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  179. return files[0] # return file
  180. def check_dataset(data, autodownload=True):
  181. # Download dataset if not found locally
  182. path = Path(data.get('path', '')) # optional 'path' field
  183. if path:
  184. for k in 'train', 'val', 'test':
  185. if data.get(k): # prepend path
  186. data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
  187. train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
  188. if val:
  189. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  190. if not all(x.exists() for x in val):
  191. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  192. if s and autodownload: # download script
  193. if s.startswith('http') and s.endswith('.zip'): # URL
  194. f = Path(s).name # filename
  195. print(f'Downloading {s} ...')
  196. torch.hub.download_url_to_file(s, f)
  197. root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
  198. Path(root).mkdir(parents=True, exist_ok=True) # create root
  199. r = os.system(f'unzip -q {f} -d {root} && rm {f}') # unzip
  200. elif s.startswith('bash '): # bash script
  201. print(f'Running {s} ...')
  202. r = os.system(s)
  203. else: # python script
  204. r = exec(s, {'yaml': data}) # return None
  205. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  206. else:
  207. raise Exception('Dataset not found.')
  208. def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
  209. # Multi-threaded file download and unzip function
  210. def download_one(url, dir):
  211. # Download 1 file
  212. f = dir / Path(url).name # filename
  213. if not f.exists():
  214. print(f'Downloading {url} to {f}...')
  215. if curl:
  216. os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
  217. else:
  218. torch.hub.download_url_to_file(url, f, progress=True) # torch download
  219. if unzip and f.suffix in ('.zip', '.gz'):
  220. print(f'Unzipping {f}...')
  221. if f.suffix == '.zip':
  222. s = f'unzip -qo {f} -d {dir}' # unzip -quiet -overwrite
  223. elif f.suffix == '.gz':
  224. s = f'tar xfz {f} --directory {f.parent}' # unzip
  225. if delete: # delete zip file after unzip
  226. s += f' && rm {f}'
  227. os.system(s)
  228. dir = Path(dir)
  229. dir.mkdir(parents=True, exist_ok=True) # make directory
  230. if threads > 1:
  231. pool = ThreadPool(threads)
  232. pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
  233. pool.close()
  234. pool.join()
  235. else:
  236. for u in tuple(url) if isinstance(url, str) else url:
  237. download_one(u, dir)
  238. def make_divisible(x, divisor):
  239. # Returns x evenly divisible by divisor
  240. return math.ceil(x / divisor) * divisor
  241. def clean_str(s):
  242. # Cleans a string by replacing special characters with underscore _
  243. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  244. def one_cycle(y1=0.0, y2=1.0, steps=100):
  245. # lambda function for sinusoidal ramp from y1 to y2
  246. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  247. def colorstr(*input):
  248. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  249. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  250. colors = {'black': '\033[30m', # basic colors
  251. 'red': '\033[31m',
  252. 'green': '\033[32m',
  253. 'yellow': '\033[33m',
  254. 'blue': '\033[34m',
  255. 'magenta': '\033[35m',
  256. 'cyan': '\033[36m',
  257. 'white': '\033[37m',
  258. 'bright_black': '\033[90m', # bright colors
  259. 'bright_red': '\033[91m',
  260. 'bright_green': '\033[92m',
  261. 'bright_yellow': '\033[93m',
  262. 'bright_blue': '\033[94m',
  263. 'bright_magenta': '\033[95m',
  264. 'bright_cyan': '\033[96m',
  265. 'bright_white': '\033[97m',
  266. 'end': '\033[0m', # misc
  267. 'bold': '\033[1m',
  268. 'underline': '\033[4m'}
  269. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  270. def labels_to_class_weights(labels, nc=80):
  271. # Get class weights (inverse frequency) from training labels
  272. if labels[0] is None: # no labels loaded
  273. return torch.Tensor()
  274. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  275. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  276. weights = np.bincount(classes, minlength=nc) # occurrences per class
  277. # Prepend gridpoint count (for uCE training)
  278. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  279. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  280. weights[weights == 0] = 1 # replace empty bins with 1
  281. weights = 1 / weights # number of targets per class
  282. weights /= weights.sum() # normalize
  283. return torch.from_numpy(weights)
  284. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  285. # Produces image weights based on class_weights and image contents
  286. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  287. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  288. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  289. return image_weights
  290. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  291. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  292. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  293. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  294. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  295. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  296. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  297. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  298. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  299. return x
  300. def xyxy2xywh(x):
  301. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  302. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  303. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  304. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  305. y[:, 2] = x[:, 2] - x[:, 0] # width
  306. y[:, 3] = x[:, 3] - x[:, 1] # height
  307. return y
  308. def xywh2xyxy(x):
  309. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  310. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  311. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  312. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  313. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  314. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  315. return y
  316. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  317. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  318. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  319. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  320. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  321. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  322. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  323. return y
  324. def xyxy2xywhn(x, w=640, h=640, clip=False):
  325. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
  326. if clip:
  327. clip_coords(x, (h, w)) # warning: inplace clip
  328. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  329. y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
  330. y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
  331. y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
  332. y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
  333. return y
  334. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  335. # Convert normalized segments into pixel segments, shape (n,2)
  336. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  337. y[:, 0] = w * x[:, 0] + padw # top left x
  338. y[:, 1] = h * x[:, 1] + padh # top left y
  339. return y
  340. def segment2box(segment, width=640, height=640):
  341. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  342. x, y = segment.T # segment xy
  343. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  344. x, y, = x[inside], y[inside]
  345. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  346. def segments2boxes(segments):
  347. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  348. boxes = []
  349. for s in segments:
  350. x, y = s.T # segment xy
  351. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  352. return xyxy2xywh(np.array(boxes)) # cls, xywh
  353. def resample_segments(segments, n=1000):
  354. # Up-sample an (n,2) segment
  355. for i, s in enumerate(segments):
  356. x = np.linspace(0, len(s) - 1, n)
  357. xp = np.arange(len(s))
  358. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  359. return segments
  360. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  361. # Rescale coords (xyxy) from img1_shape to img0_shape
  362. if ratio_pad is None: # calculate from img0_shape
  363. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  364. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  365. else:
  366. gain = ratio_pad[0][0]
  367. pad = ratio_pad[1]
  368. coords[:, [0, 2]] -= pad[0] # x padding
  369. coords[:, [1, 3]] -= pad[1] # y padding
  370. coords[:, :4] /= gain
  371. clip_coords(coords, img0_shape)
  372. return coords
  373. def clip_coords(boxes, img_shape):
  374. # Clip bounding xyxy bounding boxes to image shape (height, width)
  375. if isinstance(boxes, torch.Tensor):
  376. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  377. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  378. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  379. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  380. else: # np.array
  381. boxes[:, 0].clip(0, img_shape[1], out=boxes[:, 0]) # x1
  382. boxes[:, 1].clip(0, img_shape[0], out=boxes[:, 1]) # y1
  383. boxes[:, 2].clip(0, img_shape[1], out=boxes[:, 2]) # x2
  384. boxes[:, 3].clip(0, img_shape[0], out=boxes[:, 3]) # y2
  385. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  386. labels=(), max_det=300):
  387. """Runs Non-Maximum Suppression (NMS) on inference results
  388. Returns:
  389. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  390. """
  391. nc = prediction.shape[2] - 5 # number of classes
  392. xc = prediction[..., 4] > conf_thres # candidates
  393. # Checks
  394. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  395. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  396. # Settings
  397. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  398. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  399. time_limit = 10.0 # seconds to quit after
  400. redundant = True # require redundant detections
  401. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  402. merge = False # use merge-NMS
  403. t = time.time()
  404. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  405. for xi, x in enumerate(prediction): # image index, image inference
  406. # Apply constraints
  407. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  408. x = x[xc[xi]] # confidence
  409. # Cat apriori labels if autolabelling
  410. if labels and len(labels[xi]):
  411. l = labels[xi]
  412. v = torch.zeros((len(l), nc + 5), device=x.device)
  413. v[:, :4] = l[:, 1:5] # box
  414. v[:, 4] = 1.0 # conf
  415. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  416. x = torch.cat((x, v), 0)
  417. # If none remain process next image
  418. if not x.shape[0]:
  419. continue
  420. # Compute conf
  421. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  422. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  423. box = xywh2xyxy(x[:, :4])
  424. # Detections matrix nx6 (xyxy, conf, cls)
  425. if multi_label:
  426. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  427. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  428. else: # best class only
  429. conf, j = x[:, 5:].max(1, keepdim=True)
  430. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  431. # Filter by class
  432. if classes is not None:
  433. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  434. # Apply finite constraint
  435. # if not torch.isfinite(x).all():
  436. # x = x[torch.isfinite(x).all(1)]
  437. # Check shape
  438. n = x.shape[0] # number of boxes
  439. if not n: # no boxes
  440. continue
  441. elif n > max_nms: # excess boxes
  442. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  443. # Batched NMS
  444. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  445. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  446. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  447. if i.shape[0] > max_det: # limit detections
  448. i = i[:max_det]
  449. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  450. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  451. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  452. weights = iou * scores[None] # box weights
  453. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  454. if redundant:
  455. i = i[iou.sum(1) > 1] # require redundancy
  456. output[xi] = x[i]
  457. if (time.time() - t) > time_limit:
  458. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  459. break # time limit exceeded
  460. return output
  461. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  462. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  463. x = torch.load(f, map_location=torch.device('cpu'))
  464. if x.get('ema'):
  465. x['model'] = x['ema'] # replace model with ema
  466. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  467. x[k] = None
  468. x['epoch'] = -1
  469. x['model'].half() # to FP16
  470. for p in x['model'].parameters():
  471. p.requires_grad = False
  472. torch.save(x, s or f)
  473. mb = os.path.getsize(s or f) / 1E6 # filesize
  474. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  475. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  476. # Print mutation results to evolve.txt (for use with train.py --evolve)
  477. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  478. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  479. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  480. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  481. if bucket:
  482. url = 'gs://%s/evolve.txt' % bucket
  483. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  484. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  485. with open('evolve.txt', 'a') as f: # append result
  486. f.write(c + b + '\n')
  487. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  488. x = x[np.argsort(-fitness(x))] # sort
  489. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  490. # Save yaml
  491. for i, k in enumerate(hyp.keys()):
  492. hyp[k] = float(x[0, i + 7])
  493. with open(yaml_file, 'w') as f:
  494. results = tuple(x[0, :7])
  495. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  496. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  497. yaml.safe_dump(hyp, f, sort_keys=False)
  498. if bucket:
  499. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  500. def apply_classifier(x, model, img, im0):
  501. # Apply a second stage classifier to yolo outputs
  502. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  503. for i, d in enumerate(x): # per image
  504. if d is not None and len(d):
  505. d = d.clone()
  506. # Reshape and pad cutouts
  507. b = xyxy2xywh(d[:, :4]) # boxes
  508. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  509. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  510. d[:, :4] = xywh2xyxy(b).long()
  511. # Rescale boxes from img_size to im0 size
  512. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  513. # Classes
  514. pred_cls1 = d[:, 5].long()
  515. ims = []
  516. for j, a in enumerate(d): # per item
  517. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  518. im = cv2.resize(cutout, (224, 224)) # BGR
  519. # cv2.imwrite('test%i.jpg' % j, cutout)
  520. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  521. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  522. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  523. ims.append(im)
  524. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  525. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  526. return x
  527. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
  528. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  529. xyxy = torch.tensor(xyxy).view(-1, 4)
  530. b = xyxy2xywh(xyxy) # boxes
  531. if square:
  532. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  533. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  534. xyxy = xywh2xyxy(b).long()
  535. clip_coords(xyxy, im.shape)
  536. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  537. if save:
  538. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)
  539. return crop
  540. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  541. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  542. path = Path(path) # os-agnostic
  543. if path.exists() and not exist_ok:
  544. suffix = path.suffix
  545. path = path.with_suffix('')
  546. dirs = glob.glob(f"{path}{sep}*") # similar paths
  547. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  548. i = [int(m.groups()[0]) for m in matches if m] # indices
  549. n = max(i) + 1 if i else 2 # increment number
  550. path = Path(f"{path}{sep}{n}{suffix}") # update path
  551. dir = path if path.suffix == '' else path.parent # directory
  552. if not dir.exists() and mkdir:
  553. dir.mkdir(parents=True, exist_ok=True) # make directory
  554. return path