您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

YOLOv5 Segmentation Dataloader Updates (#2188) * Update C3 module * Update C3 module * Update C3 module * Update C3 module * update * update * update * update * update * update * update * update * update * updates * updates * updates * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * updates * updates * updates * updates * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update datasets * update * update * update * update attempt_downlaod() * merge * merge * update * update * update * update * update * update * update * update * update * update * parameterize eps * comments * gs-multiple * update * max_nms implemented * Create one_cycle() function * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * GitHub API rate limit fix * update * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * astuple * epochs * update * update * ComputeLoss() * update * update * update * update * update * update * update * update * update * update * update * merge * merge * merge * merge * update * update * update * update * commit=tag == tags[-1] * Update cudnn.benchmark * update * update * update * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * update * mosaic9 * update * update * update * update * update * update * institute cache versioning * only display on existing cache * reverse cache exists booleans
3 年前
YOLOv5 Segmentation Dataloader Updates (#2188) * Update C3 module * Update C3 module * Update C3 module * Update C3 module * update * update * update * update * update * update * update * update * update * updates * updates * updates * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * updates * updates * updates * updates * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update datasets * update * update * update * update attempt_downlaod() * merge * merge * update * update * update * update * update * update * update * update * update * update * parameterize eps * comments * gs-multiple * update * max_nms implemented * Create one_cycle() function * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * GitHub API rate limit fix * update * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * astuple * epochs * update * update * ComputeLoss() * update * update * update * update * update * update * update * update * update * update * update * merge * merge * merge * merge * update * update * update * update * commit=tag == tags[-1] * Update cudnn.benchmark * update * update * update * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * update * mosaic9 * update * update * update * update * update * update * institute cache versioning * only display on existing cache * reverse cache exists booleans
3 年前
4 年前
4 年前
4 年前
4 年前
4 年前
4 年前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. # General utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import torch
  15. import torchvision
  16. import yaml
  17. from utils.google_utils import gsutil_getsize
  18. from utils.metrics import fitness
  19. from utils.torch_utils import init_torch_seeds
  20. # Settings
  21. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  22. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  23. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  24. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  25. def set_logging(rank=-1):
  26. logging.basicConfig(
  27. format="%(message)s",
  28. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  29. def init_seeds(seed=0):
  30. # Initialize random number generator (RNG) seeds
  31. random.seed(seed)
  32. np.random.seed(seed)
  33. init_torch_seeds(seed)
  34. def get_latest_run(search_dir='.'):
  35. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  36. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  37. return max(last_list, key=os.path.getctime) if last_list else ''
  38. def isdocker():
  39. # Is environment a Docker container
  40. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  41. def check_online():
  42. # Check internet connectivity
  43. import socket
  44. try:
  45. socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
  46. return True
  47. except OSError:
  48. return False
  49. def check_git_status():
  50. # Recommend 'git pull' if code is out of date
  51. print(colorstr('github: '), end='')
  52. try:
  53. assert Path('.git').exists(), 'skipping check (not a git repository)'
  54. assert not isdocker(), 'skipping check (Docker image)'
  55. assert check_online(), 'skipping check (offline)'
  56. cmd = 'git fetch && git config --get remote.origin.url'
  57. url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
  58. branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  59. n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  60. if n > 0:
  61. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  62. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  63. else:
  64. s = f'up to date with {url} ✅'
  65. print(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s)
  66. except Exception as e:
  67. print(e)
  68. def check_requirements(file='requirements.txt', exclude=()):
  69. # Check installed dependencies meet requirements
  70. import pkg_resources
  71. requirements = [f'{x.name}{x.specifier}' for x in pkg_resources.parse_requirements(Path(file).open())
  72. if x.name not in exclude]
  73. pkg_resources.require(requirements) # DistributionNotFound or VersionConflict exception if requirements not met
  74. def check_img_size(img_size, s=32):
  75. # Verify img_size is a multiple of stride s
  76. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  77. if new_size != img_size:
  78. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  79. return new_size
  80. def check_imshow():
  81. # Check if environment supports image displays
  82. try:
  83. assert not isdocker(), 'cv2.imshow() is disabled in Docker environments'
  84. cv2.imshow('test', np.zeros((1, 1, 3)))
  85. cv2.waitKey(1)
  86. cv2.destroyAllWindows()
  87. cv2.waitKey(1)
  88. return True
  89. except Exception as e:
  90. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  91. return False
  92. def check_file(file):
  93. # Search for file if not found
  94. if os.path.isfile(file) or file == '':
  95. return file
  96. else:
  97. files = glob.glob('./**/' + file, recursive=True) # find file
  98. assert len(files), 'File Not Found: %s' % file # assert file was found
  99. assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files) # assert unique
  100. return files[0] # return file
  101. def check_dataset(dict):
  102. # Download dataset if not found locally
  103. val, s = dict.get('val'), dict.get('download')
  104. if val and len(val):
  105. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  106. if not all(x.exists() for x in val):
  107. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  108. if s and len(s): # download script
  109. print('Downloading %s ...' % s)
  110. if s.startswith('http') and s.endswith('.zip'): # URL
  111. f = Path(s).name # filename
  112. torch.hub.download_url_to_file(s, f)
  113. r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
  114. else: # bash script
  115. r = os.system(s)
  116. print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
  117. else:
  118. raise Exception('Dataset not found.')
  119. def make_divisible(x, divisor):
  120. # Returns x evenly divisible by divisor
  121. return math.ceil(x / divisor) * divisor
  122. def clean_str(s):
  123. # Cleans a string by replacing special characters with underscore _
  124. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  125. def one_cycle(y1=0.0, y2=1.0, steps=100):
  126. # lambda function for sinusoidal ramp from y1 to y2
  127. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  128. def colorstr(*input):
  129. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  130. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  131. colors = {'black': '\033[30m', # basic colors
  132. 'red': '\033[31m',
  133. 'green': '\033[32m',
  134. 'yellow': '\033[33m',
  135. 'blue': '\033[34m',
  136. 'magenta': '\033[35m',
  137. 'cyan': '\033[36m',
  138. 'white': '\033[37m',
  139. 'bright_black': '\033[90m', # bright colors
  140. 'bright_red': '\033[91m',
  141. 'bright_green': '\033[92m',
  142. 'bright_yellow': '\033[93m',
  143. 'bright_blue': '\033[94m',
  144. 'bright_magenta': '\033[95m',
  145. 'bright_cyan': '\033[96m',
  146. 'bright_white': '\033[97m',
  147. 'end': '\033[0m', # misc
  148. 'bold': '\033[1m',
  149. 'underline': '\033[4m'}
  150. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  151. def labels_to_class_weights(labels, nc=80):
  152. # Get class weights (inverse frequency) from training labels
  153. if labels[0] is None: # no labels loaded
  154. return torch.Tensor()
  155. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  156. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  157. weights = np.bincount(classes, minlength=nc) # occurrences per class
  158. # Prepend gridpoint count (for uCE training)
  159. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  160. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  161. weights[weights == 0] = 1 # replace empty bins with 1
  162. weights = 1 / weights # number of targets per class
  163. weights /= weights.sum() # normalize
  164. return torch.from_numpy(weights)
  165. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  166. # Produces image weights based on class_weights and image contents
  167. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  168. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  169. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  170. return image_weights
  171. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  172. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  173. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  174. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  175. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  176. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  177. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  178. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  179. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  180. return x
  181. def xyxy2xywh(x):
  182. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  183. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  184. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  185. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  186. y[:, 2] = x[:, 2] - x[:, 0] # width
  187. y[:, 3] = x[:, 3] - x[:, 1] # height
  188. return y
  189. def xywh2xyxy(x):
  190. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  191. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  192. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  193. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  194. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  195. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  196. return y
  197. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  198. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  199. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  200. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  201. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  202. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  203. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  204. return y
  205. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  206. # Convert normalized segments into pixel segments, shape (n,2)
  207. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  208. y[:, 0] = w * x[:, 0] + padw # top left x
  209. y[:, 1] = h * x[:, 1] + padh # top left y
  210. return y
  211. def segment2box(segment, width=640, height=640):
  212. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  213. x, y = segment.T # segment xy
  214. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  215. x, y, = x[inside], y[inside]
  216. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # cls, xyxy
  217. def segments2boxes(segments):
  218. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  219. boxes = []
  220. for s in segments:
  221. x, y = s.T # segment xy
  222. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  223. return xyxy2xywh(np.array(boxes)) # cls, xywh
  224. def resample_segments(segments, n=1000):
  225. # Up-sample an (n,2) segment
  226. for i, s in enumerate(segments):
  227. x = np.linspace(0, len(s) - 1, n)
  228. xp = np.arange(len(s))
  229. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  230. return segments
  231. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  232. # Rescale coords (xyxy) from img1_shape to img0_shape
  233. if ratio_pad is None: # calculate from img0_shape
  234. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  235. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  236. else:
  237. gain = ratio_pad[0][0]
  238. pad = ratio_pad[1]
  239. coords[:, [0, 2]] -= pad[0] # x padding
  240. coords[:, [1, 3]] -= pad[1] # y padding
  241. coords[:, :4] /= gain
  242. clip_coords(coords, img0_shape)
  243. return coords
  244. def clip_coords(boxes, img_shape):
  245. # Clip bounding xyxy bounding boxes to image shape (height, width)
  246. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  247. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  248. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  249. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  250. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9):
  251. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  252. box2 = box2.T
  253. # Get the coordinates of bounding boxes
  254. if x1y1x2y2: # x1, y1, x2, y2 = box1
  255. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  256. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  257. else: # transform from xywh to xyxy
  258. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  259. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  260. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  261. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  262. # Intersection area
  263. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  264. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  265. # Union Area
  266. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  267. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  268. union = w1 * h1 + w2 * h2 - inter + eps
  269. iou = inter / union
  270. if GIoU or DIoU or CIoU:
  271. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  272. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  273. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  274. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  275. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  276. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  277. if DIoU:
  278. return iou - rho2 / c2 # DIoU
  279. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  280. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  281. with torch.no_grad():
  282. alpha = v / ((1 + eps) - iou + v)
  283. return iou - (rho2 / c2 + v * alpha) # CIoU
  284. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  285. c_area = cw * ch + eps # convex area
  286. return iou - (c_area - union) / c_area # GIoU
  287. else:
  288. return iou # IoU
  289. def box_iou(box1, box2):
  290. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  291. """
  292. Return intersection-over-union (Jaccard index) of boxes.
  293. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  294. Arguments:
  295. box1 (Tensor[N, 4])
  296. box2 (Tensor[M, 4])
  297. Returns:
  298. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  299. IoU values for every element in boxes1 and boxes2
  300. """
  301. def box_area(box):
  302. # box = 4xn
  303. return (box[2] - box[0]) * (box[3] - box[1])
  304. area1 = box_area(box1.T)
  305. area2 = box_area(box2.T)
  306. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  307. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  308. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  309. def wh_iou(wh1, wh2):
  310. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  311. wh1 = wh1[:, None] # [N,1,2]
  312. wh2 = wh2[None] # [1,M,2]
  313. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  314. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  315. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
  316. """Performs Non-Maximum Suppression (NMS) on inference results
  317. Returns:
  318. detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
  319. """
  320. nc = prediction.shape[2] - 5 # number of classes
  321. xc = prediction[..., 4] > conf_thres # candidates
  322. # Settings
  323. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  324. max_det = 300 # maximum number of detections per image
  325. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  326. time_limit = 10.0 # seconds to quit after
  327. redundant = True # require redundant detections
  328. multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
  329. merge = False # use merge-NMS
  330. t = time.time()
  331. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  332. for xi, x in enumerate(prediction): # image index, image inference
  333. # Apply constraints
  334. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  335. x = x[xc[xi]] # confidence
  336. # Cat apriori labels if autolabelling
  337. if labels and len(labels[xi]):
  338. l = labels[xi]
  339. v = torch.zeros((len(l), nc + 5), device=x.device)
  340. v[:, :4] = l[:, 1:5] # box
  341. v[:, 4] = 1.0 # conf
  342. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  343. x = torch.cat((x, v), 0)
  344. # If none remain process next image
  345. if not x.shape[0]:
  346. continue
  347. # Compute conf
  348. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  349. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  350. box = xywh2xyxy(x[:, :4])
  351. # Detections matrix nx6 (xyxy, conf, cls)
  352. if multi_label:
  353. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  354. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  355. else: # best class only
  356. conf, j = x[:, 5:].max(1, keepdim=True)
  357. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  358. # Filter by class
  359. if classes is not None:
  360. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  361. # Apply finite constraint
  362. # if not torch.isfinite(x).all():
  363. # x = x[torch.isfinite(x).all(1)]
  364. # Check shape
  365. n = x.shape[0] # number of boxes
  366. if not n: # no boxes
  367. continue
  368. elif n > max_nms: # excess boxes
  369. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  370. # Batched NMS
  371. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  372. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  373. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  374. if i.shape[0] > max_det: # limit detections
  375. i = i[:max_det]
  376. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  377. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  378. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  379. weights = iou * scores[None] # box weights
  380. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  381. if redundant:
  382. i = i[iou.sum(1) > 1] # require redundancy
  383. output[xi] = x[i]
  384. if (time.time() - t) > time_limit:
  385. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  386. break # time limit exceeded
  387. return output
  388. def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
  389. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  390. x = torch.load(f, map_location=torch.device('cpu'))
  391. for key in 'optimizer', 'training_results', 'wandb_id':
  392. x[key] = None
  393. x['epoch'] = -1
  394. x['model'].half() # to FP16
  395. for p in x['model'].parameters():
  396. p.requires_grad = False
  397. torch.save(x, s or f)
  398. mb = os.path.getsize(s or f) / 1E6 # filesize
  399. print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
  400. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  401. # Print mutation results to evolve.txt (for use with train.py --evolve)
  402. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  403. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  404. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  405. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  406. if bucket:
  407. url = 'gs://%s/evolve.txt' % bucket
  408. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  409. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  410. with open('evolve.txt', 'a') as f: # append result
  411. f.write(c + b + '\n')
  412. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  413. x = x[np.argsort(-fitness(x))] # sort
  414. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  415. # Save yaml
  416. for i, k in enumerate(hyp.keys()):
  417. hyp[k] = float(x[0, i + 7])
  418. with open(yaml_file, 'w') as f:
  419. results = tuple(x[0, :7])
  420. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  421. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  422. yaml.dump(hyp, f, sort_keys=False)
  423. if bucket:
  424. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  425. def apply_classifier(x, model, img, im0):
  426. # applies a second stage classifier to yolo outputs
  427. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  428. for i, d in enumerate(x): # per image
  429. if d is not None and len(d):
  430. d = d.clone()
  431. # Reshape and pad cutouts
  432. b = xyxy2xywh(d[:, :4]) # boxes
  433. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  434. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  435. d[:, :4] = xywh2xyxy(b).long()
  436. # Rescale boxes from img_size to im0 size
  437. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  438. # Classes
  439. pred_cls1 = d[:, 5].long()
  440. ims = []
  441. for j, a in enumerate(d): # per item
  442. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  443. im = cv2.resize(cutout, (224, 224)) # BGR
  444. # cv2.imwrite('test%i.jpg' % j, cutout)
  445. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  446. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  447. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  448. ims.append(im)
  449. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  450. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  451. return x
  452. def increment_path(path, exist_ok=True, sep=''):
  453. # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
  454. path = Path(path) # os-agnostic
  455. if (path.exists() and exist_ok) or (not path.exists()):
  456. return str(path)
  457. else:
  458. dirs = glob.glob(f"{path}{sep}*") # similar paths
  459. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  460. i = [int(m.groups()[0]) for m in matches if m] # indices
  461. n = max(i) + 1 if i else 2 # increment number
  462. return f"{path}{sep}{n}" # update path