Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

general.py 24KB

il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
YOLOv5 Segmentation Dataloader Updates (#2188) * Update C3 module * Update C3 module * Update C3 module * Update C3 module * update * update * update * update * update * update * update * update * update * updates * updates * updates * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * updates * updates * updates * updates * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update datasets * update * update * update * update attempt_downlaod() * merge * merge * update * update * update * update * update * update * update * update * update * update * parameterize eps * comments * gs-multiple * update * max_nms implemented * Create one_cycle() function * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * GitHub API rate limit fix * update * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * astuple * epochs * update * update * ComputeLoss() * update * update * update * update * update * update * update * update * update * update * update * merge * merge * merge * merge * update * update * update * update * commit=tag == tags[-1] * Update cudnn.benchmark * update * update * update * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * update * mosaic9 * update * update * update * update * update * update * institute cache versioning * only display on existing cache * reverse cache exists booleans
il y a 3 ans
YOLOv5 Segmentation Dataloader Updates (#2188) * Update C3 module * Update C3 module * Update C3 module * Update C3 module * update * update * update * update * update * update * update * update * update * updates * updates * updates * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * updates * updates * updates * updates * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update datasets * update * update * update * update attempt_downlaod() * merge * merge * update * update * update * update * update * update * update * update * update * update * parameterize eps * comments * gs-multiple * update * max_nms implemented * Create one_cycle() function * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * GitHub API rate limit fix * update * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * ComputeLoss * astuple * epochs * update * update * ComputeLoss() * update * update * update * update * update * update * update * update * update * update * update * merge * merge * merge * merge * update * update * update * update * commit=tag == tags[-1] * Update cudnn.benchmark * update * update * update * updates * updates * updates * updates * updates * updates * updates * update * update * update * update * update * mosaic9 * update * update * update * update * update * update * institute cache versioning * only display on existing cache * reverse cache exists booleans
il y a 3 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
il y a 4 ans
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. # YOLOv5 general utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import torch
  15. import torchvision
  16. import yaml
  17. from utils.google_utils import gsutil_getsize
  18. from utils.metrics import fitness
  19. from utils.torch_utils import init_torch_seeds
  20. # Settings
  21. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  22. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  23. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  24. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  25. def set_logging(rank=-1):
  26. logging.basicConfig(
  27. format="%(message)s",
  28. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  29. def init_seeds(seed=0):
  30. # Initialize random number generator (RNG) seeds
  31. random.seed(seed)
  32. np.random.seed(seed)
  33. init_torch_seeds(seed)
  34. def get_latest_run(search_dir='.'):
  35. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  36. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  37. return max(last_list, key=os.path.getctime) if last_list else ''
  38. def isdocker():
  39. # Is environment a Docker container
  40. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  41. def emojis(str=''):
  42. # Return platform-dependent emoji-safe version of string
  43. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  44. def check_online():
  45. # Check internet connectivity
  46. import socket
  47. try:
  48. socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
  49. return True
  50. except OSError:
  51. return False
  52. def check_git_status():
  53. # Recommend 'git pull' if code is out of date
  54. print(colorstr('github: '), end='')
  55. try:
  56. assert Path('.git').exists(), 'skipping check (not a git repository)'
  57. assert not isdocker(), 'skipping check (Docker image)'
  58. assert check_online(), 'skipping check (offline)'
  59. cmd = 'git fetch && git config --get remote.origin.url'
  60. url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
  61. branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  62. n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  63. if n > 0:
  64. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  65. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  66. else:
  67. s = f'up to date with {url} ✅'
  68. print(emojis(s)) # emoji-safe
  69. except Exception as e:
  70. print(e)
  71. def check_requirements(file='requirements.txt', exclude=()):
  72. # Check installed dependencies meet requirements
  73. import pkg_resources as pkg
  74. prefix = colorstr('red', 'bold', 'requirements:')
  75. file = Path(file)
  76. if not file.exists():
  77. print(f"{prefix} {file.resolve()} not found, check failed.")
  78. return
  79. n = 0 # number of packages updates
  80. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  81. for r in requirements:
  82. try:
  83. pkg.require(r)
  84. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  85. n += 1
  86. print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
  87. print(subprocess.check_output(f"pip install '{e.req}'", shell=True).decode())
  88. if n: # if packages updated
  89. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file.resolve()}\n" \
  90. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  91. print(emojis(s)) # emoji-safe
  92. def check_img_size(img_size, s=32):
  93. # Verify img_size is a multiple of stride s
  94. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  95. if new_size != img_size:
  96. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  97. return new_size
  98. def check_imshow():
  99. # Check if environment supports image displays
  100. try:
  101. assert not isdocker(), 'cv2.imshow() is disabled in Docker environments'
  102. cv2.imshow('test', np.zeros((1, 1, 3)))
  103. cv2.waitKey(1)
  104. cv2.destroyAllWindows()
  105. cv2.waitKey(1)
  106. return True
  107. except Exception as e:
  108. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  109. return False
  110. def check_file(file):
  111. # Search for file if not found
  112. if os.path.isfile(file) or file == '':
  113. return file
  114. else:
  115. files = glob.glob('./**/' + file, recursive=True) # find file
  116. assert len(files), 'File Not Found: %s' % file # assert file was found
  117. assert len(files) == 1, "Multiple files match '%s', specify exact path: %s" % (file, files) # assert unique
  118. return files[0] # return file
  119. def check_dataset(dict):
  120. # Download dataset if not found locally
  121. val, s = dict.get('val'), dict.get('download')
  122. if val and len(val):
  123. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  124. if not all(x.exists() for x in val):
  125. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  126. if s and len(s): # download script
  127. print('Downloading %s ...' % s)
  128. if s.startswith('http') and s.endswith('.zip'): # URL
  129. f = Path(s).name # filename
  130. torch.hub.download_url_to_file(s, f)
  131. r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
  132. else: # bash script
  133. r = os.system(s)
  134. print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
  135. else:
  136. raise Exception('Dataset not found.')
  137. def make_divisible(x, divisor):
  138. # Returns x evenly divisible by divisor
  139. return math.ceil(x / divisor) * divisor
  140. def clean_str(s):
  141. # Cleans a string by replacing special characters with underscore _
  142. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  143. def one_cycle(y1=0.0, y2=1.0, steps=100):
  144. # lambda function for sinusoidal ramp from y1 to y2
  145. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  146. def colorstr(*input):
  147. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  148. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  149. colors = {'black': '\033[30m', # basic colors
  150. 'red': '\033[31m',
  151. 'green': '\033[32m',
  152. 'yellow': '\033[33m',
  153. 'blue': '\033[34m',
  154. 'magenta': '\033[35m',
  155. 'cyan': '\033[36m',
  156. 'white': '\033[37m',
  157. 'bright_black': '\033[90m', # bright colors
  158. 'bright_red': '\033[91m',
  159. 'bright_green': '\033[92m',
  160. 'bright_yellow': '\033[93m',
  161. 'bright_blue': '\033[94m',
  162. 'bright_magenta': '\033[95m',
  163. 'bright_cyan': '\033[96m',
  164. 'bright_white': '\033[97m',
  165. 'end': '\033[0m', # misc
  166. 'bold': '\033[1m',
  167. 'underline': '\033[4m'}
  168. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  169. def labels_to_class_weights(labels, nc=80):
  170. # Get class weights (inverse frequency) from training labels
  171. if labels[0] is None: # no labels loaded
  172. return torch.Tensor()
  173. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  174. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  175. weights = np.bincount(classes, minlength=nc) # occurrences per class
  176. # Prepend gridpoint count (for uCE training)
  177. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  178. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  179. weights[weights == 0] = 1 # replace empty bins with 1
  180. weights = 1 / weights # number of targets per class
  181. weights /= weights.sum() # normalize
  182. return torch.from_numpy(weights)
  183. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  184. # Produces image weights based on class_weights and image contents
  185. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  186. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  187. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  188. return image_weights
  189. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  190. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  191. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  192. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  193. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  194. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  195. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  196. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  197. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  198. return x
  199. def xyxy2xywh(x):
  200. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  201. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  202. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  203. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  204. y[:, 2] = x[:, 2] - x[:, 0] # width
  205. y[:, 3] = x[:, 3] - x[:, 1] # height
  206. return y
  207. def xywh2xyxy(x):
  208. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  209. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  210. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  211. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  212. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  213. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  214. return y
  215. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  216. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  217. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  218. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  219. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  220. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  221. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  222. return y
  223. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  224. # Convert normalized segments into pixel segments, shape (n,2)
  225. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  226. y[:, 0] = w * x[:, 0] + padw # top left x
  227. y[:, 1] = h * x[:, 1] + padh # top left y
  228. return y
  229. def segment2box(segment, width=640, height=640):
  230. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  231. x, y = segment.T # segment xy
  232. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  233. x, y, = x[inside], y[inside]
  234. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # cls, xyxy
  235. def segments2boxes(segments):
  236. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  237. boxes = []
  238. for s in segments:
  239. x, y = s.T # segment xy
  240. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  241. return xyxy2xywh(np.array(boxes)) # cls, xywh
  242. def resample_segments(segments, n=1000):
  243. # Up-sample an (n,2) segment
  244. for i, s in enumerate(segments):
  245. x = np.linspace(0, len(s) - 1, n)
  246. xp = np.arange(len(s))
  247. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  248. return segments
  249. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  250. # Rescale coords (xyxy) from img1_shape to img0_shape
  251. if ratio_pad is None: # calculate from img0_shape
  252. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  253. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  254. else:
  255. gain = ratio_pad[0][0]
  256. pad = ratio_pad[1]
  257. coords[:, [0, 2]] -= pad[0] # x padding
  258. coords[:, [1, 3]] -= pad[1] # y padding
  259. coords[:, :4] /= gain
  260. clip_coords(coords, img0_shape)
  261. return coords
  262. def clip_coords(boxes, img_shape):
  263. # Clip bounding xyxy bounding boxes to image shape (height, width)
  264. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  265. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  266. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  267. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  268. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  269. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  270. box2 = box2.T
  271. # Get the coordinates of bounding boxes
  272. if x1y1x2y2: # x1, y1, x2, y2 = box1
  273. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  274. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  275. else: # transform from xywh to xyxy
  276. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  277. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  278. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  279. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  280. # Intersection area
  281. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  282. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  283. # Union Area
  284. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  285. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  286. union = w1 * h1 + w2 * h2 - inter + eps
  287. iou = inter / union
  288. if GIoU or DIoU or CIoU:
  289. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  290. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  291. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  292. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  293. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  294. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  295. if DIoU:
  296. return iou - rho2 / c2 # DIoU
  297. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  298. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  299. with torch.no_grad():
  300. alpha = v / (v - iou + (1 + eps))
  301. return iou - (rho2 / c2 + v * alpha) # CIoU
  302. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  303. c_area = cw * ch + eps # convex area
  304. return iou - (c_area - union) / c_area # GIoU
  305. else:
  306. return iou # IoU
  307. def box_iou(box1, box2):
  308. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  309. """
  310. Return intersection-over-union (Jaccard index) of boxes.
  311. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  312. Arguments:
  313. box1 (Tensor[N, 4])
  314. box2 (Tensor[M, 4])
  315. Returns:
  316. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  317. IoU values for every element in boxes1 and boxes2
  318. """
  319. def box_area(box):
  320. # box = 4xn
  321. return (box[2] - box[0]) * (box[3] - box[1])
  322. area1 = box_area(box1.T)
  323. area2 = box_area(box2.T)
  324. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  325. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  326. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  327. def wh_iou(wh1, wh2):
  328. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  329. wh1 = wh1[:, None] # [N,1,2]
  330. wh2 = wh2[None] # [1,M,2]
  331. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  332. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  333. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  334. labels=()):
  335. """Runs Non-Maximum Suppression (NMS) on inference results
  336. Returns:
  337. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  338. """
  339. nc = prediction.shape[2] - 5 # number of classes
  340. xc = prediction[..., 4] > conf_thres # candidates
  341. # Settings
  342. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  343. max_det = 300 # maximum number of detections per image
  344. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  345. time_limit = 10.0 # seconds to quit after
  346. redundant = True # require redundant detections
  347. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  348. merge = False # use merge-NMS
  349. t = time.time()
  350. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  351. for xi, x in enumerate(prediction): # image index, image inference
  352. # Apply constraints
  353. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  354. x = x[xc[xi]] # confidence
  355. # Cat apriori labels if autolabelling
  356. if labels and len(labels[xi]):
  357. l = labels[xi]
  358. v = torch.zeros((len(l), nc + 5), device=x.device)
  359. v[:, :4] = l[:, 1:5] # box
  360. v[:, 4] = 1.0 # conf
  361. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  362. x = torch.cat((x, v), 0)
  363. # If none remain process next image
  364. if not x.shape[0]:
  365. continue
  366. # Compute conf
  367. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  368. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  369. box = xywh2xyxy(x[:, :4])
  370. # Detections matrix nx6 (xyxy, conf, cls)
  371. if multi_label:
  372. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  373. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  374. else: # best class only
  375. conf, j = x[:, 5:].max(1, keepdim=True)
  376. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  377. # Filter by class
  378. if classes is not None:
  379. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  380. # Apply finite constraint
  381. # if not torch.isfinite(x).all():
  382. # x = x[torch.isfinite(x).all(1)]
  383. # Check shape
  384. n = x.shape[0] # number of boxes
  385. if not n: # no boxes
  386. continue
  387. elif n > max_nms: # excess boxes
  388. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  389. # Batched NMS
  390. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  391. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  392. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  393. if i.shape[0] > max_det: # limit detections
  394. i = i[:max_det]
  395. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  396. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  397. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  398. weights = iou * scores[None] # box weights
  399. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  400. if redundant:
  401. i = i[iou.sum(1) > 1] # require redundancy
  402. output[xi] = x[i]
  403. if (time.time() - t) > time_limit:
  404. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  405. break # time limit exceeded
  406. return output
  407. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  408. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  409. x = torch.load(f, map_location=torch.device('cpu'))
  410. if x.get('ema'):
  411. x['model'] = x['ema'] # replace model with ema
  412. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  413. x[k] = None
  414. x['epoch'] = -1
  415. x['model'].half() # to FP16
  416. for p in x['model'].parameters():
  417. p.requires_grad = False
  418. torch.save(x, s or f)
  419. mb = os.path.getsize(s or f) / 1E6 # filesize
  420. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  421. def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
  422. # Print mutation results to evolve.txt (for use with train.py --evolve)
  423. a = '%10s' * len(hyp) % tuple(hyp.keys()) # hyperparam keys
  424. b = '%10.3g' * len(hyp) % tuple(hyp.values()) # hyperparam values
  425. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  426. print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
  427. if bucket:
  428. url = 'gs://%s/evolve.txt' % bucket
  429. if gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):
  430. os.system('gsutil cp %s .' % url) # download evolve.txt if larger than local
  431. with open('evolve.txt', 'a') as f: # append result
  432. f.write(c + b + '\n')
  433. x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0) # load unique rows
  434. x = x[np.argsort(-fitness(x))] # sort
  435. np.savetxt('evolve.txt', x, '%10.3g') # save sort by fitness
  436. # Save yaml
  437. for i, k in enumerate(hyp.keys()):
  438. hyp[k] = float(x[0, i + 7])
  439. with open(yaml_file, 'w') as f:
  440. results = tuple(x[0, :7])
  441. c = '%10.4g' * len(results) % results # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)
  442. f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')
  443. yaml.dump(hyp, f, sort_keys=False)
  444. if bucket:
  445. os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket)) # upload
  446. def apply_classifier(x, model, img, im0):
  447. # applies a second stage classifier to yolo outputs
  448. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  449. for i, d in enumerate(x): # per image
  450. if d is not None and len(d):
  451. d = d.clone()
  452. # Reshape and pad cutouts
  453. b = xyxy2xywh(d[:, :4]) # boxes
  454. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  455. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  456. d[:, :4] = xywh2xyxy(b).long()
  457. # Rescale boxes from img_size to im0 size
  458. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  459. # Classes
  460. pred_cls1 = d[:, 5].long()
  461. ims = []
  462. for j, a in enumerate(d): # per item
  463. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  464. im = cv2.resize(cutout, (224, 224)) # BGR
  465. # cv2.imwrite('test%i.jpg' % j, cutout)
  466. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  467. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  468. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  469. ims.append(im)
  470. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  471. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  472. return x
  473. def increment_path(path, exist_ok=True, sep=''):
  474. # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
  475. path = Path(path) # os-agnostic
  476. if (path.exists() and exist_ok) or (not path.exists()):
  477. return str(path)
  478. else:
  479. dirs = glob.glob(f"{path}{sep}*") # similar paths
  480. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  481. i = [int(m.groups()[0]) for m in matches if m] # indices
  482. n = max(i) + 1 if i else 2 # increment number
  483. return f"{path}{sep}{n}" # update path