高速公路违停检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

335 lines
12KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import torchvision
  7. from nets.stdcnet import STDCNet1446, STDCNet813
  8. from modules.bn import InPlaceABNSync as BatchNorm2d
  9. # BatchNorm2d = nn.BatchNorm2d
  10. class ConvBNReLU(nn.Module):
  11. def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
  12. super(ConvBNReLU, self).__init__()
  13. self.conv = nn.Conv2d(in_chan,
  14. out_chan,
  15. kernel_size = ks,
  16. stride = stride,
  17. padding = padding,
  18. bias = False)
  19. # self.bn = BatchNorm2d(out_chan)
  20. # self.bn = BatchNorm2d(out_chan, activation='none')
  21. self.bn = nn.BatchNorm2d(out_chan)
  22. self.relu = nn.ReLU()
  23. self.init_weight()
  24. def forward(self, x):
  25. x = self.conv(x)
  26. x = self.bn(x)
  27. x = self.relu(x)
  28. return x
  29. def init_weight(self):
  30. for ly in self.children():
  31. if isinstance(ly, nn.Conv2d):
  32. nn.init.kaiming_normal_(ly.weight, a=1)
  33. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  34. class BiSeNetOutput(nn.Module):
  35. def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
  36. super(BiSeNetOutput, self).__init__()
  37. self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
  38. self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
  39. self.init_weight()
  40. def forward(self, x):
  41. x = self.conv(x)
  42. x = self.conv_out(x)
  43. return x
  44. def init_weight(self):
  45. for ly in self.children():
  46. if isinstance(ly, nn.Conv2d):
  47. nn.init.kaiming_normal_(ly.weight, a=1)
  48. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  49. def get_params(self):
  50. wd_params, nowd_params = [], []
  51. for name, module in self.named_modules():
  52. if isinstance(module, (nn.Linear, nn.Conv2d)):
  53. wd_params.append(module.weight)
  54. if not module.bias is None:
  55. nowd_params.append(module.bias)
  56. elif isinstance(module, nn.BatchNorm2d):######################1
  57. nowd_params += list(module.parameters())
  58. return wd_params, nowd_params
  59. class AttentionRefinementModule(nn.Module):
  60. def __init__(self, in_chan, out_chan, *args, **kwargs):
  61. super(AttentionRefinementModule, self).__init__()
  62. self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
  63. self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
  64. # self.bn_atten = nn.BatchNorm2d(out_chan)
  65. # self.bn_atten = BatchNorm2d(out_chan, activation='none')
  66. self.bn_atten = nn.BatchNorm2d(out_chan)########################2
  67. self.sigmoid_atten = nn.Sigmoid()
  68. self.init_weight()
  69. def forward(self, x):
  70. feat = self.conv(x)
  71. atten = F.avg_pool2d(feat, feat.size()[2:])
  72. atten = self.conv_atten(atten)
  73. atten = self.bn_atten(atten)
  74. atten = self.sigmoid_atten(atten)
  75. out = torch.mul(feat, atten)
  76. return out
  77. def init_weight(self):
  78. for ly in self.children():
  79. if isinstance(ly, nn.Conv2d):
  80. nn.init.kaiming_normal_(ly.weight, a=1)
  81. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  82. class ContextPath(nn.Module):
  83. def __init__(self, backbone='CatNetSmall', pretrain_model='', use_conv_last=False, *args, **kwargs):
  84. super(ContextPath, self).__init__()
  85. self.backbone_name = backbone
  86. if backbone == 'STDCNet1446':
  87. self.backbone = STDCNet1446(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
  88. self.arm16 = AttentionRefinementModule(512, 128)
  89. inplanes = 1024
  90. if use_conv_last:
  91. inplanes = 1024
  92. self.arm32 = AttentionRefinementModule(inplanes, 128)
  93. self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  94. self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  95. self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
  96. elif backbone == 'STDCNet813':
  97. self.backbone = STDCNet813(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
  98. self.arm16 = AttentionRefinementModule(512, 128)
  99. inplanes = 1024
  100. if use_conv_last:
  101. inplanes = 1024
  102. self.arm32 = AttentionRefinementModule(inplanes, 128)
  103. self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  104. self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  105. self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
  106. else:
  107. print("backbone is not in backbone lists")
  108. exit(0)
  109. self.init_weight()
  110. def forward(self, x):
  111. H0, W0 = x.size()[2:]
  112. feat2, feat4, feat8, feat16, feat32 = self.backbone(x)
  113. H8, W8 = feat8.size()[2:]
  114. H16, W16 = feat16.size()[2:]
  115. H32, W32 = feat32.size()[2:]
  116. avg = F.avg_pool2d(feat32, feat32.size()[2:])
  117. avg = self.conv_avg(avg)
  118. avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
  119. feat32_arm = self.arm32(feat32)
  120. feat32_sum = feat32_arm + avg_up
  121. feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
  122. feat32_up = self.conv_head32(feat32_up)
  123. feat16_arm = self.arm16(feat16)
  124. feat16_sum = feat16_arm + feat32_up
  125. feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
  126. feat16_up = self.conv_head16(feat16_up)
  127. return feat2, feat4, feat8, feat16, feat16_up, feat32_up # x8, x16
  128. def init_weight(self):
  129. for ly in self.children():
  130. if isinstance(ly, nn.Conv2d):
  131. nn.init.kaiming_normal_(ly.weight, a=1)
  132. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  133. def get_params(self):
  134. wd_params, nowd_params = [], []
  135. for name, module in self.named_modules():
  136. if isinstance(module, (nn.Linear, nn.Conv2d)):
  137. wd_params.append(module.weight)
  138. if not module.bias is None:
  139. nowd_params.append(module.bias)
  140. elif isinstance(module, nn.BatchNorm2d):#################3
  141. nowd_params += list(module.parameters())
  142. return wd_params, nowd_params
  143. class FeatureFusionModule(nn.Module):
  144. def __init__(self, in_chan, out_chan, *args, **kwargs):
  145. super(FeatureFusionModule, self).__init__()
  146. self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
  147. self.conv1 = nn.Conv2d(out_chan,
  148. out_chan//4,
  149. kernel_size = 1,
  150. stride = 1,
  151. padding = 0,
  152. bias = False)
  153. self.conv2 = nn.Conv2d(out_chan//4,
  154. out_chan,
  155. kernel_size = 1,
  156. stride = 1,
  157. padding = 0,
  158. bias = False)
  159. self.relu = nn.ReLU(inplace=True)
  160. self.sigmoid = nn.Sigmoid()
  161. self.init_weight()
  162. def forward(self, fsp, fcp):
  163. fcat = torch.cat([fsp, fcp], dim=1)
  164. feat = self.convblk(fcat)
  165. atten = F.avg_pool2d(feat, feat.size()[2:])
  166. atten = self.conv1(atten)
  167. atten = self.relu(atten)
  168. atten = self.conv2(atten)
  169. atten = self.sigmoid(atten)
  170. feat_atten = torch.mul(feat, atten)
  171. feat_out = feat_atten + feat
  172. return feat_out
  173. def init_weight(self):
  174. for ly in self.children():
  175. if isinstance(ly, nn.Conv2d):
  176. nn.init.kaiming_normal_(ly.weight, a=1)
  177. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  178. def get_params(self):
  179. wd_params, nowd_params = [], []
  180. for name, module in self.named_modules():
  181. if isinstance(module, (nn.Linear, nn.Conv2d)):
  182. wd_params.append(module.weight)
  183. if not module.bias is None:
  184. nowd_params.append(module.bias)
  185. elif isinstance(module, nn.BatchNorm2d):##################4
  186. nowd_params += list(module.parameters())
  187. return wd_params, nowd_params
  188. class BiSeNet(nn.Module):
  189. def __init__(self, backbone, n_classes, pretrain_model='', use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False, heat_map=False, *args, **kwargs):
  190. super(BiSeNet, self).__init__()
  191. self.use_boundary_2 = use_boundary_2
  192. self.use_boundary_4 = use_boundary_4
  193. self.use_boundary_8 = use_boundary_8
  194. self.use_boundary_16 = use_boundary_16
  195. # self.heat_map = heat_map
  196. self.cp = ContextPath(backbone, pretrain_model, use_conv_last=use_conv_last)
  197. if backbone == 'STDCNet1446':
  198. conv_out_inplanes = 128
  199. sp2_inplanes = 32
  200. sp4_inplanes = 64
  201. sp8_inplanes = 256
  202. sp16_inplanes = 512
  203. inplane = sp8_inplanes + conv_out_inplanes
  204. elif backbone == 'STDCNet813':
  205. conv_out_inplanes = 128
  206. sp2_inplanes = 32
  207. sp4_inplanes = 64
  208. sp8_inplanes = 256
  209. sp16_inplanes = 512
  210. inplane = sp8_inplanes + conv_out_inplanes
  211. else:
  212. print("backbone is not in backbone lists")
  213. exit(0)
  214. self.ffm = FeatureFusionModule(inplane, 256)
  215. self.conv_out = BiSeNetOutput(256, 256, n_classes)
  216. self.conv_out16 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)
  217. self.conv_out32 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)
  218. self.conv_out_sp16 = BiSeNetOutput(sp16_inplanes, 64, 1)
  219. self.conv_out_sp8 = BiSeNetOutput(sp8_inplanes, 64, 1)
  220. self.conv_out_sp4 = BiSeNetOutput(sp4_inplanes, 64, 1)
  221. self.conv_out_sp2 = BiSeNetOutput(sp2_inplanes, 64, 1)
  222. self.init_weight()
  223. def forward(self, x):
  224. H, W = x.size()[2:]
  225. feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x)
  226. feat_out_sp2 = self.conv_out_sp2(feat_res2)
  227. feat_out_sp4 = self.conv_out_sp4(feat_res4)
  228. feat_out_sp8 = self.conv_out_sp8(feat_res8)
  229. feat_out_sp16 = self.conv_out_sp16(feat_res16)
  230. feat_fuse = self.ffm(feat_res8, feat_cp8)
  231. feat_out = self.conv_out(feat_fuse)
  232. feat_out16 = self.conv_out16(feat_cp8)
  233. feat_out32 = self.conv_out32(feat_cp16)
  234. feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
  235. feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
  236. feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
  237. if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8:
  238. return feat_out, feat_out16, feat_out32, feat_out_sp2, feat_out_sp4, feat_out_sp8
  239. if (not self.use_boundary_2) and self.use_boundary_4 and self.use_boundary_8:
  240. return feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8
  241. if (not self.use_boundary_2) and (not self.use_boundary_4) and self.use_boundary_8:
  242. return feat_out, feat_out16, feat_out32, feat_out_sp8
  243. if (not self.use_boundary_2) and (not self.use_boundary_4) and (not self.use_boundary_8):
  244. return feat_out, feat_out16, feat_out32
  245. def init_weight(self):
  246. for ly in self.children():
  247. if isinstance(ly, nn.Conv2d):
  248. nn.init.kaiming_normal_(ly.weight, a=1)
  249. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  250. def get_params(self):
  251. wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
  252. for name, child in self.named_children():
  253. child_wd_params, child_nowd_params = child.get_params()
  254. if isinstance(child, (FeatureFusionModule, BiSeNetOutput)):
  255. lr_mul_wd_params += child_wd_params
  256. lr_mul_nowd_params += child_nowd_params
  257. else:
  258. wd_params += child_wd_params
  259. nowd_params += child_nowd_params
  260. return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
  261. if __name__ == "__main__":
  262. # net = BiSeNet('STDCNet813', 19) # 原始
  263. net = BiSeNet('STDCNet813', 3) # 改动
  264. net.cuda()
  265. net.eval()
  266. in_ten = torch.randn(1, 3, 768, 1536).cuda()
  267. out, out16, out32 = net(in_ten)
  268. print(out.shape)
  269. # torch.save(net.state_dict(), 'STDCNet813.pth')###