You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

656 lines
27KB

  1. # YOLOv5 general utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from itertools import repeat
  12. from multiprocessing.pool import ThreadPool
  13. from pathlib import Path
  14. import cv2
  15. import numpy as np
  16. import pandas as pd
  17. import torch
  18. import torchvision
  19. import yaml
  20. from utils.google_utils import gsutil_getsize
  21. from utils.metrics import fitness
  22. from utils.torch_utils import init_torch_seeds
  23. # Settings
  24. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  25. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  26. pd.options.display.max_columns = 10
  27. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  28. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  29. def set_logging(rank=-1):
  30. logging.basicConfig(
  31. format="%(message)s",
  32. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  33. def init_seeds(seed=0):
  34. # Initialize random number generator (RNG) seeds
  35. random.seed(seed)
  36. np.random.seed(seed)
  37. init_torch_seeds(seed)
  38. def get_latest_run(search_dir='.'):
  39. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  40. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  41. return max(last_list, key=os.path.getctime) if last_list else ''
  42. def isdocker():
  43. # Is environment a Docker container
  44. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  45. def emojis(str=''):
  46. # Return platform-dependent emoji-safe version of string
  47. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  48. def file_size(file):
  49. # Return file size in MB
  50. return Path(file).stat().st_size / 1e6
  51. def check_online():
  52. # Check internet connectivity
  53. import socket
  54. try:
  55. socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
  56. return True
  57. except OSError:
  58. return False
  59. def check_git_status():
  60. # Recommend 'git pull' if code is out of date
  61. print(colorstr('github: '), end='')
  62. try:
  63. assert Path('.git').exists(), 'skipping check (not a git repository)'
  64. assert not isdocker(), 'skipping check (Docker image)'
  65. assert check_online(), 'skipping check (offline)'
  66. cmd = 'git fetch && git config --get remote.origin.url'
  67. url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
  68. branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  69. n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  70. if n > 0:
  71. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  72. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  73. else:
  74. s = f'up to date with {url} ✅'
  75. print(emojis(s)) # emoji-safe
  76. except Exception as e:
  77. print(e)
  78. def check_requirements(requirements='requirements.txt', exclude=()):
  79. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  80. import pkg_resources as pkg
  81. prefix = colorstr('red', 'bold', 'requirements:')
  82. if isinstance(requirements, (str, Path)): # requirements.txt file
  83. file = Path(requirements)
  84. if not file.exists():
  85. print(f"{prefix} {file.resolve()} not found, check failed.")
  86. return
  87. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  88. else: # list or tuple of packages
  89. requirements = [x for x in requirements if x not in exclude]
  90. n = 0 # number of packages updates
  91. for r in requirements:
  92. try:
  93. pkg.require(r)
  94. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  95. n += 1
  96. print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
  97. print(subprocess.check_output(f"pip install {e.req}", shell=True).decode())
  98. if n: # if packages updated
  99. source = file.resolve() if 'file' in locals() else requirements
  100. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  101. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  102. print(emojis(s)) # emoji-safe
  103. def check_img_size(img_size, s=32):
  104. # Verify img_size is a multiple of stride s
  105. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  106. if new_size != img_size:
  107. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  108. return new_size
  109. def check_imshow():
  110. # Check if environment supports image displays
  111. try:
  112. assert not isdocker(), 'cv2.imshow() is disabled in Docker environments'
  113. cv2.imshow('test', np.zeros((1, 1, 3)))
  114. cv2.waitKey(1)
  115. cv2.destroyAllWindows()
  116. cv2.waitKey(1)
  117. return True
  118. except Exception as e:
  119. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  120. return False
  121. def check_file(file):
  122. # Search for file if not found
  123. if Path(file).is_file() or file == '':
  124. return file
  125. else:
  126. files = glob.glob('./**/' + file, recursive=True) # find file
  127. assert len(files), f'File Not Found: {file}' # assert file was found
  128. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  129. return files[0] # return file
  130. def check_dataset(dict):
  131. # Download dataset if not found locally
  132. val, s = dict.get('val'), dict.get('download')
  133. if val and len(val):
  134. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  135. if not all(x.exists() for x in val):
  136. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  137. if s and len(s): # download script
  138. if s.startswith('http') and s.endswith('.zip'): # URL
  139. f = Path(s).name # filename
  140. print(f'Downloading {s} ...')
  141. torch.hub.download_url_to_file(s, f)
  142. r = os.system(f'unzip -q {f} -d ../ && rm {f}') # unzip
  143. elif s.startswith('bash '): # bash script
  144. print(f'Running {s} ...')
  145. r = os.system(s)
  146. else: # python script
  147. r = exec(s) # return None
  148. print('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure')) # print result
  149. else:
  150. raise Exception('Dataset not found.')
  151. def download(url, dir='.', multi_thread=False):
  152. # Multi-threaded file download and unzip function
  153. def download_one(url, dir):
  154. # Download 1 file
  155. f = dir / Path(url).name # filename
  156. if not f.exists():
  157. print(f'Downloading {url} to {f}...')
  158. torch.hub.download_url_to_file(url, f, progress=True) # download
  159. if f.suffix in ('.zip', '.gz'):
  160. print(f'Unzipping {f}...')
  161. if f.suffix == '.zip':
  162. os.system(f'unzip -qo {f} -d {dir} && rm {f}') # unzip -quiet -overwrite
  163. elif f.suffix == '.gz':
  164. os.system(f'tar xfz {f} --directory {f.parent} && rm {f}') # unzip
  165. dir = Path(dir)
  166. dir.mkdir(parents=True, exist_ok=True) # make directory
  167. if multi_thread:
  168. ThreadPool(8).imap(lambda x: download_one(*x), zip(url, repeat(dir))) # 8 threads
  169. else:
  170. for u in tuple(url) if isinstance(url, str) else url:
  171. download_one(u, dir)
  172. def make_divisible(x, divisor):
  173. # Returns x evenly divisible by divisor
  174. return math.ceil(x / divisor) * divisor
  175. def clean_str(s):
  176. # Cleans a string by replacing special characters with underscore _
  177. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  178. def one_cycle(y1=0.0, y2=1.0, steps=100):
  179. # lambda function for sinusoidal ramp from y1 to y2
  180. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  181. def colorstr(*input):
  182. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  183. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  184. colors = {'black': '\033[30m', # basic colors
  185. 'red': '\033[31m',
  186. 'green': '\033[32m',
  187. 'yellow': '\033[33m',
  188. 'blue': '\033[34m',
  189. 'magenta': '\033[35m',
  190. 'cyan': '\033[36m',
  191. 'white': '\033[37m',
  192. 'bright_black': '\033[90m', # bright colors
  193. 'bright_red': '\033[91m',
  194. 'bright_green': '\033[92m',
  195. 'bright_yellow': '\033[93m',
  196. 'bright_blue': '\033[94m',
  197. 'bright_magenta': '\033[95m',
  198. 'bright_cyan': '\033[96m',
  199. 'bright_white': '\033[97m',
  200. 'end': '\033[0m', # misc
  201. 'bold': '\033[1m',
  202. 'underline': '\033[4m'}
  203. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  204. def labels_to_class_weights(labels, nc=80):
  205. # Get class weights (inverse frequency) from training labels
  206. if labels[0] is None: # no labels loaded
  207. return torch.Tensor()
  208. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  209. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  210. weights = np.bincount(classes, minlength=nc) # occurrences per class
  211. # Prepend gridpoint count (for uCE training)
  212. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  213. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  214. weights[weights == 0] = 1 # replace empty bins with 1
  215. weights = 1 / weights # number of targets per class
  216. weights /= weights.sum() # normalize
  217. return torch.from_numpy(weights)
  218. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  219. # Produces image weights based on class_weights and image contents
  220. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  221. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  222. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  223. return image_weights
  224. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  225. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  226. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  227. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  228. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  229. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  230. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  231. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  232. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  233. return x
  234. def xyxy2xywh(x):
  235. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  236. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  237. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  238. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  239. y[:, 2] = x[:, 2] - x[:, 0] # width
  240. y[:, 3] = x[:, 3] - x[:, 1] # height
  241. return y
  242. def xywh2xyxy(x):
  243. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  244. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  245. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  246. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  247. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  248. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  249. return y
  250. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  251. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  252. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  253. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  254. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  255. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  256. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  257. return y
  258. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  259. # Convert normalized segments into pixel segments, shape (n,2)
  260. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  261. y[:, 0] = w * x[:, 0] + padw # top left x
  262. y[:, 1] = h * x[:, 1] + padh # top left y
  263. return y
  264. def segment2box(segment, width=640, height=640):
  265. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  266. x, y = segment.T # segment xy
  267. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  268. x, y, = x[inside], y[inside]
  269. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  270. def segments2boxes(segments):
  271. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  272. boxes = []
  273. for s in segments:
  274. x, y = s.T # segment xy
  275. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  276. return xyxy2xywh(np.array(boxes)) # cls, xywh
  277. def resample_segments(segments, n=1000):
  278. # Up-sample an (n,2) segment
  279. for i, s in enumerate(segments):
  280. x = np.linspace(0, len(s) - 1, n)
  281. xp = np.arange(len(s))
  282. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  283. return segments
  284. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  285. # Rescale coords (xyxy) from img1_shape to img0_shape
  286. if ratio_pad is None: # calculate from img0_shape
  287. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  288. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  289. else:
  290. gain = ratio_pad[0][0]
  291. pad = ratio_pad[1]
  292. coords[:, [0, 2]] -= pad[0] # x padding
  293. coords[:, [1, 3]] -= pad[1] # y padding
  294. coords[:, :4] /= gain
  295. clip_coords(coords, img0_shape)
  296. return coords
  297. def clip_coords(boxes, img_shape):
  298. # Clip bounding xyxy bounding boxes to image shape (height, width)
  299. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  300. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  301. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  302. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  303. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  304. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  305. box2 = box2.T
  306. # Get the coordinates of bounding boxes
  307. if x1y1x2y2: # x1, y1, x2, y2 = box1
  308. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  309. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  310. else: # transform from xywh to xyxy
  311. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  312. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  313. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  314. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  315. # Intersection area
  316. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  317. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  318. # Union Area
  319. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  320. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  321. union = w1 * h1 + w2 * h2 - inter + eps
  322. iou = inter / union
  323. if GIoU or DIoU or CIoU:
  324. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  325. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  326. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  327. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  328. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  329. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  330. if DIoU:
  331. return iou - rho2 / c2 # DIoU
  332. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  333. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  334. with torch.no_grad():
  335. alpha = v / (v - iou + (1 + eps))
  336. return iou - (rho2 / c2 + v * alpha) # CIoU
  337. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  338. c_area = cw * ch + eps # convex area
  339. return iou - (c_area - union) / c_area # GIoU
  340. else:
  341. return iou # IoU
  342. def box_iou(box1, box2):
  343. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  344. """
  345. Return intersection-over-union (Jaccard index) of boxes.
  346. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  347. Arguments:
  348. box1 (Tensor[N, 4])
  349. box2 (Tensor[M, 4])
  350. Returns:
  351. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  352. IoU values for every element in boxes1 and boxes2
  353. """
  354. def box_area(box):
  355. # box = 4xn
  356. return (box[2] - box[0]) * (box[3] - box[1])
  357. area1 = box_area(box1.T)
  358. area2 = box_area(box2.T)
  359. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  360. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  361. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  362. def wh_iou(wh1, wh2):
  363. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  364. wh1 = wh1[:, None] # [N,1,2]
  365. wh2 = wh2[None] # [1,M,2]
  366. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  367. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  368. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  369. labels=()):
  370. """Runs Non-Maximum Suppression (NMS) on inference results
  371. Returns:
  372. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  373. """
  374. nc = prediction.shape[2] - 5 # number of classes
  375. xc = prediction[..., 4] > conf_thres # candidates
  376. # Settings
  377. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  378. max_det = 300 # maximum number of detections per image
  379. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  380. time_limit = 10.0 # seconds to quit after
  381. redundant = True # require redundant detections
  382. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  383. merge = False # use merge-NMS
  384. t = time.time()
  385. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  386. for xi, x in enumerate(prediction): # image index, image inference
  387. # Apply constraints
  388. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  389. x = x[xc[xi]] # confidence
  390. # Cat apriori labels if autolabelling
  391. if labels and len(labels[xi]):
  392. l = labels[xi]
  393. v = torch.zeros((len(l), nc + 5), device=x.device)
  394. v[:, :4] = l[:, 1:5] # box
  395. v[:, 4] = 1.0 # conf
  396. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  397. x = torch.cat((x, v), 0)
  398. # If none remain process next image
  399. if not x.shape[0]:
  400. continue
  401. # Compute conf
  402. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  403. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  404. box = xywh2xyxy(x[:, :4])
  405. # Detections matrix nx6 (xyxy, conf, cls)
  406. if multi_label:
  407. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  408. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  409. else: # best class only
  410. conf, j = x[:, 5:].max(1, keepdim=True)
  411. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  412. # Filter by class
  413. if classes is not None:
  414. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  415. # Apply finite constraint
  416. # if not torch.isfinite(x).all():
  417. # x = x[torch.isfinite(x).all(1)]
  418. # Check shape
  419. n = x.shape[0] # number of boxes
  420. if not n: # no boxes
  421. continue
  422. elif n > max_nms: # excess boxes
  423. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  424. # Batched NMS
  425. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  426. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  427. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  428. if i.shape[0] > max_det: # limit detections
  429. i = i[:max_det]
  430. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  431. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  432. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  433. weights = iou * scores[None] # box weights
  434. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  435. if redundant:
  436. i = i[iou.sum(1) > 1] # require redundancy
  437. output[xi] = x[i]
  438. if (time.time() - t) > time_limit:
  439. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  440. break # time limit exceeded
  441. return output
  442. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  443. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  444. x = torch.load(f, map_location=torch.device('cpu'))
  445. if x.get('ema'):
  446. x['model'] = x['ema'] # replace model with ema
  447. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  448. x[k] = None
  449. x['epoch'] = -1
  450. x['model'].half() # to FP16
  451. for p in x['model'].parameters():
  452. p.requires_grad = False
  453. torch.save(x, s or f)
  454. mb = os.path.getsize(s or f) / 1E6 # filesize
  455. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  456. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  457. # Print mutation results to evolve.txt (for use with train.py --evolve)
  458. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  459. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  460. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  461. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  462. if bucket:
  463. url = 'gs://%s/evolve.txt' % bucket
  464. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  465. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  466. with open('evolve.txt', 'a') as f: # append result
  467. f.write(c + b + '\n')
  468. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  469. x = x[np.argsort(-fitness(x))] # sort
  470. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  471. # Save yaml
  472. for i, k in enumerate(hyp.keys()):
  473. hyp[k] = float(x[0, i + 7])
  474. with open(yaml_file, 'w') as f:
  475. results = tuple(x[0, :7])
  476. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  477. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  478. yaml.safe_dump(hyp, f, sort_keys=False)
  479. if bucket:
  480. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  481. def apply_classifier(x, model, img, im0):
  482. # Apply a second stage classifier to yolo outputs
  483. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  484. for i, d in enumerate(x): # per image
  485. if d is not None and len(d):
  486. d = d.clone()
  487. # Reshape and pad cutouts
  488. b = xyxy2xywh(d[:, :4]) # boxes
  489. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  490. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  491. d[:, :4] = xywh2xyxy(b).long()
  492. # Rescale boxes from img_size to im0 size
  493. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  494. # Classes
  495. pred_cls1 = d[:, 5].long()
  496. ims = []
  497. for j, a in enumerate(d): # per item
  498. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  499. im = cv2.resize(cutout, (224, 224)) # BGR
  500. # cv2.imwrite('test%i.jpg' % j, cutout)
  501. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  502. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  503. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  504. ims.append(im)
  505. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  506. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  507. return x
  508. def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False):
  509. # Save an image crop as {file} with crop size multiplied by {gain} and padded by {pad} pixels
  510. xyxy = torch.tensor(xyxy).view(-1, 4)
  511. b = xyxy2xywh(xyxy) # boxes
  512. if square:
  513. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  514. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  515. xyxy = xywh2xyxy(b).long()
  516. clip_coords(xyxy, im.shape)
  517. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2])]
  518. cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop if BGR else crop[..., ::-1])
  519. def increment_path(path, exist_ok=False, sep='', mkdir=False):
  520. # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
  521. path = Path(path) # os-agnostic
  522. if path.exists() and not exist_ok:
  523. suffix = path.suffix
  524. path = path.with_suffix('')
  525. dirs = glob.glob(f"{path}{sep}*") # similar paths
  526. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  527. i = [int(m.groups()[0]) for m in matches if m] # indices
  528. n = max(i) + 1 if i else 2 # increment number
  529. path = Path(f"{path}{sep}{n}{suffix}") # update path
  530. dir = path if path.suffix == '' else path.parent # directory
  531. if not dir.exists() and mkdir:
  532. dir.mkdir(parents=True, exist_ok=True) # make directory
  533. return path