您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

1年前
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. # YOLOv5 general utils
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import platform
  7. import random
  8. import re
  9. import subprocess
  10. import time
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import pandas as pd
  15. import torch
  16. import torchvision
  17. import yaml
  18. # Settings
  19. torch.set_printoptions(linewidth=320, precision=5, profile='long')
  20. np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
  21. pd.options.display.max_columns = 10
  22. cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
  23. os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
  24. def set_logging(rank=-1):
  25. logging.basicConfig(
  26. format="%(message)s",
  27. level=logging.INFO if rank in [-1, 0] else logging.WARN)
  28. def get_latest_run(search_dir='.'):
  29. # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
  30. last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
  31. return max(last_list, key=os.path.getctime) if last_list else ''
  32. def isdocker():
  33. # Is environment a Docker container
  34. return Path('/workspace').exists() # or Path('/.dockerenv').exists()
  35. def emojis(str=''):
  36. # Return platform-dependent emoji-safe version of string
  37. return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
  38. def check_online():
  39. # Check internet connectivity
  40. import socket
  41. try:
  42. socket.create_connection(("1.1.1.1", 443), 5) # check host accesability
  43. return True
  44. except OSError:
  45. return False
  46. def check_git_status():
  47. # Recommend 'git pull' if code is out of date
  48. print(colorstr('github: '), end='')
  49. try:
  50. assert Path('.git').exists(), 'skipping check (not a git repository)'
  51. assert not isdocker(), 'skipping check (Docker image)'
  52. assert check_online(), 'skipping check (offline)'
  53. cmd = 'git fetch && git config --get remote.origin.url'
  54. url = subprocess.check_output(cmd, shell=True).decode().strip().rstrip('.git') # github repo url
  55. branch = subprocess.check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
  56. n = int(subprocess.check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind
  57. if n > 0:
  58. s = f"⚠️ WARNING: code is out of date by {n} commit{'s' * (n > 1)}. " \
  59. f"Use 'git pull' to update or 'git clone {url}' to download latest."
  60. else:
  61. s = f'up to date with {url} ✅'
  62. print(emojis(s)) # emoji-safe
  63. except Exception as e:
  64. print(e)
  65. def check_requirements(requirements='requirements.txt', exclude=()):
  66. # Check installed dependencies meet requirements (pass *.txt file or list of packages)
  67. import pkg_resources as pkg
  68. prefix = colorstr('red', 'bold', 'requirements:')
  69. if isinstance(requirements, (str, Path)): # requirements.txt file
  70. file = Path(requirements)
  71. if not file.exists():
  72. print(f"{prefix} {file.resolve()} not found, check failed.")
  73. return
  74. requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]
  75. else: # list or tuple of packages
  76. requirements = [x for x in requirements if x not in exclude]
  77. n = 0 # number of packages updates
  78. for r in requirements:
  79. try:
  80. pkg.require(r)
  81. except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
  82. n += 1
  83. print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
  84. print(subprocess.check_output(f"pip install {e.req}", shell=True).decode())
  85. if n: # if packages updated
  86. source = file.resolve() if 'file' in locals() else requirements
  87. s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
  88. f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
  89. print(emojis(s)) # emoji-safe
  90. def check_img_size(img_size, s=32):
  91. # Verify img_size is a multiple of stride s
  92. new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
  93. if new_size != img_size:
  94. print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
  95. return new_size
  96. def check_imshow():
  97. # Check if environment supports image displays
  98. try:
  99. assert not isdocker(), 'cv2.imshow() is disabled in Docker environments'
  100. cv2.imshow('test', np.zeros((1, 1, 3)))
  101. cv2.waitKey(1)
  102. cv2.destroyAllWindows()
  103. cv2.waitKey(1)
  104. return True
  105. except Exception as e:
  106. print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
  107. return False
  108. def check_file(file):
  109. # Search for file if not found
  110. if Path(file).is_file() or file == '':
  111. return file
  112. else:
  113. files = glob.glob('./**/' + file, recursive=True) # find file
  114. assert len(files), f'File Not Found: {file}' # assert file was found
  115. assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
  116. return files[0] # return file
  117. def check_dataset(dict):
  118. # Download dataset if not found locally
  119. val, s = dict.get('val'), dict.get('download')
  120. if val and len(val):
  121. val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
  122. if not all(x.exists() for x in val):
  123. print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])
  124. if s and len(s): # download script
  125. print('Downloading %s ...' % s)
  126. if s.startswith('http') and s.endswith('.zip'): # URL
  127. f = Path(s).name # filename
  128. torch.hub.download_url_to_file(s, f)
  129. r = os.system('unzip -q %s -d ../ && rm %s' % (f, f)) # unzip
  130. else: # bash script
  131. r = os.system(s)
  132. print('Dataset autodownload %s\n' % ('success' if r == 0 else 'failure')) # analyze return value
  133. else:
  134. raise Exception('Dataset not found.')
  135. def make_divisible(x, divisor):
  136. # Returns x evenly divisible by divisor
  137. return math.ceil(x / divisor) * divisor
  138. def clean_str(s):
  139. # Cleans a string by replacing special characters with underscore _
  140. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
  141. def one_cycle(y1=0.0, y2=1.0, steps=100):
  142. # lambda function for sinusoidal ramp from y1 to y2
  143. return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
  144. def colorstr(*input):
  145. # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
  146. *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
  147. colors = {'black': '\033[30m', # basic colors
  148. 'red': '\033[31m',
  149. 'green': '\033[32m',
  150. 'yellow': '\033[33m',
  151. 'blue': '\033[34m',
  152. 'magenta': '\033[35m',
  153. 'cyan': '\033[36m',
  154. 'white': '\033[37m',
  155. 'bright_black': '\033[90m', # bright colors
  156. 'bright_red': '\033[91m',
  157. 'bright_green': '\033[92m',
  158. 'bright_yellow': '\033[93m',
  159. 'bright_blue': '\033[94m',
  160. 'bright_magenta': '\033[95m',
  161. 'bright_cyan': '\033[96m',
  162. 'bright_white': '\033[97m',
  163. 'end': '\033[0m', # misc
  164. 'bold': '\033[1m',
  165. 'underline': '\033[4m'}
  166. return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
  167. def labels_to_class_weights(labels, nc=80):
  168. # Get class weights (inverse frequency) from training labels
  169. if labels[0] is None: # no labels loaded
  170. return torch.Tensor()
  171. labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
  172. classes = labels[:, 0].astype(np.int) # labels = [class xywh]
  173. weights = np.bincount(classes, minlength=nc) # occurrences per class
  174. # Prepend gridpoint count (for uCE training)
  175. # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
  176. # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
  177. weights[weights == 0] = 1 # replace empty bins with 1
  178. weights = 1 / weights # number of targets per class
  179. weights /= weights.sum() # normalize
  180. return torch.from_numpy(weights)
  181. def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
  182. # Produces image weights based on class_weights and image contents
  183. class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
  184. image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
  185. # index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
  186. return image_weights
  187. def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
  188. # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
  189. # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
  190. # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
  191. # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
  192. # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
  193. x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
  194. 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
  195. 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  196. return x
  197. def xyxy2xywh(x):
  198. # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
  199. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  200. y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
  201. y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
  202. y[:, 2] = x[:, 2] - x[:, 0] # width
  203. y[:, 3] = x[:, 3] - x[:, 1] # height
  204. return y
  205. def xywh2xyxy(x):
  206. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  207. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  208. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  209. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  210. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  211. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  212. return y
  213. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  214. # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  215. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  216. y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x
  217. y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y
  218. y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x
  219. y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y
  220. return y
  221. def xyn2xy(x, w=640, h=640, padw=0, padh=0):
  222. # Convert normalized segments into pixel segments, shape (n,2)
  223. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  224. y[:, 0] = w * x[:, 0] + padw # top left x
  225. y[:, 1] = h * x[:, 1] + padh # top left y
  226. return y
  227. def segment2box(segment, width=640, height=640):
  228. # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
  229. x, y = segment.T # segment xy
  230. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  231. x, y, = x[inside], y[inside]
  232. return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
  233. def segments2boxes(segments):
  234. # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  235. boxes = []
  236. for s in segments:
  237. x, y = s.T # segment xy
  238. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  239. return xyxy2xywh(np.array(boxes)) # cls, xywh
  240. def resample_segments(segments, n=1000):
  241. # Up-sample an (n,2) segment
  242. for i, s in enumerate(segments):
  243. x = np.linspace(0, len(s) - 1, n)
  244. xp = np.arange(len(s))
  245. segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
  246. return segments
  247. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  248. # Rescale coords (xyxy) from img1_shape to img0_shape
  249. if ratio_pad is None: # calculate from img0_shape
  250. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  251. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  252. else:
  253. gain = ratio_pad[0][0]
  254. pad = ratio_pad[1]
  255. coords[:, [0, 2]] -= pad[0] # x padding
  256. coords[:, [1, 3]] -= pad[1] # y padding
  257. coords[:, :4] /= gain
  258. clip_coords(coords, img0_shape)
  259. return coords
  260. def clip_coords(boxes, img_shape):
  261. # Clip bounding xyxy bounding boxes to image shape (height, width)
  262. boxes[:, 0].clamp_(0, img_shape[1]) # x1
  263. boxes[:, 1].clamp_(0, img_shape[0]) # y1
  264. boxes[:, 2].clamp_(0, img_shape[1]) # x2
  265. boxes[:, 3].clamp_(0, img_shape[0]) # y2
  266. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  267. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  268. box2 = box2.T
  269. # Get the coordinates of bounding boxes
  270. if x1y1x2y2: # x1, y1, x2, y2 = box1
  271. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  272. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  273. else: # transform from xywh to xyxy
  274. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  275. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  276. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  277. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  278. # Intersection area
  279. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  280. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  281. # Union Area
  282. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  283. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  284. union = w1 * h1 + w2 * h2 - inter + eps
  285. iou = inter / union
  286. if GIoU or DIoU or CIoU:
  287. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  288. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  289. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  290. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  291. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  292. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  293. if DIoU:
  294. return iou - rho2 / c2 # DIoU
  295. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  296. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  297. with torch.no_grad():
  298. alpha = v / (v - iou + (1 + eps))
  299. return iou - (rho2 / c2 + v * alpha) # CIoU
  300. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  301. c_area = cw * ch + eps # convex area
  302. return iou - (c_area - union) / c_area # GIoU
  303. else:
  304. return iou # IoU
  305. def box_iou(box1, box2):
  306. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  307. """
  308. Return intersection-over-union (Jaccard index) of boxes.
  309. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  310. Arguments:
  311. box1 (Tensor[N, 4])
  312. box2 (Tensor[M, 4])
  313. Returns:
  314. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  315. IoU values for every element in boxes1 and boxes2
  316. """
  317. def box_area(box):
  318. # box = 4xn
  319. return (box[2] - box[0]) * (box[3] - box[1])
  320. area1 = box_area(box1.T)
  321. area2 = box_area(box2.T)
  322. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  323. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  324. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  325. def wh_iou(wh1, wh2):
  326. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  327. wh1 = wh1[:, None] # [N,1,2]
  328. wh2 = wh2[None] # [1,M,2]
  329. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  330. return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
  331. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  332. labels=()):
  333. """Runs Non-Maximum Suppression (NMS) on inference results
  334. Returns:
  335. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  336. """
  337. nc = prediction.shape[2] - 5 # number of classes
  338. xc = prediction[..., 4] > conf_thres # candidates
  339. # Settings
  340. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  341. max_det = 300 # maximum number of detections per image
  342. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  343. time_limit = 10.0 # seconds to quit after
  344. redundant = True # require redundant detections
  345. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  346. merge = False # use merge-NMS
  347. t = time.time()
  348. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  349. for xi, x in enumerate(prediction): # image index, image inference
  350. # Apply constraints
  351. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  352. x = x[xc[xi]] # confidence
  353. # Cat apriori labels if autolabelling
  354. if labels and len(labels[xi]):
  355. l = labels[xi]
  356. v = torch.zeros((len(l), nc + 5), device=x.device)
  357. v[:, :4] = l[:, 1:5] # box
  358. v[:, 4] = 1.0 # conf
  359. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  360. x = torch.cat((x, v), 0)
  361. # If none remain process next image
  362. if not x.shape[0]:
  363. continue
  364. # Compute conf
  365. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  366. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  367. box = xywh2xyxy(x[:, :4])
  368. # Detections matrix nx6 (xyxy, conf, cls)
  369. if multi_label:
  370. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  371. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  372. else: # best class only
  373. conf, j = x[:, 5:].max(1, keepdim=True)
  374. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  375. # Filter by class
  376. if classes is not None:
  377. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  378. # Apply finite constraint
  379. # if not torch.isfinite(x).all():
  380. # x = x[torch.isfinite(x).all(1)]
  381. # Check shape
  382. n = x.shape[0] # number of boxes
  383. if not n: # no boxes
  384. continue
  385. elif n > max_nms: # excess boxes
  386. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  387. # Batched NMS
  388. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  389. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  390. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  391. if i.shape[0] > max_det: # limit detections
  392. i = i[:max_det]
  393. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  394. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  395. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  396. weights = iou * scores[None] # box weights
  397. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  398. if redundant:
  399. i = i[iou.sum(1) > 1] # require redundancy
  400. output[xi] = x[i]
  401. if (time.time() - t) > time_limit:
  402. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  403. break # time limit exceeded
  404. return output
  405. def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
  406. # Strip optimizer from 'f' to finalize training, optionally save as 's'
  407. x = torch.load(f, map_location=torch.device('cpu'))
  408. if x.get('ema'):
  409. x['model'] = x['ema'] # replace model with ema
  410. for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
  411. x[k] = None
  412. x['epoch'] = -1
  413. x['model'].half() # to FP16
  414. for p in x['model'].parameters():
  415. p.requires_grad = False
  416. torch.save(x, s or f)
  417. mb = os.path.getsize(s or f) / 1E6 # filesize
  418. print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
  419. def apply_classifier(x, model, img, im0):
  420. # applies a second stage classifier to yolo outputs
  421. im0 = [im0] if isinstance(im0, np.ndarray) else im0
  422. for i, d in enumerate(x): # per image
  423. if d is not None and len(d):
  424. d = d.clone()
  425. # Reshape and pad cutouts
  426. b = xyxy2xywh(d[:, :4]) # boxes
  427. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
  428. b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
  429. d[:, :4] = xywh2xyxy(b).long()
  430. # Rescale boxes from img_size to im0 size
  431. scale_coords(img.shape[2:], d[:, :4], im0[i].shape)
  432. # Classes
  433. pred_cls1 = d[:, 5].long()
  434. ims = []
  435. for j, a in enumerate(d): # per item
  436. cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
  437. im = cv2.resize(cutout, (224, 224)) # BGR
  438. # cv2.imwrite('test%i.jpg' % j, cutout)
  439. im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  440. im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
  441. im /= 255.0 # 0 - 255 to 0.0 - 1.0
  442. ims.append(im)
  443. pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
  444. x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
  445. return x
  446. def increment_path(path, exist_ok=True, sep=''):
  447. # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
  448. path = Path(path) # os-agnostic
  449. if (path.exists() and exist_ok) or (not path.exists()):
  450. return str(path)
  451. else:
  452. dirs = glob.glob(f"{path}{sep}*") # similar paths
  453. matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
  454. i = [int(m.groups()[0]) for m in matches if m] # indices
  455. n = max(i) + 1 if i else 2 # increment number
  456. return f"{path}{sep}{n}" # update path
  457. def numpy_detect(x,nc=None):
  458. # anchor_grid为先验眶
  459. anchor_grid = [10.0, 13.0, 16.0, 30.0, 33.0, 23.0, 30.0, 61.0, 62.0, 45.0, 59.0, 119.0, 116.0, 90.0, 156.0, 198.0, 373.0, 326.0]
  460. anchor_grid = np.array(anchor_grid).reshape(3,1,-1,1,1,2)
  461. stride = np.array([8, 16, 32])
  462. grid = [make_grid(80,80), make_grid(40,40), make_grid(20,20)]
  463. z = []
  464. for i in range(3):
  465. y = numpy_sigmoid(x[i])
  466. y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + grid[i]) * stride[i] # xy
  467. y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * anchor_grid[i] # wh
  468. z.append(y.reshape(1, -1, nc + 5))
  469. res = np.concatenate(z, 1)
  470. return res
  471. def numpy_sigmoid(x):
  472. return 1/(1+np.exp(-x))
  473. def make_grid(nx=20,ny=20):
  474. xv,yv = np.meshgrid(np.arange(nx), np.arange(ny))
  475. res = np.stack((xv,yv), 2).reshape(1,1,nx,ny,2).astype(np.float32)
  476. return res
  477. def img_pad(img, size, pad_value=[114,114,114]):
  478. H,W,_ = img.shape
  479. r = max(H/size[0], W/size[1])
  480. img_r = cv2.resize(img, (int(W/r), int(H/r)))
  481. tb = size[0] - img_r.shape[0]
  482. lr = size[1] - img_r.shape[1]
  483. top = int(tb/2)
  484. bottom = tb - top
  485. left = int(lr/2)
  486. right = lr - left
  487. pad_image = cv2.copyMakeBorder(img_r, top, bottom, left, right, cv2.BORDER_CONSTANT,value=pad_value)
  488. return pad_image,(top, left,r)
  489. def scale_back(boxes,padInfos):
  490. top, left,r = padInfos[0:3]
  491. boxes[:,0] = (boxes[:,0] - left) * r
  492. boxes[:,2] = (boxes[:,2] - left) * r
  493. boxes[:,1] = (boxes[:,1] - top) * r
  494. boxes[:,3] = (boxes[:,3] - top) * r
  495. return boxes
  496. def get_ms(t1,t0):
  497. return (t1-t0)*1000.0
  498. def pre_process(im0,device):
  499. img, padInfos = img_pad(im0, size=(640,640,3))
  500. img = img[:, :, ::-1].transpose(2, 0, 1)
  501. img = np.ascontiguousarray(img, dtype=np.float32)
  502. img = torch.from_numpy(img).to(device)
  503. img /= 255.0
  504. t3=time.time()
  505. if img.ndimension() == 3:
  506. img = img.unsqueeze(0)
  507. return img ,padInfos
  508. def post_process(pred,padInfos,device,conf_thres=0.4, iou_thres=0.45,nc=9):
  509. pred = [x.data.cpu().numpy() for x in pred]
  510. pred = numpy_detect(pred, nc)
  511. pred = torch.tensor(pred).to(device)
  512. pred = non_max_suppression(pred, conf_thres, iou_thres)
  513. for i, det in enumerate(pred): # only one image
  514. if len(det):
  515. det[:, :4] = scale_back( det[:, :4],padInfos)
  516. print()
  517. else:
  518. det = []
  519. if len(det)!=0:
  520. det=det.cpu().numpy()
  521. else: det=[]
  522. return det
  523. def get_return_data(img,boxes,modelType='code',plate_dilate=(0.1,0.1)):
  524. #name ['greeCode', 'yellowCode', 'redCode', 'hsImage', 'NameImage', 'word']
  525. #"greeCode","yellowCode","redCode","hsImage","hs48Image","NameImage","phoneNumberImage","word" , "phone"
  526. # 0 , 1 , 2 , 3 , 4 , 5 , 6, 7 , 8
  527. ##type: 0--车牌,1--健康码或者行程卡,2--其它
  528. results=[]
  529. h,w,c = img.shape
  530. fx = float(900.0/h)
  531. if modelType !='plate':
  532. etc={'type':0,'color':'green','nameImage':'','phoneNumberImage':'','cityImage':'','hsImage':'','plateImage':''}
  533. cols=[];cols_scores=[]
  534. score_hsImage= 0.0
  535. for box in boxes:
  536. pts=box[0:4]; pts=[int(x) for x in pts]; x0,y0,x1,y1=pts[0:4]
  537. x0=max(0,x0);y0=max(0,y0);x1=min(x1,w);y1=min(y1,h)
  538. typee=int(box[5]);score=str(box[4])
  539. if typee==7 or typee == 8:
  540. etc['type']=1 #检测到城市,或者行程卡上的手机号 ,为行程卡 ,##苏康码也定义为1
  541. if typee==9:
  542. etc['type']=2 #检测到苏康码字图像,则为苏康码
  543. if typee in [3,4,5,6,7,8]:
  544. image_corp = cv2.resize(img[y0:y1,x0:x1],None,fx=fx,fy=fx)
  545. if typee==3 or typee==4:
  546. if score_hsImage < float(score):
  547. etc['hsImage']= [image_corp,score] ; score_hsImage = float(score)
  548. if typee==5: etc['nameImage']= [image_corp,score]
  549. if typee==6: etc['phoneNumberImage']= [image_corp,score]
  550. if typee==7: etc['cityImage']= [image_corp,score]
  551. if typee==8: etc['phoneNumberImage']= [image_corp,score]
  552. if typee in [0,1,2]:
  553. if typee==0: cols.append('green');cols_scores.append(float(score))
  554. if typee==1: cols.append('yellow');cols_scores.append(float(score))
  555. if typee==2: cols.append('red');cols_scores.append(float(score))
  556. if len(cols)>0:
  557. maxid=cols_scores.index(max(cols_scores))
  558. etc['color']=cols[maxid]
  559. else:
  560. etc={'type':0,'plateImage':'','color':'green'}
  561. score_list=[ float(b[4]) for b in boxes ];
  562. if len(score_list)>0:
  563. maxid =score_list.index(max(score_list))
  564. box= boxes[maxid]
  565. pts=box[0:4]; pts=[int(x) for x in pts]; x0,y0,x1,y1=pts[0:4]
  566. per=plate_dilate
  567. x_delta = int((x1-x0)*per[0]/2); y_delta = int((y1-y0)*per[1]/2);
  568. #print( 'x0,y0,x1,y1,x_delta,y_delta:' , x0,y0,x1,y1,x_delta,y_delta,' h, w',h,w)
  569. x0 = max(x0- x_delta,0); x1 = min(x1+x_delta,w)
  570. y0 = max(y0-y_delta,0); y1 = min(y1+y_delta,h)
  571. #print( 'x0,y0,x1,y1,x_delta,y_delta:' , x0,y0,x1,y1,x_delta,y_delta,' h, w',h,w)
  572. image_corp = cv2.resize(img[y0:y1,x0:x1],None,fx=fx,fy=fx)
  573. etc['plateImage']= [image_corp,str(max(score_list))]
  574. etc['type'] = 3
  575. results.append(etc)
  576. return etc