用kafka接收消息
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

303 lines
11KB

  1. import torch
  2. import torch.nn as nn
  3. from torch.nn import init
  4. import math
  5. class ConvX(nn.Module):
  6. def __init__(self, in_planes, out_planes, kernel=3, stride=1):
  7. super(ConvX, self).__init__()
  8. self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False)
  9. self.bn = nn.BatchNorm2d(out_planes)
  10. self.relu = nn.ReLU(inplace=True)
  11. def forward(self, x):
  12. out = self.relu(self.bn(self.conv(x)))
  13. return out
  14. class AddBottleneck(nn.Module):
  15. def __init__(self, in_planes, out_planes, block_num=3, stride=1):
  16. super(AddBottleneck, self).__init__()
  17. assert block_num > 1, print("block number should be larger than 1.")
  18. self.conv_list = nn.ModuleList()
  19. self.stride = stride
  20. if stride == 2:
  21. self.avd_layer = nn.Sequential(
  22. nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
  23. nn.BatchNorm2d(out_planes//2),
  24. )
  25. self.skip = nn.Sequential(
  26. nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes, bias=False),
  27. nn.BatchNorm2d(in_planes),
  28. nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
  29. nn.BatchNorm2d(out_planes),
  30. )
  31. stride = 1
  32. for idx in range(block_num):
  33. if idx == 0:
  34. self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
  35. elif idx == 1 and block_num == 2:
  36. self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))
  37. elif idx == 1 and block_num > 2:
  38. self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))
  39. elif idx < block_num - 1:
  40. self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1))))
  41. else:
  42. self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx))))
  43. def forward(self, x):
  44. out_list = []
  45. out = x
  46. for idx, conv in enumerate(self.conv_list):
  47. if idx == 0 and self.stride == 2:
  48. out = self.avd_layer(conv(out))
  49. else:
  50. out = conv(out)
  51. out_list.append(out)
  52. if self.stride == 2:
  53. x = self.skip(x)
  54. return torch.cat(out_list, dim=1) + x
  55. class CatBottleneck(nn.Module):
  56. def __init__(self, in_planes, out_planes, block_num=3, stride=1):
  57. super(CatBottleneck, self).__init__()
  58. assert block_num > 1, print("block number should be larger than 1.")
  59. self.conv_list = nn.ModuleList()
  60. self.stride = stride
  61. if stride == 2:
  62. self.avd_layer = nn.Sequential(
  63. nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
  64. nn.BatchNorm2d(out_planes//2),
  65. )
  66. self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
  67. stride = 1
  68. for idx in range(block_num):
  69. if idx == 0:
  70. self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
  71. elif idx == 1 and block_num == 2:
  72. self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))
  73. elif idx == 1 and block_num > 2:
  74. self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))
  75. elif idx < block_num - 1:
  76. self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1))))
  77. else:
  78. self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx))))
  79. def forward(self, x):
  80. out_list = []
  81. out1 = self.conv_list[0](x)
  82. for idx, conv in enumerate(self.conv_list[1:]):
  83. if idx == 0:
  84. if self.stride == 2:
  85. out = conv(self.avd_layer(out1))
  86. else:
  87. out = conv(out1)
  88. else:
  89. out = conv(out)
  90. out_list.append(out)
  91. if self.stride == 2:
  92. out1 = self.skip(out1)
  93. out_list.insert(0, out1)
  94. out = torch.cat(out_list, dim=1)
  95. return out
  96. #STDC2Net
  97. class STDCNet1446(nn.Module):
  98. def __init__(self, base=64, layers=[4,5,3], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False):
  99. super(STDCNet1446, self).__init__()
  100. if type == "cat":
  101. block = CatBottleneck
  102. elif type == "add":
  103. block = AddBottleneck
  104. self.use_conv_last = use_conv_last
  105. self.features = self._make_layers(base, layers, block_num, block)
  106. self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1)
  107. self.gap = nn.AdaptiveAvgPool2d(1)
  108. self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False)
  109. self.bn = nn.BatchNorm1d(max(1024, base*16))
  110. self.relu = nn.ReLU(inplace=True)
  111. self.dropout = nn.Dropout(p=dropout)
  112. self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False)
  113. self.x2 = nn.Sequential(self.features[:1])
  114. self.x4 = nn.Sequential(self.features[1:2])
  115. self.x8 = nn.Sequential(self.features[2:6])
  116. self.x16 = nn.Sequential(self.features[6:11])
  117. self.x32 = nn.Sequential(self.features[11:])
  118. if pretrain_model:
  119. print('use pretrain model {}'.format(pretrain_model))
  120. self.init_weight(pretrain_model)
  121. else:
  122. self.init_params()
  123. def init_weight(self, pretrain_model):
  124. state_dict = torch.load(pretrain_model)["state_dict"]
  125. self_state_dict = self.state_dict()
  126. for k, v in state_dict.items():
  127. self_state_dict.update({k: v})
  128. self.load_state_dict(self_state_dict)
  129. def init_params(self):
  130. for m in self.modules():
  131. if isinstance(m, nn.Conv2d):
  132. init.kaiming_normal_(m.weight, mode='fan_out')
  133. if m.bias is not None:
  134. init.constant_(m.bias, 0)
  135. elif isinstance(m, nn.BatchNorm2d):
  136. init.constant_(m.weight, 1)
  137. init.constant_(m.bias, 0)
  138. elif isinstance(m, nn.Linear):
  139. init.normal_(m.weight, std=0.001)
  140. if m.bias is not None:
  141. init.constant_(m.bias, 0)
  142. def _make_layers(self, base, layers, block_num, block):
  143. features = []
  144. features += [ConvX(3, base//2, 3, 2)]
  145. features += [ConvX(base//2, base, 3, 2)]
  146. for i, layer in enumerate(layers):
  147. for j in range(layer):
  148. if i == 0 and j == 0:
  149. features.append(block(base, base*4, block_num, 2))
  150. elif j == 0:
  151. features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
  152. else:
  153. features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))
  154. return nn.Sequential(*features)
  155. def forward(self, x):
  156. feat2 = self.x2(x)
  157. feat4 = self.x4(feat2)
  158. feat8 = self.x8(feat4)
  159. feat16 = self.x16(feat8)
  160. feat32 = self.x32(feat16)
  161. if self.use_conv_last:
  162. feat32 = self.conv_last(feat32)
  163. return feat2, feat4, feat8, feat16, feat32
  164. def forward_impl(self, x):
  165. out = self.features(x)
  166. out = self.conv_last(out).pow(2)
  167. out = self.gap(out).flatten(1)
  168. out = self.fc(out)
  169. # out = self.bn(out)
  170. out = self.relu(out)
  171. # out = self.relu(self.bn(self.fc(out)))
  172. out = self.dropout(out)
  173. out = self.linear(out)
  174. return out
  175. # STDC1Net
  176. class STDCNet813(nn.Module):
  177. def __init__(self, base=64, layers=[2,2,2], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False):
  178. super(STDCNet813, self).__init__()
  179. if type == "cat":
  180. block = CatBottleneck
  181. elif type == "add":
  182. block = AddBottleneck
  183. self.use_conv_last = use_conv_last
  184. self.features = self._make_layers(base, layers, block_num, block)
  185. self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1)
  186. self.gap = nn.AdaptiveAvgPool2d(1)
  187. self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False)
  188. self.bn = nn.BatchNorm1d(max(1024, base*16))
  189. self.relu = nn.ReLU(inplace=True)
  190. self.dropout = nn.Dropout(p=dropout)
  191. self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False)
  192. self.x2 = nn.Sequential(self.features[:1])
  193. self.x4 = nn.Sequential(self.features[1:2])
  194. self.x8 = nn.Sequential(self.features[2:4])
  195. self.x16 = nn.Sequential(self.features[4:6])
  196. self.x32 = nn.Sequential(self.features[6:])
  197. if pretrain_model:
  198. print('use pretrain model {}'.format(pretrain_model))
  199. self.init_weight(pretrain_model)
  200. else:
  201. self.init_params()
  202. def init_weight(self, pretrain_model):
  203. state_dict = torch.load(pretrain_model)["state_dict"]
  204. self_state_dict = self.state_dict()
  205. for k, v in state_dict.items():
  206. self_state_dict.update({k: v})
  207. self.load_state_dict(self_state_dict)
  208. def init_params(self):
  209. for m in self.modules():
  210. if isinstance(m, nn.Conv2d):
  211. init.kaiming_normal_(m.weight, mode='fan_out')
  212. if m.bias is not None:
  213. init.constant_(m.bias, 0)
  214. elif isinstance(m, nn.BatchNorm2d):
  215. init.constant_(m.weight, 1)
  216. init.constant_(m.bias, 0)
  217. elif isinstance(m, nn.Linear):
  218. init.normal_(m.weight, std=0.001)
  219. if m.bias is not None:
  220. init.constant_(m.bias, 0)
  221. def _make_layers(self, base, layers, block_num, block):
  222. features = []
  223. features += [ConvX(3, base//2, 3, 2)]
  224. features += [ConvX(base//2, base, 3, 2)]
  225. for i, layer in enumerate(layers):
  226. for j in range(layer):
  227. if i == 0 and j == 0:
  228. features.append(block(base, base*4, block_num, 2))
  229. elif j == 0:
  230. features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
  231. else:
  232. features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))
  233. return nn.Sequential(*features)
  234. def forward(self, x):
  235. feat2 = self.x2(x)
  236. feat4 = self.x4(feat2)
  237. feat8 = self.x8(feat4)
  238. feat16 = self.x16(feat8)
  239. feat32 = self.x32(feat16)
  240. if self.use_conv_last:
  241. feat32 = self.conv_last(feat32)
  242. return feat2, feat4, feat8, feat16, feat32
  243. def forward_impl(self, x):
  244. out = self.features(x)
  245. out = self.conv_last(out).pow(2)
  246. out = self.gap(out).flatten(1)
  247. out = self.fc(out)
  248. # out = self.bn(out)
  249. out = self.relu(out)
  250. # out = self.relu(self.bn(self.fc(out)))
  251. out = self.dropout(out)
  252. out = self.linear(out)
  253. return out
  254. if __name__ == "__main__":
  255. model = STDCNet813(num_classes=1000, dropout=0.00, block_num=4)
  256. model.eval()
  257. x = torch.randn(1,3,224,224)
  258. y = model(x)
  259. torch.save(model.state_dict(), 'cat.pth')
  260. print(y.size())