基于Yolov7的路面病害检测代码
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

273 lines
11KB

  1. import numpy as np
  2. import random
  3. import torch
  4. import torch.nn as nn
  5. from models.common import Conv, DWConv
  6. from utils.google_utils import attempt_download
  7. class CrossConv(nn.Module):
  8. # Cross Convolution Downsample
  9. def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
  10. # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
  11. super(CrossConv, self).__init__()
  12. c_ = int(c2 * e) # hidden channels
  13. self.cv1 = Conv(c1, c_, (1, k), (1, s))
  14. self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
  15. self.add = shortcut and c1 == c2
  16. def forward(self, x):
  17. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  18. class Sum(nn.Module):
  19. # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
  20. def __init__(self, n, weight=False): # n: number of inputs
  21. super(Sum, self).__init__()
  22. self.weight = weight # apply weights boolean
  23. self.iter = range(n - 1) # iter object
  24. if weight:
  25. self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights
  26. def forward(self, x):
  27. y = x[0] # no weight
  28. if self.weight:
  29. w = torch.sigmoid(self.w) * 2
  30. for i in self.iter:
  31. y = y + x[i + 1] * w[i]
  32. else:
  33. for i in self.iter:
  34. y = y + x[i + 1]
  35. return y
  36. class MixConv2d(nn.Module):
  37. # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
  38. def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
  39. super(MixConv2d, self).__init__()
  40. groups = len(k)
  41. if equal_ch: # equal c_ per group
  42. i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
  43. c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
  44. else: # equal weight.numel() per group
  45. b = [c2] + [0] * groups
  46. a = np.eye(groups + 1, groups, k=-1)
  47. a -= np.roll(a, 1, axis=1)
  48. a *= np.array(k) ** 2
  49. a[0] = 1
  50. c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
  51. self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
  52. self.bn = nn.BatchNorm2d(c2)
  53. self.act = nn.LeakyReLU(0.1, inplace=True)
  54. def forward(self, x):
  55. return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
  56. class Ensemble(nn.ModuleList):
  57. # Ensemble of models
  58. def __init__(self):
  59. super(Ensemble, self).__init__()
  60. def forward(self, x, augment=False):
  61. y = []
  62. for module in self:
  63. y.append(module(x, augment)[0])
  64. # y = torch.stack(y).max(0)[0] # max ensemble
  65. # y = torch.stack(y).mean(0) # mean ensemble
  66. y = torch.cat(y, 1) # nms ensemble
  67. return y, None # inference, train output
  68. class ORT_NMS(torch.autograd.Function):
  69. '''ONNX-Runtime NMS operation'''
  70. @staticmethod
  71. def forward(ctx,
  72. boxes,
  73. scores,
  74. max_output_boxes_per_class=torch.tensor([100]),
  75. iou_threshold=torch.tensor([0.45]),
  76. score_threshold=torch.tensor([0.25])):
  77. device = boxes.device
  78. batch = scores.shape[0]
  79. num_det = random.randint(0, 100)
  80. batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
  81. idxs = torch.arange(100, 100 + num_det).to(device)
  82. zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
  83. selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
  84. selected_indices = selected_indices.to(torch.int64)
  85. return selected_indices
  86. @staticmethod
  87. def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
  88. return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
  89. class TRT_NMS(torch.autograd.Function):
  90. '''TensorRT NMS operation'''
  91. @staticmethod
  92. def forward(
  93. ctx,
  94. boxes,
  95. scores,
  96. background_class=-1,
  97. box_coding=1,
  98. iou_threshold=0.45,
  99. max_output_boxes=100,
  100. plugin_version="1",
  101. score_activation=0,
  102. score_threshold=0.25,
  103. ):
  104. batch_size, num_boxes, num_classes = scores.shape
  105. num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
  106. det_boxes = torch.randn(batch_size, max_output_boxes, 4)
  107. det_scores = torch.randn(batch_size, max_output_boxes)
  108. det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
  109. return num_det, det_boxes, det_scores, det_classes
  110. @staticmethod
  111. def symbolic(g,
  112. boxes,
  113. scores,
  114. background_class=-1,
  115. box_coding=1,
  116. iou_threshold=0.45,
  117. max_output_boxes=100,
  118. plugin_version="1",
  119. score_activation=0,
  120. score_threshold=0.25):
  121. out = g.op("TRT::EfficientNMS_TRT",
  122. boxes,
  123. scores,
  124. background_class_i=background_class,
  125. box_coding_i=box_coding,
  126. iou_threshold_f=iou_threshold,
  127. max_output_boxes_i=max_output_boxes,
  128. plugin_version_s=plugin_version,
  129. score_activation_i=score_activation,
  130. score_threshold_f=score_threshold,
  131. outputs=4)
  132. nums, boxes, scores, classes = out
  133. return nums, boxes, scores, classes
  134. class ONNX_ORT(nn.Module):
  135. '''onnx module with ONNX-Runtime NMS operation.'''
  136. def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80):
  137. super().__init__()
  138. self.device = device if device else torch.device("cpu")
  139. self.max_obj = torch.tensor([max_obj]).to(device)
  140. self.iou_threshold = torch.tensor([iou_thres]).to(device)
  141. self.score_threshold = torch.tensor([score_thres]).to(device)
  142. self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic
  143. self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
  144. dtype=torch.float32,
  145. device=self.device)
  146. self.n_classes=n_classes
  147. def forward(self, x):
  148. boxes = x[:, :, :4]
  149. conf = x[:, :, 4:5]
  150. scores = x[:, :, 5:]
  151. if self.n_classes == 1:
  152. scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
  153. # so there is no need to multiplicate.
  154. else:
  155. scores *= conf # conf = obj_conf * cls_conf
  156. boxes @= self.convert_matrix
  157. max_score, category_id = scores.max(2, keepdim=True)
  158. dis = category_id.float() * self.max_wh
  159. nmsbox = boxes + dis
  160. max_score_tp = max_score.transpose(1, 2).contiguous()
  161. selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold)
  162. X, Y = selected_indices[:, 0], selected_indices[:, 2]
  163. selected_boxes = boxes[X, Y, :]
  164. selected_categories = category_id[X, Y, :].float()
  165. selected_scores = max_score[X, Y, :]
  166. X = X.unsqueeze(1).float()
  167. return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1)
  168. class ONNX_TRT(nn.Module):
  169. '''onnx module with TensorRT NMS operation.'''
  170. def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
  171. super().__init__()
  172. assert max_wh is None
  173. self.device = device if device else torch.device('cpu')
  174. self.background_class = -1,
  175. self.box_coding = 1,
  176. self.iou_threshold = iou_thres
  177. self.max_obj = max_obj
  178. self.plugin_version = '1'
  179. self.score_activation = 0
  180. self.score_threshold = score_thres
  181. self.n_classes=n_classes
  182. def forward(self, x):
  183. boxes = x[:, :, :4]
  184. conf = x[:, :, 4:5]
  185. scores = x[:, :, 5:]
  186. if self.n_classes == 1:
  187. scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
  188. # so there is no need to multiplicate.
  189. else:
  190. scores *= conf # conf = obj_conf * cls_conf
  191. num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding,
  192. self.iou_threshold, self.max_obj,
  193. self.plugin_version, self.score_activation,
  194. self.score_threshold)
  195. return num_det, det_boxes, det_scores, det_classes
  196. class End2End(nn.Module):
  197. '''export onnx or tensorrt model with NMS operation.'''
  198. def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80):
  199. super().__init__()
  200. device = device if device else torch.device('cpu')
  201. assert isinstance(max_wh,(int)) or max_wh is None
  202. self.model = model.to(device)
  203. self.model.model[-1].end2end = True
  204. self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
  205. self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes)
  206. self.end2end.eval()
  207. def forward(self, x):
  208. x = self.model(x)
  209. x = self.end2end(x)
  210. return x
  211. def attempt_load(weights, map_location=None):
  212. # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
  213. model = Ensemble()
  214. for w in weights if isinstance(weights, list) else [weights]:
  215. attempt_download(w)
  216. ckpt = torch.load(w, map_location=map_location) # load
  217. model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model
  218. # Compatibility updates
  219. for m in model.modules():
  220. if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  221. m.inplace = True # pytorch 1.7.0 compatibility
  222. elif type(m) is nn.Upsample:
  223. m.recompute_scale_factor = None # torch 1.11.0 compatibility
  224. elif type(m) is Conv:
  225. m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
  226. if len(model) == 1:
  227. return model[-1] # return model
  228. else:
  229. print('Ensemble created with %s\n' % weights)
  230. for k in ['names', 'stride']:
  231. setattr(model, k, getattr(model[-1], k))
  232. return model # return ensemble