You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_stages.py 14KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import torchvision
  7. import time
  8. from stdcnet import STDCNet1446, STDCNet813
  9. #from models_725.bn import InPlaceABNSync as BatchNorm2d
  10. # BatchNorm2d = nn.BatchNorm2d
  11. modelSize=(360,640) ##(W,H)
  12. print('######Attention model input(H,W):',modelSize)
  13. class ConvBNReLU(nn.Module):
  14. def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
  15. super(ConvBNReLU, self).__init__()
  16. self.conv = nn.Conv2d(in_chan,
  17. out_chan,
  18. kernel_size = ks,
  19. stride = stride,
  20. padding = padding,
  21. bias = False)
  22. # self.bn = BatchNorm2d(out_chan)
  23. # self.bn = BatchNorm2d(out_chan, activation='none')
  24. self.bn = nn.BatchNorm2d(out_chan)
  25. self.relu = nn.ReLU()
  26. self.init_weight()
  27. def forward(self, x):
  28. x = self.conv(x)
  29. x = self.bn(x)
  30. x = self.relu(x)
  31. return x
  32. def init_weight(self):
  33. for ly in self.children():
  34. if isinstance(ly, nn.Conv2d):
  35. nn.init.kaiming_normal_(ly.weight, a=1)
  36. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  37. class BiSeNetOutput(nn.Module):
  38. def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
  39. super(BiSeNetOutput, self).__init__()
  40. self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
  41. self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
  42. self.init_weight()
  43. def forward(self, x):
  44. x = self.conv(x)
  45. x = self.conv_out(x)
  46. return x
  47. def init_weight(self):
  48. for ly in self.children():
  49. if isinstance(ly, nn.Conv2d):
  50. nn.init.kaiming_normal_(ly.weight, a=1)
  51. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  52. def get_params(self):
  53. wd_params, nowd_params = [], []
  54. for name, module in self.named_modules():
  55. if isinstance(module, (nn.Linear, nn.Conv2d)):
  56. wd_params.append(module.weight)
  57. if not module.bias is None:
  58. nowd_params.append(module.bias)
  59. elif isinstance(module, nn.BatchNorm2d):######################1
  60. nowd_params += list(module.parameters())
  61. return wd_params, nowd_params
  62. class AttentionRefinementModule(nn.Module):
  63. def __init__(self, in_chan, out_chan,avg_pool2d_kernel_size, *args, **kwargs):
  64. super(AttentionRefinementModule, self).__init__()
  65. self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
  66. self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
  67. # self.bn_atten = nn.BatchNorm2d(out_chan)
  68. # self.bn_atten = BatchNorm2d(out_chan, activation='none')
  69. self.bn_atten = nn.BatchNorm2d(out_chan)########################2
  70. self.sigmoid_atten = nn.Sigmoid()
  71. self.avg_pool2d_kernel_size = avg_pool2d_kernel_size
  72. self.init_weight()
  73. def forward(self, x):
  74. feat = self.conv(x)
  75. #atten = F.avg_pool2d(feat, feat.size()[2:])
  76. atten = F.avg_pool2d(feat, self.avg_pool2d_kernel_size)
  77. #print('------------------newline86:','out:',atten.size(),'in:',feat.size())
  78. atten = self.conv_atten(atten)
  79. atten = self.bn_atten(atten)
  80. atten = self.sigmoid_atten(atten)
  81. out = torch.mul(feat, atten)
  82. return out
  83. def init_weight(self):
  84. for ly in self.children():
  85. if isinstance(ly, nn.Conv2d):
  86. nn.init.kaiming_normal_(ly.weight, a=1)
  87. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  88. class ContextPath(nn.Module):
  89. def __init__(self, backbone='CatNetSmall', pretrain_model='', use_conv_last=False, *args, **kwargs):
  90. super(ContextPath, self).__init__()
  91. self.backbone_name = backbone
  92. self.avg_pool_kernel_size_32=[ int(modelSize[0]/32+0.999), int( modelSize[1]/32+0.999 ) ]
  93. self.avg_pool_kernel_size_16=[ int(modelSize[0]/16+0.999), int( modelSize[1]/16+0.999 ) ]
  94. if backbone == 'STDCNet1446':
  95. self.backbone = STDCNet1446(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
  96. self.arm16 = AttentionRefinementModule(512, 128)
  97. inplanes = 1024
  98. if use_conv_last:
  99. inplanes = 1024
  100. self.arm32 = AttentionRefinementModule(inplanes, 128)
  101. self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  102. self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  103. self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
  104. elif backbone == 'STDCNet813':
  105. self.backbone = STDCNet813(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
  106. self.arm16 = AttentionRefinementModule(512, 128,self.avg_pool_kernel_size_16)
  107. inplanes = 1024
  108. if use_conv_last:
  109. inplanes = 1024
  110. self.arm32 = AttentionRefinementModule(inplanes, 128,self.avg_pool_kernel_size_32)
  111. self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  112. self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
  113. self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
  114. else:
  115. print("backbone is not in backbone lists")
  116. exit(0)
  117. self.init_weight()
  118. def forward(self, x):
  119. H0, W0 = x.size()[2:]
  120. feat2, feat4, feat8, feat16, feat32 = self.backbone(x)
  121. H8, W8 = feat8.size()[2:]
  122. H16, W16 = feat16.size()[2:]
  123. H32, W32 = feat32.size()[2:]
  124. #avg = F.avg_pool2d(feat32, feat32.size()[2:])
  125. #print('line147:self.avg_pool_kernel_size_32:',self.avg_pool_kernel_size_32)
  126. avg = F.avg_pool2d(feat32, self.avg_pool_kernel_size_32)
  127. #print('------------------newline140:','out:','out;',avg.size(),' in:',feat32.size())
  128. avg = self.conv_avg(avg)
  129. avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
  130. #print('------------line143,arm32:',feat32.size())
  131. feat32_arm = self.arm32(feat32)
  132. feat32_sum = feat32_arm + avg_up
  133. feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
  134. feat32_up = self.conv_head32(feat32_up)
  135. #print('------------line148,arm16:',feat16.size())
  136. feat16_arm = self.arm16(feat16)
  137. feat16_sum = feat16_arm + feat32_up
  138. feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
  139. feat16_up = self.conv_head16(feat16_up)
  140. return feat2, feat4, feat8, feat16, feat16_up, feat32_up # x8, x16
  141. # return feat8, feat16_up # x8, x16
  142. def init_weight(self):
  143. for ly in self.children():
  144. if isinstance(ly, nn.Conv2d):
  145. nn.init.kaiming_normal_(ly.weight, a=1)
  146. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  147. def get_params(self):
  148. wd_params, nowd_params = [], []
  149. for name, module in self.named_modules():
  150. if isinstance(module, (nn.Linear, nn.Conv2d)):
  151. wd_params.append(module.weight)
  152. if not module.bias is None:
  153. nowd_params.append(module.bias)
  154. elif isinstance(module, nn.BatchNorm2d):#################3
  155. nowd_params += list(module.parameters())
  156. return wd_params, nowd_params
  157. class FeatureFusionModule(nn.Module):
  158. def __init__(self, in_chan, out_chan, *args, **kwargs):
  159. super(FeatureFusionModule, self).__init__()
  160. self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
  161. self.avg_pool_kernel_size=[ int(modelSize[0]/8+0.9999), int( modelSize[1]/8+0.9999 ) ]
  162. self.conv1 = nn.Conv2d(out_chan,
  163. out_chan//4,
  164. kernel_size = 1,
  165. stride = 1,
  166. padding = 0,
  167. bias = False)
  168. self.conv2 = nn.Conv2d(out_chan//4,
  169. out_chan,
  170. kernel_size = 1,
  171. stride = 1,
  172. padding = 0,
  173. bias = False)
  174. self.relu = nn.ReLU(inplace=True)
  175. self.sigmoid = nn.Sigmoid()
  176. self.init_weight()
  177. def forward(self, fsp, fcp):
  178. fcat = torch.cat([fsp, fcp], dim=1)
  179. feat = self.convblk(fcat)
  180. #atten = F.avg_pool2d(feat, feat.size()[2:])
  181. atten = F.avg_pool2d(feat, kernel_size=self.avg_pool_kernel_size)
  182. #print('------------------newline199:',' out:',atten.size(),'in:',feat.size())
  183. atten = self.conv1(atten)
  184. atten = self.relu(atten)
  185. atten = self.conv2(atten)
  186. atten = self.sigmoid(atten)
  187. feat_atten = torch.mul(feat, atten)
  188. feat_out = feat_atten + feat
  189. return feat_out
  190. def init_weight(self):
  191. for ly in self.children():
  192. if isinstance(ly, nn.Conv2d):
  193. nn.init.kaiming_normal_(ly.weight, a=1)
  194. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  195. def get_params(self):
  196. wd_params, nowd_params = [], []
  197. for name, module in self.named_modules():
  198. if isinstance(module, (nn.Linear, nn.Conv2d)):
  199. wd_params.append(module.weight)
  200. if not module.bias is None:
  201. nowd_params.append(module.bias)
  202. elif isinstance(module, nn.BatchNorm2d):##################4
  203. nowd_params += list(module.parameters())
  204. return wd_params, nowd_params
  205. class BiSeNet_STDC(nn.Module):
  206. def __init__(self, backbone, n_classes, pretrain_model='', use_boundary_2=False, use_boundary_4=False,
  207. use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  208. super(BiSeNet_STDC, self).__init__()
  209. self.use_boundary_2 = use_boundary_2
  210. self.use_boundary_4 = use_boundary_4
  211. self.use_boundary_8 = use_boundary_8
  212. self.use_boundary_16 = use_boundary_16
  213. # self.heat_map = heat_map
  214. self.cp = ContextPath(backbone, pretrain_model, use_conv_last=use_conv_last)
  215. if backbone == 'STDCNet1446':
  216. conv_out_inplanes = 128
  217. sp2_inplanes = 32
  218. sp4_inplanes = 64
  219. sp8_inplanes = 256
  220. sp16_inplanes = 512
  221. inplane = sp8_inplanes + conv_out_inplanes
  222. elif backbone == 'STDCNet813':
  223. conv_out_inplanes = 128
  224. sp2_inplanes = 32
  225. sp4_inplanes = 64
  226. sp8_inplanes = 256
  227. sp16_inplanes = 512
  228. inplane = sp8_inplanes + conv_out_inplanes
  229. else:
  230. print("backbone is not in backbone lists")
  231. exit(0)
  232. self.ffm = FeatureFusionModule(inplane, 256)
  233. self.conv_out = BiSeNetOutput(256, 256, n_classes)
  234. self.conv_out16 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)
  235. self.conv_out32 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)
  236. self.conv_out_sp16 = BiSeNetOutput(sp16_inplanes, 64, 1)
  237. self.conv_out_sp8 = BiSeNetOutput(sp8_inplanes, 64, 1)
  238. self.conv_out_sp4 = BiSeNetOutput(sp4_inplanes, 64, 1)
  239. self.conv_out_sp2 = BiSeNetOutput(sp2_inplanes, 64, 1)
  240. self.init_weight()
  241. def forward(self, x):
  242. H, W = x.size()[2:]
  243. # time_0 = time.time()
  244. # feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x)
  245. feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x)
  246. # print('----backbone', (time.time() - time_0) * 1000)
  247. # feat_out_sp2 = self.conv_out_sp2(feat_res2)
  248. #
  249. # feat_out_sp4 = self.conv_out_sp4(feat_res4)
  250. #
  251. # feat_out_sp8 = self.conv_out_sp8(feat_res8)
  252. #
  253. # feat_out_sp16 = self.conv_out_sp16(feat_res16)
  254. # time_1 = time.time()
  255. feat_fuse = self.ffm(feat_res8, feat_cp8)
  256. # print('----ffm', (time.time() - time_1) * 1000)
  257. # time_2 = time.time()
  258. feat_out = self.conv_out(feat_fuse)
  259. # feat_out16 = self.conv_out16(feat_cp8)
  260. # feat_out32 = self.conv_out32(feat_cp16)
  261. feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
  262. # print('----conv_out', (time.time() - time_2) * 1000)
  263. # feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
  264. # feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
  265. # if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8:
  266. # return feat_out, feat_out16, feat_out32, feat_out_sp2, feat_out_sp4, feat_out_sp8
  267. #
  268. # if (not self.use_boundary_2) and self.use_boundary_4 and self.use_boundary_8:
  269. # return feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8
  270. #
  271. # if (not self.use_boundary_2) and (not self.use_boundary_4) and self.use_boundary_8:
  272. return feat_out
  273. # if (not self.use_boundary_2) and (not self.use_boundary_4) and (not self.use_boundary_8):
  274. # return feat_out, feat_out16, feat_out32
  275. def init_weight(self):
  276. for ly in self.children():
  277. if isinstance(ly, nn.Conv2d):
  278. nn.init.kaiming_normal_(ly.weight, a=1)
  279. if not ly.bias is None: nn.init.constant_(ly.bias, 0)
  280. def get_params(self):
  281. wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
  282. for name, child in self.named_children():
  283. child_wd_params, child_nowd_params = child.get_params()
  284. if isinstance(child, (FeatureFusionModule, BiSeNetOutput)):
  285. lr_mul_wd_params += child_wd_params
  286. lr_mul_nowd_params += child_nowd_params
  287. else:
  288. wd_params += child_wd_params
  289. nowd_params += child_nowd_params
  290. return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
  291. if __name__ == "__main__":
  292. net = BiSeNet('STDCNet813', 19)
  293. net.cuda()
  294. net.eval()
  295. in_ten = torch.randn(1, 3, 768, 1536).cuda()
  296. out, out16, out32 = net(in_ten)
  297. print(out.shape)
  298. # torch.save(net.state_dict(), 'STDCNet813.pth')###