基于Yolov7的路面病害检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2020 satır
82KB

  1. import math
  2. from copy import copy
  3. from pathlib import Path
  4. import numpy as np
  5. import pandas as pd
  6. import requests
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from torchvision.ops import DeformConv2d
  11. from PIL import Image
  12. from torch.cuda import amp
  13. from utils.datasets import letterbox
  14. from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
  15. from utils.plots import color_list, plot_one_box
  16. from utils.torch_utils import time_synchronized
  17. ##### basic ####
  18. def autopad(k, p=None): # kernel, padding
  19. # Pad to 'same'
  20. if p is None:
  21. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  22. return p
  23. class MP(nn.Module):
  24. def __init__(self, k=2):
  25. super(MP, self).__init__()
  26. self.m = nn.MaxPool2d(kernel_size=k, stride=k)
  27. def forward(self, x):
  28. return self.m(x)
  29. class SP(nn.Module):
  30. def __init__(self, k=3, s=1):
  31. super(SP, self).__init__()
  32. self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)
  33. def forward(self, x):
  34. return self.m(x)
  35. class ReOrg(nn.Module):
  36. def __init__(self):
  37. super(ReOrg, self).__init__()
  38. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  39. return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
  40. class Concat(nn.Module):
  41. def __init__(self, dimension=1):
  42. super(Concat, self).__init__()
  43. self.d = dimension
  44. def forward(self, x):
  45. return torch.cat(x, self.d)
  46. class Chuncat(nn.Module):
  47. def __init__(self, dimension=1):
  48. super(Chuncat, self).__init__()
  49. self.d = dimension
  50. def forward(self, x):
  51. x1 = []
  52. x2 = []
  53. for xi in x:
  54. xi1, xi2 = xi.chunk(2, self.d)
  55. x1.append(xi1)
  56. x2.append(xi2)
  57. return torch.cat(x1+x2, self.d)
  58. class Shortcut(nn.Module):
  59. def __init__(self, dimension=0):
  60. super(Shortcut, self).__init__()
  61. self.d = dimension
  62. def forward(self, x):
  63. return x[0]+x[1]
  64. class Foldcut(nn.Module):
  65. def __init__(self, dimension=0):
  66. super(Foldcut, self).__init__()
  67. self.d = dimension
  68. def forward(self, x):
  69. x1, x2 = x.chunk(2, self.d)
  70. return x1+x2
  71. class Conv(nn.Module):
  72. # Standard convolution
  73. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  74. super(Conv, self).__init__()
  75. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  76. self.bn = nn.BatchNorm2d(c2)
  77. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  78. def forward(self, x):
  79. return self.act(self.bn(self.conv(x)))
  80. def fuseforward(self, x):
  81. return self.act(self.conv(x))
  82. class RobustConv(nn.Module):
  83. # Robust convolution (use high kernel size 7-11 for: downsampling and other layers). Train for 300 - 450 epochs.
  84. def __init__(self, c1, c2, k=7, s=1, p=None, g=1, act=True, layer_scale_init_value=1e-6): # ch_in, ch_out, kernel, stride, padding, groups
  85. super(RobustConv, self).__init__()
  86. self.conv_dw = Conv(c1, c1, k=k, s=s, p=p, g=c1, act=act)
  87. self.conv1x1 = nn.Conv2d(c1, c2, 1, 1, 0, groups=1, bias=True)
  88. self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c2)) if layer_scale_init_value > 0 else None
  89. def forward(self, x):
  90. x = x.to(memory_format=torch.channels_last)
  91. x = self.conv1x1(self.conv_dw(x))
  92. if self.gamma is not None:
  93. x = x.mul(self.gamma.reshape(1, -1, 1, 1))
  94. return x
  95. class RobustConv2(nn.Module):
  96. # Robust convolution 2 (use [32, 5, 2] or [32, 7, 4] or [32, 11, 8] for one of the paths in CSP).
  97. def __init__(self, c1, c2, k=7, s=4, p=None, g=1, act=True, layer_scale_init_value=1e-6): # ch_in, ch_out, kernel, stride, padding, groups
  98. super(RobustConv2, self).__init__()
  99. self.conv_strided = Conv(c1, c1, k=k, s=s, p=p, g=c1, act=act)
  100. self.conv_deconv = nn.ConvTranspose2d(in_channels=c1, out_channels=c2, kernel_size=s, stride=s,
  101. padding=0, bias=True, dilation=1, groups=1
  102. )
  103. self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(c2)) if layer_scale_init_value > 0 else None
  104. def forward(self, x):
  105. x = self.conv_deconv(self.conv_strided(x))
  106. if self.gamma is not None:
  107. x = x.mul(self.gamma.reshape(1, -1, 1, 1))
  108. return x
  109. def DWConv(c1, c2, k=1, s=1, act=True):
  110. # Depthwise convolution
  111. return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  112. class GhostConv(nn.Module):
  113. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  114. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  115. super(GhostConv, self).__init__()
  116. c_ = c2 // 2 # hidden channels
  117. self.cv1 = Conv(c1, c_, k, s, None, g, act)
  118. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
  119. def forward(self, x):
  120. y = self.cv1(x)
  121. return torch.cat([y, self.cv2(y)], 1)
  122. class Stem(nn.Module):
  123. # Stem
  124. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  125. super(Stem, self).__init__()
  126. c_ = int(c2/2) # hidden channels
  127. self.cv1 = Conv(c1, c_, 3, 2)
  128. self.cv2 = Conv(c_, c_, 1, 1)
  129. self.cv3 = Conv(c_, c_, 3, 2)
  130. self.pool = torch.nn.MaxPool2d(2, stride=2)
  131. self.cv4 = Conv(2 * c_, c2, 1, 1)
  132. def forward(self, x):
  133. x = self.cv1(x)
  134. return self.cv4(torch.cat((self.cv3(self.cv2(x)), self.pool(x)), dim=1))
  135. class DownC(nn.Module):
  136. # Spatial pyramid pooling layer used in YOLOv3-SPP
  137. def __init__(self, c1, c2, n=1, k=2):
  138. super(DownC, self).__init__()
  139. c_ = int(c1) # hidden channels
  140. self.cv1 = Conv(c1, c_, 1, 1)
  141. self.cv2 = Conv(c_, c2//2, 3, k)
  142. self.cv3 = Conv(c1, c2//2, 1, 1)
  143. self.mp = nn.MaxPool2d(kernel_size=k, stride=k)
  144. def forward(self, x):
  145. return torch.cat((self.cv2(self.cv1(x)), self.cv3(self.mp(x))), dim=1)
  146. class SPP(nn.Module):
  147. # Spatial pyramid pooling layer used in YOLOv3-SPP
  148. def __init__(self, c1, c2, k=(5, 9, 13)):
  149. super(SPP, self).__init__()
  150. c_ = c1 // 2 # hidden channels
  151. self.cv1 = Conv(c1, c_, 1, 1)
  152. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  153. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  154. def forward(self, x):
  155. x = self.cv1(x)
  156. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  157. class Bottleneck(nn.Module):
  158. # Darknet bottleneck
  159. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  160. super(Bottleneck, self).__init__()
  161. c_ = int(c2 * e) # hidden channels
  162. self.cv1 = Conv(c1, c_, 1, 1)
  163. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  164. self.add = shortcut and c1 == c2
  165. def forward(self, x):
  166. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  167. class Res(nn.Module):
  168. # ResNet bottleneck
  169. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  170. super(Res, self).__init__()
  171. c_ = int(c2 * e) # hidden channels
  172. self.cv1 = Conv(c1, c_, 1, 1)
  173. self.cv2 = Conv(c_, c_, 3, 1, g=g)
  174. self.cv3 = Conv(c_, c2, 1, 1)
  175. self.add = shortcut and c1 == c2
  176. def forward(self, x):
  177. return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
  178. class ResX(Res):
  179. # ResNet bottleneck
  180. def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  181. super().__init__(c1, c2, shortcut, g, e)
  182. c_ = int(c2 * e) # hidden channels
  183. class Ghost(nn.Module):
  184. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  185. def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
  186. super(Ghost, self).__init__()
  187. c_ = c2 // 2
  188. self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
  189. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  190. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  191. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
  192. Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  193. def forward(self, x):
  194. return self.conv(x) + self.shortcut(x)
  195. ##### end of basic #####
  196. ##### cspnet #####
  197. class SPPCSPC(nn.Module):
  198. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  199. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
  200. super(SPPCSPC, self).__init__()
  201. c_ = int(2 * c2 * e) # hidden channels
  202. self.cv1 = Conv(c1, c_, 1, 1)
  203. self.cv2 = Conv(c1, c_, 1, 1)
  204. self.cv3 = Conv(c_, c_, 3, 1)
  205. self.cv4 = Conv(c_, c_, 1, 1)
  206. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  207. self.cv5 = Conv(4 * c_, c_, 1, 1)
  208. self.cv6 = Conv(c_, c_, 3, 1)
  209. self.cv7 = Conv(2 * c_, c2, 1, 1)
  210. def forward(self, x):
  211. x1 = self.cv4(self.cv3(self.cv1(x)))
  212. y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
  213. y2 = self.cv2(x)
  214. return self.cv7(torch.cat((y1, y2), dim=1))
  215. class GhostSPPCSPC(SPPCSPC):
  216. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  217. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
  218. super().__init__(c1, c2, n, shortcut, g, e, k)
  219. c_ = int(2 * c2 * e) # hidden channels
  220. self.cv1 = GhostConv(c1, c_, 1, 1)
  221. self.cv2 = GhostConv(c1, c_, 1, 1)
  222. self.cv3 = GhostConv(c_, c_, 3, 1)
  223. self.cv4 = GhostConv(c_, c_, 1, 1)
  224. self.cv5 = GhostConv(4 * c_, c_, 1, 1)
  225. self.cv6 = GhostConv(c_, c_, 3, 1)
  226. self.cv7 = GhostConv(2 * c_, c2, 1, 1)
  227. class GhostStem(Stem):
  228. # Stem
  229. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  230. super().__init__(c1, c2, k, s, p, g, act)
  231. c_ = int(c2/2) # hidden channels
  232. self.cv1 = GhostConv(c1, c_, 3, 2)
  233. self.cv2 = GhostConv(c_, c_, 1, 1)
  234. self.cv3 = GhostConv(c_, c_, 3, 2)
  235. self.cv4 = GhostConv(2 * c_, c2, 1, 1)
  236. class BottleneckCSPA(nn.Module):
  237. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  238. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  239. super(BottleneckCSPA, self).__init__()
  240. c_ = int(c2 * e) # hidden channels
  241. self.cv1 = Conv(c1, c_, 1, 1)
  242. self.cv2 = Conv(c1, c_, 1, 1)
  243. self.cv3 = Conv(2 * c_, c2, 1, 1)
  244. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  245. def forward(self, x):
  246. y1 = self.m(self.cv1(x))
  247. y2 = self.cv2(x)
  248. return self.cv3(torch.cat((y1, y2), dim=1))
  249. class BottleneckCSPB(nn.Module):
  250. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  251. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  252. super(BottleneckCSPB, self).__init__()
  253. c_ = int(c2) # hidden channels
  254. self.cv1 = Conv(c1, c_, 1, 1)
  255. self.cv2 = Conv(c_, c_, 1, 1)
  256. self.cv3 = Conv(2 * c_, c2, 1, 1)
  257. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  258. def forward(self, x):
  259. x1 = self.cv1(x)
  260. y1 = self.m(x1)
  261. y2 = self.cv2(x1)
  262. return self.cv3(torch.cat((y1, y2), dim=1))
  263. class BottleneckCSPC(nn.Module):
  264. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  265. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  266. super(BottleneckCSPC, self).__init__()
  267. c_ = int(c2 * e) # hidden channels
  268. self.cv1 = Conv(c1, c_, 1, 1)
  269. self.cv2 = Conv(c1, c_, 1, 1)
  270. self.cv3 = Conv(c_, c_, 1, 1)
  271. self.cv4 = Conv(2 * c_, c2, 1, 1)
  272. self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  273. def forward(self, x):
  274. y1 = self.cv3(self.m(self.cv1(x)))
  275. y2 = self.cv2(x)
  276. return self.cv4(torch.cat((y1, y2), dim=1))
  277. class ResCSPA(BottleneckCSPA):
  278. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  279. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  280. super().__init__(c1, c2, n, shortcut, g, e)
  281. c_ = int(c2 * e) # hidden channels
  282. self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
  283. class ResCSPB(BottleneckCSPB):
  284. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  285. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  286. super().__init__(c1, c2, n, shortcut, g, e)
  287. c_ = int(c2) # hidden channels
  288. self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
  289. class ResCSPC(BottleneckCSPC):
  290. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  291. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  292. super().__init__(c1, c2, n, shortcut, g, e)
  293. c_ = int(c2 * e) # hidden channels
  294. self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
  295. class ResXCSPA(ResCSPA):
  296. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  297. def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  298. super().__init__(c1, c2, n, shortcut, g, e)
  299. c_ = int(c2 * e) # hidden channels
  300. self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  301. class ResXCSPB(ResCSPB):
  302. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  303. def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  304. super().__init__(c1, c2, n, shortcut, g, e)
  305. c_ = int(c2) # hidden channels
  306. self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  307. class ResXCSPC(ResCSPC):
  308. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  309. def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  310. super().__init__(c1, c2, n, shortcut, g, e)
  311. c_ = int(c2 * e) # hidden channels
  312. self.m = nn.Sequential(*[Res(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  313. class GhostCSPA(BottleneckCSPA):
  314. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  315. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  316. super().__init__(c1, c2, n, shortcut, g, e)
  317. c_ = int(c2 * e) # hidden channels
  318. self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
  319. class GhostCSPB(BottleneckCSPB):
  320. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  321. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  322. super().__init__(c1, c2, n, shortcut, g, e)
  323. c_ = int(c2) # hidden channels
  324. self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
  325. class GhostCSPC(BottleneckCSPC):
  326. # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks
  327. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  328. super().__init__(c1, c2, n, shortcut, g, e)
  329. c_ = int(c2 * e) # hidden channels
  330. self.m = nn.Sequential(*[Ghost(c_, c_) for _ in range(n)])
  331. ##### end of cspnet #####
  332. ##### yolor #####
  333. class ImplicitA(nn.Module):
  334. def __init__(self, channel, mean=0., std=.02):
  335. super(ImplicitA, self).__init__()
  336. self.channel = channel
  337. self.mean = mean
  338. self.std = std
  339. self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
  340. nn.init.normal_(self.implicit, mean=self.mean, std=self.std)
  341. def forward(self, x):
  342. return self.implicit + x
  343. class ImplicitM(nn.Module):
  344. def __init__(self, channel, mean=1., std=.02):
  345. super(ImplicitM, self).__init__()
  346. self.channel = channel
  347. self.mean = mean
  348. self.std = std
  349. self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
  350. nn.init.normal_(self.implicit, mean=self.mean, std=self.std)
  351. def forward(self, x):
  352. return self.implicit * x
  353. ##### end of yolor #####
  354. ##### repvgg #####
  355. class RepConv(nn.Module):
  356. # Represented convolution
  357. # https://arxiv.org/abs/2101.03697
  358. def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True, deploy=False):
  359. super(RepConv, self).__init__()
  360. self.deploy = deploy
  361. self.groups = g
  362. self.in_channels = c1
  363. self.out_channels = c2
  364. assert k == 3
  365. assert autopad(k, p) == 1
  366. padding_11 = autopad(k, p) - k // 2
  367. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  368. if deploy:
  369. self.rbr_reparam = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=True)
  370. else:
  371. self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None)
  372. self.rbr_dense = nn.Sequential(
  373. nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False),
  374. nn.BatchNorm2d(num_features=c2),
  375. )
  376. self.rbr_1x1 = nn.Sequential(
  377. nn.Conv2d( c1, c2, 1, s, padding_11, groups=g, bias=False),
  378. nn.BatchNorm2d(num_features=c2),
  379. )
  380. def forward(self, inputs):
  381. if hasattr(self, "rbr_reparam"):
  382. return self.act(self.rbr_reparam(inputs))
  383. if self.rbr_identity is None:
  384. id_out = 0
  385. else:
  386. id_out = self.rbr_identity(inputs)
  387. return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
  388. def get_equivalent_kernel_bias(self):
  389. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
  390. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
  391. kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
  392. return (
  393. kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
  394. bias3x3 + bias1x1 + biasid,
  395. )
  396. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  397. if kernel1x1 is None:
  398. return 0
  399. else:
  400. return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  401. def _fuse_bn_tensor(self, branch):
  402. if branch is None:
  403. return 0, 0
  404. if isinstance(branch, nn.Sequential):
  405. kernel = branch[0].weight
  406. running_mean = branch[1].running_mean
  407. running_var = branch[1].running_var
  408. gamma = branch[1].weight
  409. beta = branch[1].bias
  410. eps = branch[1].eps
  411. else:
  412. assert isinstance(branch, nn.BatchNorm2d)
  413. if not hasattr(self, "id_tensor"):
  414. input_dim = self.in_channels // self.groups
  415. kernel_value = np.zeros(
  416. (self.in_channels, input_dim, 3, 3), dtype=np.float32
  417. )
  418. for i in range(self.in_channels):
  419. kernel_value[i, i % input_dim, 1, 1] = 1
  420. self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
  421. kernel = self.id_tensor
  422. running_mean = branch.running_mean
  423. running_var = branch.running_var
  424. gamma = branch.weight
  425. beta = branch.bias
  426. eps = branch.eps
  427. std = (running_var + eps).sqrt()
  428. t = (gamma / std).reshape(-1, 1, 1, 1)
  429. return kernel * t, beta - running_mean * gamma / std
  430. def repvgg_convert(self):
  431. kernel, bias = self.get_equivalent_kernel_bias()
  432. return (
  433. kernel.detach().cpu().numpy(),
  434. bias.detach().cpu().numpy(),
  435. )
  436. def fuse_conv_bn(self, conv, bn):
  437. std = (bn.running_var + bn.eps).sqrt()
  438. bias = bn.bias - bn.running_mean * bn.weight / std
  439. t = (bn.weight / std).reshape(-1, 1, 1, 1)
  440. weights = conv.weight * t
  441. bn = nn.Identity()
  442. conv = nn.Conv2d(in_channels = conv.in_channels,
  443. out_channels = conv.out_channels,
  444. kernel_size = conv.kernel_size,
  445. stride=conv.stride,
  446. padding = conv.padding,
  447. dilation = conv.dilation,
  448. groups = conv.groups,
  449. bias = True,
  450. padding_mode = conv.padding_mode)
  451. conv.weight = torch.nn.Parameter(weights)
  452. conv.bias = torch.nn.Parameter(bias)
  453. return conv
  454. def fuse_repvgg_block(self):
  455. if self.deploy:
  456. return
  457. print(f"RepConv.fuse_repvgg_block")
  458. self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
  459. self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
  460. rbr_1x1_bias = self.rbr_1x1.bias
  461. weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])
  462. # Fuse self.rbr_identity
  463. if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
  464. # print(f"fuse: rbr_identity == BatchNorm2d or SyncBatchNorm")
  465. identity_conv_1x1 = nn.Conv2d(
  466. in_channels=self.in_channels,
  467. out_channels=self.out_channels,
  468. kernel_size=1,
  469. stride=1,
  470. padding=0,
  471. groups=self.groups,
  472. bias=False)
  473. identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
  474. identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
  475. # print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}")
  476. identity_conv_1x1.weight.data.fill_(0.0)
  477. identity_conv_1x1.weight.data.fill_diagonal_(1.0)
  478. identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)
  479. # print(f" identity_conv_1x1.weight = {identity_conv_1x1.weight.shape}")
  480. identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
  481. bias_identity_expanded = identity_conv_1x1.bias
  482. weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])
  483. else:
  484. # print(f"fuse: rbr_identity != BatchNorm2d, rbr_identity = {self.rbr_identity}")
  485. bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
  486. weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )
  487. #print(f"self.rbr_1x1.weight = {self.rbr_1x1.weight.shape}, ")
  488. #print(f"weight_1x1_expanded = {weight_1x1_expanded.shape}, ")
  489. #print(f"self.rbr_dense.weight = {self.rbr_dense.weight.shape}, ")
  490. self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
  491. self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
  492. self.rbr_reparam = self.rbr_dense
  493. self.deploy = True
  494. if self.rbr_identity is not None:
  495. del self.rbr_identity
  496. self.rbr_identity = None
  497. if self.rbr_1x1 is not None:
  498. del self.rbr_1x1
  499. self.rbr_1x1 = None
  500. if self.rbr_dense is not None:
  501. del self.rbr_dense
  502. self.rbr_dense = None
  503. class RepBottleneck(Bottleneck):
  504. # Standard bottleneck
  505. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  506. super().__init__(c1, c2, shortcut=True, g=1, e=0.5)
  507. c_ = int(c2 * e) # hidden channels
  508. self.cv2 = RepConv(c_, c2, 3, 1, g=g)
  509. class RepBottleneckCSPA(BottleneckCSPA):
  510. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  511. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  512. super().__init__(c1, c2, n, shortcut, g, e)
  513. c_ = int(c2 * e) # hidden channels
  514. self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  515. class RepBottleneckCSPB(BottleneckCSPB):
  516. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  517. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  518. super().__init__(c1, c2, n, shortcut, g, e)
  519. c_ = int(c2) # hidden channels
  520. self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  521. class RepBottleneckCSPC(BottleneckCSPC):
  522. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  523. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  524. super().__init__(c1, c2, n, shortcut, g, e)
  525. c_ = int(c2 * e) # hidden channels
  526. self.m = nn.Sequential(*[RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  527. class RepRes(Res):
  528. # Standard bottleneck
  529. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  530. super().__init__(c1, c2, shortcut, g, e)
  531. c_ = int(c2 * e) # hidden channels
  532. self.cv2 = RepConv(c_, c_, 3, 1, g=g)
  533. class RepResCSPA(ResCSPA):
  534. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  535. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  536. super().__init__(c1, c2, n, shortcut, g, e)
  537. c_ = int(c2 * e) # hidden channels
  538. self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
  539. class RepResCSPB(ResCSPB):
  540. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  541. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  542. super().__init__(c1, c2, n, shortcut, g, e)
  543. c_ = int(c2) # hidden channels
  544. self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
  545. class RepResCSPC(ResCSPC):
  546. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  547. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  548. super().__init__(c1, c2, n, shortcut, g, e)
  549. c_ = int(c2 * e) # hidden channels
  550. self.m = nn.Sequential(*[RepRes(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
  551. class RepResX(ResX):
  552. # Standard bottleneck
  553. def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  554. super().__init__(c1, c2, shortcut, g, e)
  555. c_ = int(c2 * e) # hidden channels
  556. self.cv2 = RepConv(c_, c_, 3, 1, g=g)
  557. class RepResXCSPA(ResXCSPA):
  558. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  559. def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  560. super().__init__(c1, c2, n, shortcut, g, e)
  561. c_ = int(c2 * e) # hidden channels
  562. self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
  563. class RepResXCSPB(ResXCSPB):
  564. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  565. def __init__(self, c1, c2, n=1, shortcut=False, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  566. super().__init__(c1, c2, n, shortcut, g, e)
  567. c_ = int(c2) # hidden channels
  568. self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
  569. class RepResXCSPC(ResXCSPC):
  570. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  571. def __init__(self, c1, c2, n=1, shortcut=True, g=32, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  572. super().__init__(c1, c2, n, shortcut, g, e)
  573. c_ = int(c2 * e) # hidden channels
  574. self.m = nn.Sequential(*[RepResX(c_, c_, shortcut, g, e=0.5) for _ in range(n)])
  575. ##### end of repvgg #####
  576. ##### transformer #####
  577. class TransformerLayer(nn.Module):
  578. # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
  579. def __init__(self, c, num_heads):
  580. super().__init__()
  581. self.q = nn.Linear(c, c, bias=False)
  582. self.k = nn.Linear(c, c, bias=False)
  583. self.v = nn.Linear(c, c, bias=False)
  584. self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
  585. self.fc1 = nn.Linear(c, c, bias=False)
  586. self.fc2 = nn.Linear(c, c, bias=False)
  587. def forward(self, x):
  588. x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
  589. x = self.fc2(self.fc1(x)) + x
  590. return x
  591. class TransformerBlock(nn.Module):
  592. # Vision Transformer https://arxiv.org/abs/2010.11929
  593. def __init__(self, c1, c2, num_heads, num_layers):
  594. super().__init__()
  595. self.conv = None
  596. if c1 != c2:
  597. self.conv = Conv(c1, c2)
  598. self.linear = nn.Linear(c2, c2) # learnable position embedding
  599. self.tr = nn.Sequential(*[TransformerLayer(c2, num_heads) for _ in range(num_layers)])
  600. self.c2 = c2
  601. def forward(self, x):
  602. if self.conv is not None:
  603. x = self.conv(x)
  604. b, _, w, h = x.shape
  605. p = x.flatten(2)
  606. p = p.unsqueeze(0)
  607. p = p.transpose(0, 3)
  608. p = p.squeeze(3)
  609. e = self.linear(p)
  610. x = p + e
  611. x = self.tr(x)
  612. x = x.unsqueeze(3)
  613. x = x.transpose(0, 3)
  614. x = x.reshape(b, self.c2, w, h)
  615. return x
  616. ##### end of transformer #####
  617. ##### yolov5 #####
  618. class Focus(nn.Module):
  619. # Focus wh information into c-space
  620. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  621. super(Focus, self).__init__()
  622. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  623. # self.contract = Contract(gain=2)
  624. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  625. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  626. # return self.conv(self.contract(x))
  627. class SPPF(nn.Module):
  628. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
  629. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  630. super().__init__()
  631. c_ = c1 // 2 # hidden channels
  632. self.cv1 = Conv(c1, c_, 1, 1)
  633. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  634. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  635. def forward(self, x):
  636. x = self.cv1(x)
  637. y1 = self.m(x)
  638. y2 = self.m(y1)
  639. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
  640. class Contract(nn.Module):
  641. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  642. def __init__(self, gain=2):
  643. super().__init__()
  644. self.gain = gain
  645. def forward(self, x):
  646. N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain'
  647. s = self.gain
  648. x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2)
  649. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  650. return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40)
  651. class Expand(nn.Module):
  652. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  653. def __init__(self, gain=2):
  654. super().__init__()
  655. self.gain = gain
  656. def forward(self, x):
  657. N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  658. s = self.gain
  659. x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80)
  660. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  661. return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160)
  662. class NMS(nn.Module):
  663. # Non-Maximum Suppression (NMS) module
  664. conf = 0.25 # confidence threshold
  665. iou = 0.45 # IoU threshold
  666. classes = None # (optional list) filter by class
  667. def __init__(self):
  668. super(NMS, self).__init__()
  669. def forward(self, x):
  670. return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
  671. class autoShape(nn.Module):
  672. # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  673. conf = 0.25 # NMS confidence threshold
  674. iou = 0.45 # NMS IoU threshold
  675. classes = None # (optional list) filter by class
  676. def __init__(self, model):
  677. super(autoShape, self).__init__()
  678. self.model = model.eval()
  679. def autoshape(self):
  680. print('autoShape already enabled, skipping... ') # model already converted to model.autoshape()
  681. return self
  682. @torch.no_grad()
  683. def forward(self, imgs, size=640, augment=False, profile=False):
  684. # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
  685. # filename: imgs = 'data/samples/zidane.jpg'
  686. # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg'
  687. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
  688. # PIL: = Image.open('image.jpg') # HWC x(640,1280,3)
  689. # numpy: = np.zeros((640,1280,3)) # HWC
  690. # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
  691. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  692. t = [time_synchronized()]
  693. p = next(self.model.parameters()) # for device and type
  694. if isinstance(imgs, torch.Tensor): # torch
  695. with amp.autocast(enabled=p.device.type != 'cpu'):
  696. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  697. # Pre-process
  698. n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
  699. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  700. for i, im in enumerate(imgs):
  701. f = f'image{i}' # filename
  702. if isinstance(im, str): # filename or uri
  703. im, f = np.asarray(Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im)), im
  704. elif isinstance(im, Image.Image): # PIL Image
  705. im, f = np.asarray(im), getattr(im, 'filename', f) or f
  706. files.append(Path(f).with_suffix('.jpg').name)
  707. if im.shape[0] < 5: # image in CHW
  708. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  709. im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
  710. s = im.shape[:2] # HWC
  711. shape0.append(s) # image shape
  712. g = (size / max(s)) # gain
  713. shape1.append([y * g for y in s])
  714. imgs[i] = im # update
  715. shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
  716. x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
  717. x = np.stack(x, 0) if n > 1 else x[0][None] # stack
  718. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  719. x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
  720. t.append(time_synchronized())
  721. with amp.autocast(enabled=p.device.type != 'cpu'):
  722. # Inference
  723. y = self.model(x, augment, profile)[0] # forward
  724. t.append(time_synchronized())
  725. # Post-process
  726. y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
  727. for i in range(n):
  728. scale_coords(shape1, y[i][:, :4], shape0[i])
  729. t.append(time_synchronized())
  730. return Detections(imgs, y, files, t, self.names, x.shape)
  731. class Detections:
  732. # detections class for YOLOv5 inference results
  733. def __init__(self, imgs, pred, files, times=None, names=None, shape=None):
  734. super(Detections, self).__init__()
  735. d = pred[0].device # device
  736. gn = [torch.tensor([*[im.shape[i] for i in [1, 0, 1, 0]], 1., 1.], device=d) for im in imgs] # normalizations
  737. self.imgs = imgs # list of images as numpy arrays
  738. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  739. self.names = names # class names
  740. self.files = files # image filenames
  741. self.xyxy = pred # xyxy pixels
  742. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  743. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  744. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  745. self.n = len(self.pred) # number of images (batch size)
  746. self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
  747. self.s = shape # inference BCHW shape
  748. def display(self, pprint=False, show=False, save=False, render=False, save_dir=''):
  749. colors = color_list()
  750. for i, (img, pred) in enumerate(zip(self.imgs, self.pred)):
  751. str = f'image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} '
  752. if pred is not None:
  753. for c in pred[:, -1].unique():
  754. n = (pred[:, -1] == c).sum() # detections per class
  755. str += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  756. if show or save or render:
  757. for *box, conf, cls in pred: # xyxy, confidence, class
  758. label = f'{self.names[int(cls)]} {conf:.2f}'
  759. plot_one_box(box, img, label=label, color=colors[int(cls) % 10])
  760. img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np
  761. if pprint:
  762. print(str.rstrip(', '))
  763. if show:
  764. img.show(self.files[i]) # show
  765. if save:
  766. f = self.files[i]
  767. img.save(Path(save_dir) / f) # save
  768. print(f"{'Saved' * (i == 0)} {f}", end=',' if i < self.n - 1 else f' to {save_dir}\n')
  769. if render:
  770. self.imgs[i] = np.asarray(img)
  771. def print(self):
  772. self.display(pprint=True) # print results
  773. print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' % self.t)
  774. def show(self):
  775. self.display(show=True) # show results
  776. def save(self, save_dir='runs/hub/exp'):
  777. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/hub/exp') # increment save_dir
  778. Path(save_dir).mkdir(parents=True, exist_ok=True)
  779. self.display(save=True, save_dir=save_dir) # save results
  780. def render(self):
  781. self.display(render=True) # render results
  782. return self.imgs
  783. def pandas(self):
  784. # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
  785. new = copy(self) # return copy
  786. ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
  787. cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
  788. for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
  789. a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
  790. setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
  791. return new
  792. def tolist(self):
  793. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  794. x = [Detections([self.imgs[i]], [self.pred[i]], self.names, self.s) for i in range(self.n)]
  795. for d in x:
  796. for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  797. setattr(d, k, getattr(d, k)[0]) # pop out of list
  798. return x
  799. def __len__(self):
  800. return self.n
  801. class Classify(nn.Module):
  802. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  803. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  804. super(Classify, self).__init__()
  805. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  806. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  807. self.flat = nn.Flatten()
  808. def forward(self, x):
  809. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  810. return self.flat(self.conv(z)) # flatten to x(b,c2)
  811. ##### end of yolov5 ######
  812. ##### orepa #####
  813. def transI_fusebn(kernel, bn):
  814. gamma = bn.weight
  815. std = (bn.running_var + bn.eps).sqrt()
  816. return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std
  817. class ConvBN(nn.Module):
  818. def __init__(self, in_channels, out_channels, kernel_size,
  819. stride=1, padding=0, dilation=1, groups=1, deploy=False, nonlinear=None):
  820. super().__init__()
  821. if nonlinear is None:
  822. self.nonlinear = nn.Identity()
  823. else:
  824. self.nonlinear = nonlinear
  825. if deploy:
  826. self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
  827. stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)
  828. else:
  829. self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
  830. stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)
  831. self.bn = nn.BatchNorm2d(num_features=out_channels)
  832. def forward(self, x):
  833. if hasattr(self, 'bn'):
  834. return self.nonlinear(self.bn(self.conv(x)))
  835. else:
  836. return self.nonlinear(self.conv(x))
  837. def switch_to_deploy(self):
  838. kernel, bias = transI_fusebn(self.conv.weight, self.bn)
  839. conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels, kernel_size=self.conv.kernel_size,
  840. stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups, bias=True)
  841. conv.weight.data = kernel
  842. conv.bias.data = bias
  843. for para in self.parameters():
  844. para.detach_()
  845. self.__delattr__('conv')
  846. self.__delattr__('bn')
  847. self.conv = conv
  848. class OREPA_3x3_RepConv(nn.Module):
  849. def __init__(self, in_channels, out_channels, kernel_size,
  850. stride=1, padding=0, dilation=1, groups=1,
  851. internal_channels_1x1_3x3=None,
  852. deploy=False, nonlinear=None, single_init=False):
  853. super(OREPA_3x3_RepConv, self).__init__()
  854. self.deploy = deploy
  855. if nonlinear is None:
  856. self.nonlinear = nn.Identity()
  857. else:
  858. self.nonlinear = nonlinear
  859. self.kernel_size = kernel_size
  860. self.in_channels = in_channels
  861. self.out_channels = out_channels
  862. self.groups = groups
  863. assert padding == kernel_size // 2
  864. self.stride = stride
  865. self.padding = padding
  866. self.dilation = dilation
  867. self.branch_counter = 0
  868. self.weight_rbr_origin = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), kernel_size, kernel_size))
  869. nn.init.kaiming_uniform_(self.weight_rbr_origin, a=math.sqrt(1.0))
  870. self.branch_counter += 1
  871. if groups < out_channels:
  872. self.weight_rbr_avg_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
  873. self.weight_rbr_pfir_conv = nn.Parameter(torch.Tensor(out_channels, int(in_channels/self.groups), 1, 1))
  874. nn.init.kaiming_uniform_(self.weight_rbr_avg_conv, a=1.0)
  875. nn.init.kaiming_uniform_(self.weight_rbr_pfir_conv, a=1.0)
  876. self.weight_rbr_avg_conv.data
  877. self.weight_rbr_pfir_conv.data
  878. self.register_buffer('weight_rbr_avg_avg', torch.ones(kernel_size, kernel_size).mul(1.0/kernel_size/kernel_size))
  879. self.branch_counter += 1
  880. else:
  881. raise NotImplementedError
  882. self.branch_counter += 1
  883. if internal_channels_1x1_3x3 is None:
  884. internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels
  885. if internal_channels_1x1_3x3 == in_channels:
  886. self.weight_rbr_1x1_kxk_idconv1 = nn.Parameter(torch.zeros(in_channels, int(in_channels/self.groups), 1, 1))
  887. id_value = np.zeros((in_channels, int(in_channels/self.groups), 1, 1))
  888. for i in range(in_channels):
  889. id_value[i, i % int(in_channels/self.groups), 0, 0] = 1
  890. id_tensor = torch.from_numpy(id_value).type_as(self.weight_rbr_1x1_kxk_idconv1)
  891. self.register_buffer('id_tensor', id_tensor)
  892. else:
  893. self.weight_rbr_1x1_kxk_conv1 = nn.Parameter(torch.Tensor(internal_channels_1x1_3x3, int(in_channels/self.groups), 1, 1))
  894. nn.init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv1, a=math.sqrt(1.0))
  895. self.weight_rbr_1x1_kxk_conv2 = nn.Parameter(torch.Tensor(out_channels, int(internal_channels_1x1_3x3/self.groups), kernel_size, kernel_size))
  896. nn.init.kaiming_uniform_(self.weight_rbr_1x1_kxk_conv2, a=math.sqrt(1.0))
  897. self.branch_counter += 1
  898. expand_ratio = 8
  899. self.weight_rbr_gconv_dw = nn.Parameter(torch.Tensor(in_channels*expand_ratio, 1, kernel_size, kernel_size))
  900. self.weight_rbr_gconv_pw = nn.Parameter(torch.Tensor(out_channels, in_channels*expand_ratio, 1, 1))
  901. nn.init.kaiming_uniform_(self.weight_rbr_gconv_dw, a=math.sqrt(1.0))
  902. nn.init.kaiming_uniform_(self.weight_rbr_gconv_pw, a=math.sqrt(1.0))
  903. self.branch_counter += 1
  904. if out_channels == in_channels and stride == 1:
  905. self.branch_counter += 1
  906. self.vector = nn.Parameter(torch.Tensor(self.branch_counter, self.out_channels))
  907. self.bn = nn.BatchNorm2d(out_channels)
  908. self.fre_init()
  909. nn.init.constant_(self.vector[0, :], 0.25) #origin
  910. nn.init.constant_(self.vector[1, :], 0.25) #avg
  911. nn.init.constant_(self.vector[2, :], 0.0) #prior
  912. nn.init.constant_(self.vector[3, :], 0.5) #1x1_kxk
  913. nn.init.constant_(self.vector[4, :], 0.5) #dws_conv
  914. def fre_init(self):
  915. prior_tensor = torch.Tensor(self.out_channels, self.kernel_size, self.kernel_size)
  916. half_fg = self.out_channels/2
  917. for i in range(self.out_channels):
  918. for h in range(3):
  919. for w in range(3):
  920. if i < half_fg:
  921. prior_tensor[i, h, w] = math.cos(math.pi*(h+0.5)*(i+1)/3)
  922. else:
  923. prior_tensor[i, h, w] = math.cos(math.pi*(w+0.5)*(i+1-half_fg)/3)
  924. self.register_buffer('weight_rbr_prior', prior_tensor)
  925. def weight_gen(self):
  926. weight_rbr_origin = torch.einsum('oihw,o->oihw', self.weight_rbr_origin, self.vector[0, :])
  927. weight_rbr_avg = torch.einsum('oihw,o->oihw', torch.einsum('oihw,hw->oihw', self.weight_rbr_avg_conv, self.weight_rbr_avg_avg), self.vector[1, :])
  928. weight_rbr_pfir = torch.einsum('oihw,o->oihw', torch.einsum('oihw,ohw->oihw', self.weight_rbr_pfir_conv, self.weight_rbr_prior), self.vector[2, :])
  929. weight_rbr_1x1_kxk_conv1 = None
  930. if hasattr(self, 'weight_rbr_1x1_kxk_idconv1'):
  931. weight_rbr_1x1_kxk_conv1 = (self.weight_rbr_1x1_kxk_idconv1 + self.id_tensor).squeeze()
  932. elif hasattr(self, 'weight_rbr_1x1_kxk_conv1'):
  933. weight_rbr_1x1_kxk_conv1 = self.weight_rbr_1x1_kxk_conv1.squeeze()
  934. else:
  935. raise NotImplementedError
  936. weight_rbr_1x1_kxk_conv2 = self.weight_rbr_1x1_kxk_conv2
  937. if self.groups > 1:
  938. g = self.groups
  939. t, ig = weight_rbr_1x1_kxk_conv1.size()
  940. o, tg, h, w = weight_rbr_1x1_kxk_conv2.size()
  941. weight_rbr_1x1_kxk_conv1 = weight_rbr_1x1_kxk_conv1.view(g, int(t/g), ig)
  942. weight_rbr_1x1_kxk_conv2 = weight_rbr_1x1_kxk_conv2.view(g, int(o/g), tg, h, w)
  943. weight_rbr_1x1_kxk = torch.einsum('gti,gothw->goihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2).view(o, ig, h, w)
  944. else:
  945. weight_rbr_1x1_kxk = torch.einsum('ti,othw->oihw', weight_rbr_1x1_kxk_conv1, weight_rbr_1x1_kxk_conv2)
  946. weight_rbr_1x1_kxk = torch.einsum('oihw,o->oihw', weight_rbr_1x1_kxk, self.vector[3, :])
  947. weight_rbr_gconv = self.dwsc2full(self.weight_rbr_gconv_dw, self.weight_rbr_gconv_pw, self.in_channels)
  948. weight_rbr_gconv = torch.einsum('oihw,o->oihw', weight_rbr_gconv, self.vector[4, :])
  949. weight = weight_rbr_origin + weight_rbr_avg + weight_rbr_1x1_kxk + weight_rbr_pfir + weight_rbr_gconv
  950. return weight
  951. def dwsc2full(self, weight_dw, weight_pw, groups):
  952. t, ig, h, w = weight_dw.size()
  953. o, _, _, _ = weight_pw.size()
  954. tg = int(t/groups)
  955. i = int(ig*groups)
  956. weight_dw = weight_dw.view(groups, tg, ig, h, w)
  957. weight_pw = weight_pw.squeeze().view(o, groups, tg)
  958. weight_dsc = torch.einsum('gtihw,ogt->ogihw', weight_dw, weight_pw)
  959. return weight_dsc.view(o, i, h, w)
  960. def forward(self, inputs):
  961. weight = self.weight_gen()
  962. out = F.conv2d(inputs, weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
  963. return self.nonlinear(self.bn(out))
  964. class RepConv_OREPA(nn.Module):
  965. def __init__(self, c1, c2, k=3, s=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False, nonlinear=nn.SiLU()):
  966. super(RepConv_OREPA, self).__init__()
  967. self.deploy = deploy
  968. self.groups = groups
  969. self.in_channels = c1
  970. self.out_channels = c2
  971. self.padding = padding
  972. self.dilation = dilation
  973. self.groups = groups
  974. assert k == 3
  975. assert padding == 1
  976. padding_11 = padding - k // 2
  977. if nonlinear is None:
  978. self.nonlinearity = nn.Identity()
  979. else:
  980. self.nonlinearity = nonlinear
  981. if use_se:
  982. self.se = SEBlock(self.out_channels, internal_neurons=self.out_channels // 16)
  983. else:
  984. self.se = nn.Identity()
  985. if deploy:
  986. self.rbr_reparam = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=k, stride=s,
  987. padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
  988. else:
  989. self.rbr_identity = nn.BatchNorm2d(num_features=self.in_channels) if self.out_channels == self.in_channels and s == 1 else None
  990. self.rbr_dense = OREPA_3x3_RepConv(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=k, stride=s, padding=padding, groups=groups, dilation=1)
  991. self.rbr_1x1 = ConvBN(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=s, padding=padding_11, groups=groups, dilation=1)
  992. print('RepVGG Block, identity = ', self.rbr_identity)
  993. def forward(self, inputs):
  994. if hasattr(self, 'rbr_reparam'):
  995. return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
  996. if self.rbr_identity is None:
  997. id_out = 0
  998. else:
  999. id_out = self.rbr_identity(inputs)
  1000. out1 = self.rbr_dense(inputs)
  1001. out2 = self.rbr_1x1(inputs)
  1002. out3 = id_out
  1003. out = out1 + out2 + out3
  1004. return self.nonlinearity(self.se(out))
  1005. # Optional. This improves the accuracy and facilitates quantization.
  1006. # 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
  1007. # 2. Use like this.
  1008. # loss = criterion(....)
  1009. # for every RepVGGBlock blk:
  1010. # loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
  1011. # optimizer.zero_grad()
  1012. # loss.backward()
  1013. # Not used for OREPA
  1014. def get_custom_L2(self):
  1015. K3 = self.rbr_dense.weight_gen()
  1016. K1 = self.rbr_1x1.conv.weight
  1017. t3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
  1018. t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
  1019. l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
  1020. eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
  1021. l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.
  1022. return l2_loss_eq_kernel + l2_loss_circle
  1023. def get_equivalent_kernel_bias(self):
  1024. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
  1025. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
  1026. kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
  1027. return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
  1028. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  1029. if kernel1x1 is None:
  1030. return 0
  1031. else:
  1032. return torch.nn.functional.pad(kernel1x1, [1,1,1,1])
  1033. def _fuse_bn_tensor(self, branch):
  1034. if branch is None:
  1035. return 0, 0
  1036. if not isinstance(branch, nn.BatchNorm2d):
  1037. if isinstance(branch, OREPA_3x3_RepConv):
  1038. kernel = branch.weight_gen()
  1039. elif isinstance(branch, ConvBN):
  1040. kernel = branch.conv.weight
  1041. else:
  1042. raise NotImplementedError
  1043. running_mean = branch.bn.running_mean
  1044. running_var = branch.bn.running_var
  1045. gamma = branch.bn.weight
  1046. beta = branch.bn.bias
  1047. eps = branch.bn.eps
  1048. else:
  1049. if not hasattr(self, 'id_tensor'):
  1050. input_dim = self.in_channels // self.groups
  1051. kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
  1052. for i in range(self.in_channels):
  1053. kernel_value[i, i % input_dim, 1, 1] = 1
  1054. self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
  1055. kernel = self.id_tensor
  1056. running_mean = branch.running_mean
  1057. running_var = branch.running_var
  1058. gamma = branch.weight
  1059. beta = branch.bias
  1060. eps = branch.eps
  1061. std = (running_var + eps).sqrt()
  1062. t = (gamma / std).reshape(-1, 1, 1, 1)
  1063. return kernel * t, beta - running_mean * gamma / std
  1064. def switch_to_deploy(self):
  1065. if hasattr(self, 'rbr_reparam'):
  1066. return
  1067. print(f"RepConv_OREPA.switch_to_deploy")
  1068. kernel, bias = self.get_equivalent_kernel_bias()
  1069. self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.in_channels, out_channels=self.rbr_dense.out_channels,
  1070. kernel_size=self.rbr_dense.kernel_size, stride=self.rbr_dense.stride,
  1071. padding=self.rbr_dense.padding, dilation=self.rbr_dense.dilation, groups=self.rbr_dense.groups, bias=True)
  1072. self.rbr_reparam.weight.data = kernel
  1073. self.rbr_reparam.bias.data = bias
  1074. for para in self.parameters():
  1075. para.detach_()
  1076. self.__delattr__('rbr_dense')
  1077. self.__delattr__('rbr_1x1')
  1078. if hasattr(self, 'rbr_identity'):
  1079. self.__delattr__('rbr_identity')
  1080. ##### end of orepa #####
  1081. ##### swin transformer #####
  1082. class WindowAttention(nn.Module):
  1083. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
  1084. super().__init__()
  1085. self.dim = dim
  1086. self.window_size = window_size # Wh, Ww
  1087. self.num_heads = num_heads
  1088. head_dim = dim // num_heads
  1089. self.scale = qk_scale or head_dim ** -0.5
  1090. # define a parameter table of relative position bias
  1091. self.relative_position_bias_table = nn.Parameter(
  1092. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
  1093. # get pair-wise relative position index for each token inside the window
  1094. coords_h = torch.arange(self.window_size[0])
  1095. coords_w = torch.arange(self.window_size[1])
  1096. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  1097. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  1098. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  1099. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  1100. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  1101. relative_coords[:, :, 1] += self.window_size[1] - 1
  1102. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  1103. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  1104. self.register_buffer("relative_position_index", relative_position_index)
  1105. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  1106. self.attn_drop = nn.Dropout(attn_drop)
  1107. self.proj = nn.Linear(dim, dim)
  1108. self.proj_drop = nn.Dropout(proj_drop)
  1109. nn.init.normal_(self.relative_position_bias_table, std=.02)
  1110. self.softmax = nn.Softmax(dim=-1)
  1111. def forward(self, x, mask=None):
  1112. B_, N, C = x.shape
  1113. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  1114. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  1115. q = q * self.scale
  1116. attn = (q @ k.transpose(-2, -1))
  1117. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  1118. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
  1119. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  1120. attn = attn + relative_position_bias.unsqueeze(0)
  1121. if mask is not None:
  1122. nW = mask.shape[0]
  1123. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  1124. attn = attn.view(-1, self.num_heads, N, N)
  1125. attn = self.softmax(attn)
  1126. else:
  1127. attn = self.softmax(attn)
  1128. attn = self.attn_drop(attn)
  1129. # print(attn.dtype, v.dtype)
  1130. try:
  1131. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  1132. except:
  1133. #print(attn.dtype, v.dtype)
  1134. x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
  1135. x = self.proj(x)
  1136. x = self.proj_drop(x)
  1137. return x
  1138. class Mlp(nn.Module):
  1139. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
  1140. super().__init__()
  1141. out_features = out_features or in_features
  1142. hidden_features = hidden_features or in_features
  1143. self.fc1 = nn.Linear(in_features, hidden_features)
  1144. self.act = act_layer()
  1145. self.fc2 = nn.Linear(hidden_features, out_features)
  1146. self.drop = nn.Dropout(drop)
  1147. def forward(self, x):
  1148. x = self.fc1(x)
  1149. x = self.act(x)
  1150. x = self.drop(x)
  1151. x = self.fc2(x)
  1152. x = self.drop(x)
  1153. return x
  1154. def window_partition(x, window_size):
  1155. B, H, W, C = x.shape
  1156. assert H % window_size == 0, 'feature map h and w can not divide by window size'
  1157. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  1158. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  1159. return windows
  1160. def window_reverse(windows, window_size, H, W):
  1161. B = int(windows.shape[0] / (H * W / window_size / window_size))
  1162. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  1163. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  1164. return x
  1165. class SwinTransformerLayer(nn.Module):
  1166. def __init__(self, dim, num_heads, window_size=8, shift_size=0,
  1167. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  1168. act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
  1169. super().__init__()
  1170. self.dim = dim
  1171. self.num_heads = num_heads
  1172. self.window_size = window_size
  1173. self.shift_size = shift_size
  1174. self.mlp_ratio = mlp_ratio
  1175. # if min(self.input_resolution) <= self.window_size:
  1176. # # if window size is larger than input resolution, we don't partition windows
  1177. # self.shift_size = 0
  1178. # self.window_size = min(self.input_resolution)
  1179. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  1180. self.norm1 = norm_layer(dim)
  1181. self.attn = WindowAttention(
  1182. dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
  1183. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  1184. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  1185. self.norm2 = norm_layer(dim)
  1186. mlp_hidden_dim = int(dim * mlp_ratio)
  1187. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  1188. def create_mask(self, H, W):
  1189. # calculate attention mask for SW-MSA
  1190. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
  1191. h_slices = (slice(0, -self.window_size),
  1192. slice(-self.window_size, -self.shift_size),
  1193. slice(-self.shift_size, None))
  1194. w_slices = (slice(0, -self.window_size),
  1195. slice(-self.window_size, -self.shift_size),
  1196. slice(-self.shift_size, None))
  1197. cnt = 0
  1198. for h in h_slices:
  1199. for w in w_slices:
  1200. img_mask[:, h, w, :] = cnt
  1201. cnt += 1
  1202. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  1203. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  1204. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  1205. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  1206. return attn_mask
  1207. def forward(self, x):
  1208. # reshape x[b c h w] to x[b l c]
  1209. _, _, H_, W_ = x.shape
  1210. Padding = False
  1211. if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
  1212. Padding = True
  1213. # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
  1214. pad_r = (self.window_size - W_ % self.window_size) % self.window_size
  1215. pad_b = (self.window_size - H_ % self.window_size) % self.window_size
  1216. x = F.pad(x, (0, pad_r, 0, pad_b))
  1217. # print('2', x.shape)
  1218. B, C, H, W = x.shape
  1219. L = H * W
  1220. x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
  1221. # create mask from init to forward
  1222. if self.shift_size > 0:
  1223. attn_mask = self.create_mask(H, W).to(x.device)
  1224. else:
  1225. attn_mask = None
  1226. shortcut = x
  1227. x = self.norm1(x)
  1228. x = x.view(B, H, W, C)
  1229. # cyclic shift
  1230. if self.shift_size > 0:
  1231. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  1232. else:
  1233. shifted_x = x
  1234. # partition windows
  1235. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  1236. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
  1237. # W-MSA/SW-MSA
  1238. attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
  1239. # merge windows
  1240. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  1241. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
  1242. # reverse cyclic shift
  1243. if self.shift_size > 0:
  1244. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  1245. else:
  1246. x = shifted_x
  1247. x = x.view(B, H * W, C)
  1248. # FFN
  1249. x = shortcut + self.drop_path(x)
  1250. x = x + self.drop_path(self.mlp(self.norm2(x)))
  1251. x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
  1252. if Padding:
  1253. x = x[:, :, :H_, :W_] # reverse padding
  1254. return x
  1255. class SwinTransformerBlock(nn.Module):
  1256. def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
  1257. super().__init__()
  1258. self.conv = None
  1259. if c1 != c2:
  1260. self.conv = Conv(c1, c2)
  1261. # remove input_resolution
  1262. self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
  1263. shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
  1264. def forward(self, x):
  1265. if self.conv is not None:
  1266. x = self.conv(x)
  1267. x = self.blocks(x)
  1268. return x
  1269. class STCSPA(nn.Module):
  1270. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  1271. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  1272. super(STCSPA, self).__init__()
  1273. c_ = int(c2 * e) # hidden channels
  1274. self.cv1 = Conv(c1, c_, 1, 1)
  1275. self.cv2 = Conv(c1, c_, 1, 1)
  1276. self.cv3 = Conv(2 * c_, c2, 1, 1)
  1277. num_heads = c_ // 32
  1278. self.m = SwinTransformerBlock(c_, c_, num_heads, n)
  1279. #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  1280. def forward(self, x):
  1281. y1 = self.m(self.cv1(x))
  1282. y2 = self.cv2(x)
  1283. return self.cv3(torch.cat((y1, y2), dim=1))
  1284. class STCSPB(nn.Module):
  1285. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  1286. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  1287. super(STCSPB, self).__init__()
  1288. c_ = int(c2) # hidden channels
  1289. self.cv1 = Conv(c1, c_, 1, 1)
  1290. self.cv2 = Conv(c_, c_, 1, 1)
  1291. self.cv3 = Conv(2 * c_, c2, 1, 1)
  1292. num_heads = c_ // 32
  1293. self.m = SwinTransformerBlock(c_, c_, num_heads, n)
  1294. #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  1295. def forward(self, x):
  1296. x1 = self.cv1(x)
  1297. y1 = self.m(x1)
  1298. y2 = self.cv2(x1)
  1299. return self.cv3(torch.cat((y1, y2), dim=1))
  1300. class STCSPC(nn.Module):
  1301. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  1302. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  1303. super(STCSPC, self).__init__()
  1304. c_ = int(c2 * e) # hidden channels
  1305. self.cv1 = Conv(c1, c_, 1, 1)
  1306. self.cv2 = Conv(c1, c_, 1, 1)
  1307. self.cv3 = Conv(c_, c_, 1, 1)
  1308. self.cv4 = Conv(2 * c_, c2, 1, 1)
  1309. num_heads = c_ // 32
  1310. self.m = SwinTransformerBlock(c_, c_, num_heads, n)
  1311. #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  1312. def forward(self, x):
  1313. y1 = self.cv3(self.m(self.cv1(x)))
  1314. y2 = self.cv2(x)
  1315. return self.cv4(torch.cat((y1, y2), dim=1))
  1316. ##### end of swin transformer #####
  1317. ##### swin transformer v2 #####
  1318. class WindowAttention_v2(nn.Module):
  1319. def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
  1320. pretrained_window_size=[0, 0]):
  1321. super().__init__()
  1322. self.dim = dim
  1323. self.window_size = window_size # Wh, Ww
  1324. self.pretrained_window_size = pretrained_window_size
  1325. self.num_heads = num_heads
  1326. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
  1327. # mlp to generate continuous relative position bias
  1328. self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
  1329. nn.ReLU(inplace=True),
  1330. nn.Linear(512, num_heads, bias=False))
  1331. # get relative_coords_table
  1332. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
  1333. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
  1334. relative_coords_table = torch.stack(
  1335. torch.meshgrid([relative_coords_h,
  1336. relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
  1337. if pretrained_window_size[0] > 0:
  1338. relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
  1339. relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
  1340. else:
  1341. relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
  1342. relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
  1343. relative_coords_table *= 8 # normalize to -8, 8
  1344. relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
  1345. torch.abs(relative_coords_table) + 1.0) / np.log2(8)
  1346. self.register_buffer("relative_coords_table", relative_coords_table)
  1347. # get pair-wise relative position index for each token inside the window
  1348. coords_h = torch.arange(self.window_size[0])
  1349. coords_w = torch.arange(self.window_size[1])
  1350. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  1351. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  1352. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  1353. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  1354. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  1355. relative_coords[:, :, 1] += self.window_size[1] - 1
  1356. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  1357. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  1358. self.register_buffer("relative_position_index", relative_position_index)
  1359. self.qkv = nn.Linear(dim, dim * 3, bias=False)
  1360. if qkv_bias:
  1361. self.q_bias = nn.Parameter(torch.zeros(dim))
  1362. self.v_bias = nn.Parameter(torch.zeros(dim))
  1363. else:
  1364. self.q_bias = None
  1365. self.v_bias = None
  1366. self.attn_drop = nn.Dropout(attn_drop)
  1367. self.proj = nn.Linear(dim, dim)
  1368. self.proj_drop = nn.Dropout(proj_drop)
  1369. self.softmax = nn.Softmax(dim=-1)
  1370. def forward(self, x, mask=None):
  1371. B_, N, C = x.shape
  1372. qkv_bias = None
  1373. if self.q_bias is not None:
  1374. qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
  1375. qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
  1376. qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  1377. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  1378. # cosine attention
  1379. attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
  1380. logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
  1381. attn = attn * logit_scale
  1382. relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
  1383. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
  1384. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
  1385. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  1386. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
  1387. attn = attn + relative_position_bias.unsqueeze(0)
  1388. if mask is not None:
  1389. nW = mask.shape[0]
  1390. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  1391. attn = attn.view(-1, self.num_heads, N, N)
  1392. attn = self.softmax(attn)
  1393. else:
  1394. attn = self.softmax(attn)
  1395. attn = self.attn_drop(attn)
  1396. try:
  1397. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  1398. except:
  1399. x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
  1400. x = self.proj(x)
  1401. x = self.proj_drop(x)
  1402. return x
  1403. def extra_repr(self) -> str:
  1404. return f'dim={self.dim}, window_size={self.window_size}, ' \
  1405. f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
  1406. def flops(self, N):
  1407. # calculate flops for 1 window with token length of N
  1408. flops = 0
  1409. # qkv = self.qkv(x)
  1410. flops += N * self.dim * 3 * self.dim
  1411. # attn = (q @ k.transpose(-2, -1))
  1412. flops += self.num_heads * N * (self.dim // self.num_heads) * N
  1413. # x = (attn @ v)
  1414. flops += self.num_heads * N * N * (self.dim // self.num_heads)
  1415. # x = self.proj(x)
  1416. flops += N * self.dim * self.dim
  1417. return flops
  1418. class Mlp_v2(nn.Module):
  1419. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
  1420. super().__init__()
  1421. out_features = out_features or in_features
  1422. hidden_features = hidden_features or in_features
  1423. self.fc1 = nn.Linear(in_features, hidden_features)
  1424. self.act = act_layer()
  1425. self.fc2 = nn.Linear(hidden_features, out_features)
  1426. self.drop = nn.Dropout(drop)
  1427. def forward(self, x):
  1428. x = self.fc1(x)
  1429. x = self.act(x)
  1430. x = self.drop(x)
  1431. x = self.fc2(x)
  1432. x = self.drop(x)
  1433. return x
  1434. def window_partition_v2(x, window_size):
  1435. B, H, W, C = x.shape
  1436. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  1437. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  1438. return windows
  1439. def window_reverse_v2(windows, window_size, H, W):
  1440. B = int(windows.shape[0] / (H * W / window_size / window_size))
  1441. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  1442. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  1443. return x
  1444. class SwinTransformerLayer_v2(nn.Module):
  1445. def __init__(self, dim, num_heads, window_size=7, shift_size=0,
  1446. mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
  1447. act_layer=nn.SiLU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
  1448. super().__init__()
  1449. self.dim = dim
  1450. #self.input_resolution = input_resolution
  1451. self.num_heads = num_heads
  1452. self.window_size = window_size
  1453. self.shift_size = shift_size
  1454. self.mlp_ratio = mlp_ratio
  1455. #if min(self.input_resolution) <= self.window_size:
  1456. # # if window size is larger than input resolution, we don't partition windows
  1457. # self.shift_size = 0
  1458. # self.window_size = min(self.input_resolution)
  1459. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  1460. self.norm1 = norm_layer(dim)
  1461. self.attn = WindowAttention_v2(
  1462. dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
  1463. qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
  1464. pretrained_window_size=(pretrained_window_size, pretrained_window_size))
  1465. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  1466. self.norm2 = norm_layer(dim)
  1467. mlp_hidden_dim = int(dim * mlp_ratio)
  1468. self.mlp = Mlp_v2(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  1469. def create_mask(self, H, W):
  1470. # calculate attention mask for SW-MSA
  1471. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
  1472. h_slices = (slice(0, -self.window_size),
  1473. slice(-self.window_size, -self.shift_size),
  1474. slice(-self.shift_size, None))
  1475. w_slices = (slice(0, -self.window_size),
  1476. slice(-self.window_size, -self.shift_size),
  1477. slice(-self.shift_size, None))
  1478. cnt = 0
  1479. for h in h_slices:
  1480. for w in w_slices:
  1481. img_mask[:, h, w, :] = cnt
  1482. cnt += 1
  1483. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  1484. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  1485. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  1486. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  1487. return attn_mask
  1488. def forward(self, x):
  1489. # reshape x[b c h w] to x[b l c]
  1490. _, _, H_, W_ = x.shape
  1491. Padding = False
  1492. if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
  1493. Padding = True
  1494. # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
  1495. pad_r = (self.window_size - W_ % self.window_size) % self.window_size
  1496. pad_b = (self.window_size - H_ % self.window_size) % self.window_size
  1497. x = F.pad(x, (0, pad_r, 0, pad_b))
  1498. # print('2', x.shape)
  1499. B, C, H, W = x.shape
  1500. L = H * W
  1501. x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
  1502. # create mask from init to forward
  1503. if self.shift_size > 0:
  1504. attn_mask = self.create_mask(H, W).to(x.device)
  1505. else:
  1506. attn_mask = None
  1507. shortcut = x
  1508. x = x.view(B, H, W, C)
  1509. # cyclic shift
  1510. if self.shift_size > 0:
  1511. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  1512. else:
  1513. shifted_x = x
  1514. # partition windows
  1515. x_windows = window_partition_v2(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  1516. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
  1517. # W-MSA/SW-MSA
  1518. attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
  1519. # merge windows
  1520. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  1521. shifted_x = window_reverse_v2(attn_windows, self.window_size, H, W) # B H' W' C
  1522. # reverse cyclic shift
  1523. if self.shift_size > 0:
  1524. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  1525. else:
  1526. x = shifted_x
  1527. x = x.view(B, H * W, C)
  1528. x = shortcut + self.drop_path(self.norm1(x))
  1529. # FFN
  1530. x = x + self.drop_path(self.norm2(self.mlp(x)))
  1531. x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
  1532. if Padding:
  1533. x = x[:, :, :H_, :W_] # reverse padding
  1534. return x
  1535. def extra_repr(self) -> str:
  1536. return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
  1537. f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
  1538. def flops(self):
  1539. flops = 0
  1540. H, W = self.input_resolution
  1541. # norm1
  1542. flops += self.dim * H * W
  1543. # W-MSA/SW-MSA
  1544. nW = H * W / self.window_size / self.window_size
  1545. flops += nW * self.attn.flops(self.window_size * self.window_size)
  1546. # mlp
  1547. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
  1548. # norm2
  1549. flops += self.dim * H * W
  1550. return flops
  1551. class SwinTransformer2Block(nn.Module):
  1552. def __init__(self, c1, c2, num_heads, num_layers, window_size=7):
  1553. super().__init__()
  1554. self.conv = None
  1555. if c1 != c2:
  1556. self.conv = Conv(c1, c2)
  1557. # remove input_resolution
  1558. self.blocks = nn.Sequential(*[SwinTransformerLayer_v2(dim=c2, num_heads=num_heads, window_size=window_size,
  1559. shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
  1560. def forward(self, x):
  1561. if self.conv is not None:
  1562. x = self.conv(x)
  1563. x = self.blocks(x)
  1564. return x
  1565. class ST2CSPA(nn.Module):
  1566. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  1567. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  1568. super(ST2CSPA, self).__init__()
  1569. c_ = int(c2 * e) # hidden channels
  1570. self.cv1 = Conv(c1, c_, 1, 1)
  1571. self.cv2 = Conv(c1, c_, 1, 1)
  1572. self.cv3 = Conv(2 * c_, c2, 1, 1)
  1573. num_heads = c_ // 32
  1574. self.m = SwinTransformer2Block(c_, c_, num_heads, n)
  1575. #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  1576. def forward(self, x):
  1577. y1 = self.m(self.cv1(x))
  1578. y2 = self.cv2(x)
  1579. return self.cv3(torch.cat((y1, y2), dim=1))
  1580. class ST2CSPB(nn.Module):
  1581. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  1582. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  1583. super(ST2CSPB, self).__init__()
  1584. c_ = int(c2) # hidden channels
  1585. self.cv1 = Conv(c1, c_, 1, 1)
  1586. self.cv2 = Conv(c_, c_, 1, 1)
  1587. self.cv3 = Conv(2 * c_, c2, 1, 1)
  1588. num_heads = c_ // 32
  1589. self.m = SwinTransformer2Block(c_, c_, num_heads, n)
  1590. #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  1591. def forward(self, x):
  1592. x1 = self.cv1(x)
  1593. y1 = self.m(x1)
  1594. y2 = self.cv2(x1)
  1595. return self.cv3(torch.cat((y1, y2), dim=1))
  1596. class ST2CSPC(nn.Module):
  1597. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  1598. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  1599. super(ST2CSPC, self).__init__()
  1600. c_ = int(c2 * e) # hidden channels
  1601. self.cv1 = Conv(c1, c_, 1, 1)
  1602. self.cv2 = Conv(c1, c_, 1, 1)
  1603. self.cv3 = Conv(c_, c_, 1, 1)
  1604. self.cv4 = Conv(2 * c_, c2, 1, 1)
  1605. num_heads = c_ // 32
  1606. self.m = SwinTransformer2Block(c_, c_, num_heads, n)
  1607. #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
  1608. def forward(self, x):
  1609. y1 = self.cv3(self.m(self.cv1(x)))
  1610. y2 = self.cv2(x)
  1611. return self.cv4(torch.cat((y1, y2), dim=1))
  1612. ##### end of swin transformer v2 #####