Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

647 lines
26KB

  1. # YOLOv5 general utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import pandas as pd
  15. import torch
  16. import torchvision
  17. import yaml
  18. from utils.google_utils import gsutil_getsize
  19. from utils.metrics import fitness
  20. from utils.torch_utils import init_torch_seeds
  21. # Settings
  22. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  23. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  24. pd.options.display.max_columns = 10
  25. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  26. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  27. def set_logging(rank=-1):
  28. logging.basicConfig(
  29. format="%(message)s",
  30. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  31. def init_seeds(seed=0):
  32. # Initialize random number generator (RNG) seeds
  33. random.seed(seed)
  34. np.random.seed(seed)
  35. init_torch_seeds(seed)
  36. def get_latest_run(search_dir='.'):
  37. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  38. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  39. return max(last_list, key=os.path.getctime) if last_list else ''
  40. def isdocker():
  41. # Is environment a Docker container
  42. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  43. def emojis(str=''):
  44. # Return platform-dependent emoji-safe version of string
  45. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  46. def check_online():
  47. # Check internet connectivity
  48. import socket
  49. try:
  50. socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
  51. return True
  52. except OSError:
  53. return False
  54. def check_git_status():
  55. # Recommend 'git pull' if code is out of date
  56. print(colorstr('github: '), end='')
  57. try:
  58. assert Path('.git').exists(), 'skipping check (not a git repository)'
  59. assert not isdocker(), 'skipping check (Docker image)'
  60. assert check_online(), 'skipping check (offline)'
  61. cmd = 'git fetch && git config --get remote.origin.url'
  62. url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
  63. branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  64. n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  65. if n > 0:
  66. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  67. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  68. else:
  69. s = f'up to date with {url} ✅'
  70. print(emojis(s)) # emoji-safe
  71. except Exception as e:
  72. print(e)
  73. def check_requirements(requirements='requirements.txt', exclude=()):
  74. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  75. import pkg_resources as pkg
  76. prefix = colorstr('red', 'bold', 'requirements:')
  77. if isinstance(requirements, (str, Path)): # requirements.txt file
  78. file = Path(requirements)
  79. if not file.exists():
  80. print(f"{prefix} {file.resolve()} not found, check failed.")
  81. return
  82. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  83. else: # list or tuple of packages
  84. requirements = [x for x in requirements if x not in exclude]
  85. n = 0 # number of packages updates
  86. for r in requirements:
  87. try:
  88. pkg.require(r)
  89. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  90. n += 1
  91. print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
  92. print(subprocess.check_output(f"pip install {e.req}", shell=True).decode())
  93. if n: # if packages updated
  94. source = file.resolve() if 'file' in locals() else requirements
  95. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  96. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  97. print(emojis(s)) # emoji-safe
  98. def check_img_size(img_size, s=32):
  99. # Verify img_size is a multiple of stride s
  100. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  101. if new_size != img_size:
  102. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  103. return new_size
  104. def check_imshow():
  105. # Check if environment supports image displays
  106. try:
  107. assert not isdocker(), 'cv2.imshow() is disabled in Docker environments'
  108. cv2.imshow('test', np.zeros((1, 1, 3)))
  109. cv2.waitKey(1)
  110. cv2.destroyAllWindows()
  111. cv2.waitKey(1)
  112. return True
  113. except Exception as e:
  114. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  115. return False
  116. def check_file(file):
  117. # Search for file if not found
  118. if Path(file).is_file() or file == '':
  119. return file
  120. else:
  121. files = glob.glob('./**/' + file, recursive=True) # find file
  122. assert len(files), f'File Not Found: {file}' # assert file was found
  123. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  124. return files[0] # return file
  125. def check_dataset(dict):
  126. # Download dataset if not found locally
  127. val, s = dict.get('val'), dict.get('download')
  128. if val and len(val):
  129. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  130. if not all(x.exists() for x in val):
  131. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  132. if s and len(s): # download script
  133. print('Downloading %s ...' % s)
  134. if s.startswith('http') and s.endswith('.zip'): # URL
  135. f = Path(s).name # filename
  136. torch.hub.download_url_to_file(s, f)
  137. r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
  138. else: # bash script
  139. r = os.system(s)
  140. print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
  141. else:
  142. raise Exception('Dataset not found.')
  143. def make_divisible(x, divisor):
  144. # Returns x evenly divisible by divisor
  145. return math.ceil(x / divisor) * divisor
  146. def clean_str(s):
  147. # Cleans a string by replacing special characters with underscore _
  148. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  149. def one_cycle(y1=0.0, y2=1.0, steps=100):
  150. # lambda function for sinusoidal ramp from y1 to y2
  151. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  152. def colorstr(*input):
  153. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  154. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  155. colors = {'black': '\033[30m', # basic colors
  156. 'red': '\033[31m',
  157. 'green': '\033[32m',
  158. 'yellow': '\033[33m',
  159. 'blue': '\033[34m',
  160. 'magenta': '\033[35m',
  161. 'cyan': '\033[36m',
  162. 'white': '\033[37m',
  163. 'bright_black': '\033[90m', # bright colors
  164. 'bright_red': '\033[91m',
  165. 'bright_green': '\033[92m',
  166. 'bright_yellow': '\033[93m',
  167. 'bright_blue': '\033[94m',
  168. 'bright_magenta': '\033[95m',
  169. 'bright_cyan': '\033[96m',
  170. 'bright_white': '\033[97m',
  171. 'end': '\033[0m', # misc
  172. 'bold': '\033[1m',
  173. 'underline': '\033[4m'}
  174. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  175. def labels_to_class_weights(labels, nc=80):
  176. # Get class weights (inverse frequency) from training labels
  177. if labels[0] is None: # no labels loaded
  178. return torch.Tensor()
  179. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  180. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  181. weights = np.bincount(classes, minlength=nc) # occurrences per class
  182. # Prepend gridpoint count (for uCE training)
  183. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  184. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  185. weights[weights == 0] = 1 # replace empty bins with 1
  186. weights = 1 / weights # number of targets per class
  187. weights /= weights.sum() # normalize
  188. return torch.from_numpy(weights)
  189. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  190. # Produces image weights based on class_weights and image contents
  191. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  192. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  193. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  194. return image_weights
  195. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  196. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  197. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  198. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  199. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  200. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  201. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  202. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  203. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  204. return x
  205. def xyxy2xywh(x):
  206. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  207. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  208. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  209. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  210. y[:, 2] = x[:, 2] - x[:, 0] # width
  211. y[:, 3] = x[:, 3] - x[:, 1] # height
  212. return y
  213. def xywh2xyxy(x):
  214. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  215. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  216. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  217. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  218. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  219. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  220. return y
  221. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  222. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  223. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  224. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  225. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  226. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  227. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  228. return y
  229. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  230. # Convert normalized segments into pixel segments, shape (n,2)
  231. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  232. y[:, 0] = w * x[:, 0] + padw # top left x
  233. y[:, 1] = h * x[:, 1] + padh # top left y
  234. return y
  235. def segment2box(segment, width=640, height=640):
  236. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  237. x, y = segment.T # segment xy
  238. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  239. x, y, = x[inside], y[inside]
  240. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  241. def segments2boxes(segments):
  242. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  243. boxes = []
  244. for s in segments:
  245. x, y = s.T # segment xy
  246. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  247. return xyxy2xywh(np.array(boxes)) # cls, xywh
  248. def resample_segments(segments, n=1000):
  249. # Up-sample an (n,2) segment
  250. for i, s in enumerate(segments):
  251. x = np.linspace(0, len(s) - 1, n)
  252. xp = np.arange(len(s))
  253. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  254. return segments
  255. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  256. # Rescale coords (xyxy) from img1_shape to img0_shape
  257. if ratio_pad is None: # calculate from img0_shape
  258. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  259. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  260. else:
  261. gain = ratio_pad[0][0]
  262. pad = ratio_pad[1]
  263. coords[:, [0, 2]] -= pad[0] # x padding
  264. coords[:, [1, 3]] -= pad[1] # y padding
  265. coords[:, :4] /= gain
  266. clip_coords(coords, img0_shape)
  267. return coords
  268. def clip_coords(boxes, img_shape):
  269. # Clip bounding xyxy bounding boxes to image shape (height, width)
  270. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  271. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  272. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  273. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  274. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  275. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  276. box2 = box2.T
  277. # Get the coordinates of bounding boxes
  278. if x1y1x2y2: # x1, y1, x2, y2 = box1
  279. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  280. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  281. else: # transform from xywh to xyxy
  282. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  283. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  284. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  285. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  286. # Intersection area
  287. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  288. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  289. # Union Area
  290. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  291. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  292. union = w1 * h1 + w2 * h2 - inter + eps
  293. iou = inter / union
  294. if GIoU or DIoU or CIoU:
  295. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  296. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  297. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  298. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  299. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  300. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  301. if DIoU:
  302. return iou - rho2 / c2 # DIoU
  303. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  304. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  305. with torch.no_grad():
  306. alpha = v / (v - iou + (1 + eps))
  307. return iou - (rho2 / c2 + v * alpha) # CIoU
  308. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  309. c_area = cw * ch + eps # convex area
  310. return iou - (c_area - union) / c_area # GIoU
  311. else:
  312. return iou # IoU
  313. def box_iou(box1, box2):
  314. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  315. """
  316. Return intersection-over-union (Jaccard index) of boxes.
  317. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  318. Arguments:
  319. box1 (Tensor[N, 4])
  320. box2 (Tensor[M, 4])
  321. Returns:
  322. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  323. IoU values for every element in boxes1 and boxes2
  324. """
  325. def box_area(box):
  326. # box = 4xn
  327. return (box[2] - box[0]) * (box[3] - box[1])
  328. area1 = box_area(box1.T)
  329. area2 = box_area(box2.T)
  330. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  331. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  332. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  333. def wh_iou(wh1, wh2):
  334. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  335. wh1 = wh1[:, None] # [N,1,2]
  336. wh2 = wh2[None] # [1,M,2]
  337. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  338. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  339. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  340. labels=()):
  341. """Runs Non-Maximum Suppression (NMS) on inference results
  342. Returns:
  343. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  344. """
  345. nc = prediction.shape[2] - 5 # number of classes
  346. xc = (prediction[..., 4] > conf_thres) & ( prediction[..., 4] < 1.0000001 ) # candidates
  347. # Settings
  348. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  349. max_det = 300 # maximum number of detections per image
  350. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  351. time_limit = 10.0 # seconds to quit after
  352. redundant = True # require redundant detections
  353. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  354. merge = False # use merge-NMS
  355. t = time.time()
  356. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  357. for xi, x in enumerate(prediction): # image index, image inference
  358. # Apply constraints
  359. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  360. x = x[xc[xi]] # confidence
  361. # Cat apriori labels if autolabelling
  362. if labels and len(labels[xi]):
  363. l = labels[xi]
  364. v = torch.zeros((len(l), nc + 5), device=x.device)
  365. v[:, :4] = l[:, 1:5] # box
  366. v[:, 4] = 1.0 # conf
  367. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  368. x = torch.cat((x, v), 0)
  369. # If none remain process next image
  370. if not x.shape[0]:
  371. continue
  372. # Compute conf
  373. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  374. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  375. box = xywh2xyxy(x[:, :4])
  376. # Detections matrix nx6 (xyxy, conf, cls)
  377. if multi_label:
  378. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  379. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  380. else: # best class only
  381. conf, j = x[:, 5:].max(1, keepdim=True)
  382. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  383. # Filter by class
  384. if classes is not None:
  385. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  386. # Apply finite constraint
  387. # if not torch.isfinite(x).all():
  388. # x = x[torch.isfinite(x).all(1)]
  389. # Check shape
  390. n = x.shape[0] # number of boxes
  391. if not n: # no boxes
  392. continue
  393. elif n > max_nms: # excess boxes
  394. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  395. # Batched NMS
  396. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  397. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  398. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  399. if i.shape[0] > max_det: # limit detections
  400. i = i[:max_det]
  401. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  402. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  403. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  404. weights = iou * scores[None] # box weights
  405. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  406. if redundant:
  407. i = i[iou.sum(1) > 1] # require redundancy
  408. output[xi] = x[i]
  409. if (time.time() - t) > time_limit:
  410. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  411. break # time limit exceeded
  412. return output
  413. def overlap_box_suppression(prediction, ovlap_thres = 0.6):
  414. """Runs overlap_box_suppression on inference results
  415. delete the box that overlap of boxes bigger than ovlap_thres
  416. Returns:
  417. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  418. """
  419. def box_iob(box1, box2):
  420. def box_area(box):
  421. return (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
  422. area1 = box_area(box1) # (N,)
  423. area2 = box_area(box2) # (M,)
  424. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  425. lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] # N中一个和M个比较;
  426. rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2]
  427. wh = (rb - lt).clamp(min=0) #小于0的为0 clamp 钳;夹钳;
  428. inter = wh[:, :, 0] * wh[:, :, 1]
  429. return torch.squeeze(inter / area1), torch.squeeze(inter / area2)
  430. output = [torch.zeros((0, 6), device=prediction[0].device)] * len(prediction)
  431. for i, x in enumerate(prediction):
  432. keep = [] # 最终保留的结果, 在boxes中对应的索引;
  433. boxes = x[:, 0:4]
  434. scores = x[:, 4]
  435. cls = x[:, 5]
  436. idxs = scores.argsort()
  437. while idxs.numel() > 0:
  438. keep_idx = idxs[-1]
  439. keep_box = boxes[keep_idx][None, ] # [1, 4]
  440. keep.append(keep_idx)
  441. if idxs.size(0) == 1:
  442. break
  443. idxs = idxs[:-1] # 将得分最大框 从索引中删除; 剩余索引对应的框 和 得分最大框 计算iob;
  444. other_boxes = boxes[idxs]
  445. this_cls = cls[keep_idx]
  446. other_cls = cls[idxs]
  447. iobs1, iobs2 = box_iob(keep_box, other_boxes) # 一个框和其余框比较 1XM
  448. idxs = idxs[((iobs1 <= ovlap_thres) & (iobs2 <= ovlap_thres)) | (other_cls != this_cls)]
  449. keep = idxs.new(keep) # Tensor
  450. output[i] = x[keep]
  451. return output
  452. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  453. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  454. x = torch.load(f, map_location=torch.device('cpu'))
  455. if x.get('ema'):
  456. x['model'] = x['ema'] # replace model with ema
  457. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  458. x[k] = None
  459. x['epoch'] = -1
  460. x['model'].half() # to FP16
  461. for p in x['model'].parameters():
  462. p.requires_grad = False
  463. torch.save(x, s or f)
  464. mb = os.path.getsize(s or f) / 1E6 # filesize
  465. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  466. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  467. # Print mutation results to evolve.txt (for use with train.py --evolve)
  468. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  469. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  470. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  471. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  472. if bucket:
  473. url = 'gs://%s/evolve.txt' % bucket
  474. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  475. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  476. with open('evolve.txt', 'a') as f: # append result
  477. f.write(c + b + '\n')
  478. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  479. x = x[np.argsort(-fitness(x))] # sort
  480. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  481. # Save yaml
  482. for i, k in enumerate(hyp.keys()):
  483. hyp[k] = float(x[0, i + 7])
  484. with open(yaml_file, 'w') as f:
  485. results = tuple(x[0, :7])
  486. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  487. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  488. yaml.dump(hyp, f, sort_keys=False)
  489. if bucket:
  490. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  491. def apply_classifier(x, model, img, im0):
  492. # applies a second stage classifier to yolo outputs
  493. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  494. for i, d in enumerate(x): # per image
  495. if d is not None and len(d):
  496. d = d.clone()
  497. # Reshape and pad cutouts
  498. b = xyxy2xywh(d[:, :4]) # boxes
  499. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  500. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  501. d[:, :4] = xywh2xyxy(b).long()
  502. # Rescale boxes from img_size to im0 size
  503. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  504. # Classes
  505. pred_cls1 = d[:, 5].long()
  506. ims = []
  507. for j, a in enumerate(d): # per item
  508. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  509. im = cv2.resize(cutout, (224, 224)) # BGR
  510. # cv2.imwrite('test%i.jpg' % j, cutout)
  511. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  512. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  513. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  514. ims.append(im)
  515. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  516. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  517. return x
  518. def increment_path(path, exist_ok=True, sep=''):
  519. # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
  520. path = Path(path) # os-agnostic
  521. if (path.exists() and exist_ok) or (not path.exists()):
  522. return str(path)
  523. else:
  524. dirs = glob.glob(f"{path}{sep}*") # similar paths
  525. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  526. i = [int(m.groups()[0]) for m in matches if m] # indices
  527. n = max(i) + 1 if i else 2 # increment number
  528. return f"{path}{sep}{n}" # update path