用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

244 lines
9.0KB

  1. """Efficient Neural Network"""
  2. import torch
  3. import torch.nn as nn
  4. __all__ = ['ENet', 'get_enet', 'get_enet_citys']
  5. class ENet(nn.Module):
  6. """Efficient Neural Network"""
  7. def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=None, **kwargs):
  8. super(ENet, self).__init__()
  9. self.initial = InitialBlock(13, **kwargs)
  10. self.bottleneck1_0 = Bottleneck(16, 16, 64, downsampling=True, **kwargs)
  11. self.bottleneck1_1 = Bottleneck(64, 16, 64, **kwargs)
  12. self.bottleneck1_2 = Bottleneck(64, 16, 64, **kwargs)
  13. self.bottleneck1_3 = Bottleneck(64, 16, 64, **kwargs)
  14. self.bottleneck1_4 = Bottleneck(64, 16, 64, **kwargs)
  15. self.bottleneck2_0 = Bottleneck(64, 32, 128, downsampling=True, **kwargs)
  16. self.bottleneck2_1 = Bottleneck(128, 32, 128, **kwargs)
  17. self.bottleneck2_2 = Bottleneck(128, 32, 128, dilation=2, **kwargs)
  18. self.bottleneck2_3 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
  19. self.bottleneck2_4 = Bottleneck(128, 32, 128, dilation=4, **kwargs)
  20. self.bottleneck2_5 = Bottleneck(128, 32, 128, **kwargs)
  21. self.bottleneck2_6 = Bottleneck(128, 32, 128, dilation=8, **kwargs)
  22. self.bottleneck2_7 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
  23. self.bottleneck2_8 = Bottleneck(128, 32, 128, dilation=16, **kwargs)
  24. self.bottleneck3_1 = Bottleneck(128, 32, 128, **kwargs)
  25. self.bottleneck3_2 = Bottleneck(128, 32, 128, dilation=2, **kwargs)
  26. self.bottleneck3_3 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
  27. self.bottleneck3_4 = Bottleneck(128, 32, 128, dilation=4, **kwargs)
  28. self.bottleneck3_5 = Bottleneck(128, 32, 128, **kwargs)
  29. self.bottleneck3_6 = Bottleneck(128, 32, 128, dilation=8, **kwargs)
  30. self.bottleneck3_7 = Bottleneck(128, 32, 128, asymmetric=True, **kwargs)
  31. self.bottleneck3_8 = Bottleneck(128, 32, 128, dilation=16, **kwargs)
  32. self.bottleneck4_0 = UpsamplingBottleneck(128, 16, 64, **kwargs)
  33. self.bottleneck4_1 = Bottleneck(64, 16, 64, **kwargs)
  34. self.bottleneck4_2 = Bottleneck(64, 16, 64, **kwargs)
  35. self.bottleneck5_0 = UpsamplingBottleneck(64, 4, 16, **kwargs)
  36. self.bottleneck5_1 = Bottleneck(16, 4, 16, **kwargs)
  37. self.fullconv = nn.ConvTranspose2d(16, nclass, 2, 2, bias=False)
  38. self.__setattr__('exclusive', ['bottleneck1_0', 'bottleneck1_1', 'bottleneck1_2', 'bottleneck1_3',
  39. 'bottleneck1_4', 'bottleneck2_0', 'bottleneck2_1', 'bottleneck2_2',
  40. 'bottleneck2_3', 'bottleneck2_4', 'bottleneck2_5', 'bottleneck2_6',
  41. 'bottleneck2_7', 'bottleneck2_8', 'bottleneck3_1', 'bottleneck3_2',
  42. 'bottleneck3_3', 'bottleneck3_4', 'bottleneck3_5', 'bottleneck3_6',
  43. 'bottleneck3_7', 'bottleneck3_8', 'bottleneck4_0', 'bottleneck4_1',
  44. 'bottleneck4_2', 'bottleneck5_0', 'bottleneck5_1', 'fullconv'])
  45. def forward(self, x):
  46. # init
  47. x = self.initial(x)
  48. # stage 1
  49. x, max_indices1 = self.bottleneck1_0(x)
  50. x = self.bottleneck1_1(x)
  51. x = self.bottleneck1_2(x)
  52. x = self.bottleneck1_3(x)
  53. x = self.bottleneck1_4(x)
  54. # stage 2
  55. x, max_indices2 = self.bottleneck2_0(x)
  56. x = self.bottleneck2_1(x)
  57. x = self.bottleneck2_2(x)
  58. x = self.bottleneck2_3(x)
  59. x = self.bottleneck2_4(x)
  60. x = self.bottleneck2_5(x)
  61. x = self.bottleneck2_6(x)
  62. x = self.bottleneck2_7(x)
  63. x = self.bottleneck2_8(x)
  64. # stage 3
  65. x = self.bottleneck3_1(x)
  66. x = self.bottleneck3_2(x)
  67. x = self.bottleneck3_3(x)
  68. x = self.bottleneck3_4(x)
  69. x = self.bottleneck3_6(x)
  70. x = self.bottleneck3_7(x)
  71. x = self.bottleneck3_8(x)
  72. # stage 4
  73. x = self.bottleneck4_0(x, max_indices2)
  74. x = self.bottleneck4_1(x)
  75. x = self.bottleneck4_2(x)
  76. # stage 5
  77. x = self.bottleneck5_0(x, max_indices1)
  78. x = self.bottleneck5_1(x)
  79. # out
  80. x = self.fullconv(x)
  81. return tuple([x])
  82. class InitialBlock(nn.Module):
  83. """ENet initial block"""
  84. def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
  85. super(InitialBlock, self).__init__()
  86. self.conv = nn.Conv2d(3, out_channels, 3, 2, 1, bias=False)
  87. self.maxpool = nn.MaxPool2d(2, 2)
  88. self.bn = norm_layer(out_channels + 3)
  89. self.act = nn.PReLU()
  90. def forward(self, x):
  91. x_conv = self.conv(x)
  92. x_pool = self.maxpool(x)
  93. x = torch.cat([x_conv, x_pool], dim=1)
  94. x = self.bn(x)
  95. x = self.act(x)
  96. return x
  97. class Bottleneck(nn.Module):
  98. """Bottlenecks include regular, asymmetric, downsampling, dilated"""
  99. def __init__(self, in_channels, inter_channels, out_channels, dilation=1, asymmetric=False,
  100. downsampling=False, norm_layer=nn.BatchNorm2d, **kwargs):
  101. super(Bottleneck, self).__init__()
  102. self.downsamping = downsampling
  103. if downsampling:
  104. self.maxpool = nn.MaxPool2d(2, 2, return_indices=True)
  105. self.conv_down = nn.Sequential(
  106. nn.Conv2d(in_channels, out_channels, 1, bias=False),
  107. norm_layer(out_channels)
  108. )
  109. self.conv1 = nn.Sequential(
  110. nn.Conv2d(in_channels, inter_channels, 1, bias=False),
  111. norm_layer(inter_channels),
  112. nn.PReLU()
  113. )
  114. if downsampling:
  115. self.conv2 = nn.Sequential(
  116. nn.Conv2d(inter_channels, inter_channels, 2, stride=2, bias=False),
  117. norm_layer(inter_channels),
  118. nn.PReLU()
  119. )
  120. else:
  121. if asymmetric:
  122. self.conv2 = nn.Sequential(
  123. nn.Conv2d(inter_channels, inter_channels, (5, 1), padding=(2, 0), bias=False),
  124. nn.Conv2d(inter_channels, inter_channels, (1, 5), padding=(0, 2), bias=False),
  125. norm_layer(inter_channels),
  126. nn.PReLU()
  127. )
  128. else:
  129. self.conv2 = nn.Sequential(
  130. nn.Conv2d(inter_channels, inter_channels, 3, dilation=dilation, padding=dilation, bias=False),
  131. norm_layer(inter_channels),
  132. nn.PReLU()
  133. )
  134. self.conv3 = nn.Sequential(
  135. nn.Conv2d(inter_channels, out_channels, 1, bias=False),
  136. norm_layer(out_channels),
  137. nn.Dropout2d(0.1)
  138. )
  139. self.act = nn.PReLU()
  140. def forward(self, x):
  141. identity = x
  142. if self.downsamping:
  143. identity, max_indices = self.maxpool(identity)
  144. identity = self.conv_down(identity)
  145. out = self.conv1(x)
  146. out = self.conv2(out)
  147. out = self.conv3(out)
  148. out = self.act(out + identity)
  149. if self.downsamping:
  150. return out, max_indices
  151. else:
  152. return out
  153. class UpsamplingBottleneck(nn.Module):
  154. """upsampling Block"""
  155. def __init__(self, in_channels, inter_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
  156. super(UpsamplingBottleneck, self).__init__()
  157. self.conv = nn.Sequential(
  158. nn.Conv2d(in_channels, out_channels, 1, bias=False),
  159. norm_layer(out_channels)
  160. )
  161. self.upsampling = nn.MaxUnpool2d(2)
  162. self.block = nn.Sequential(
  163. nn.Conv2d(in_channels, inter_channels, 1, bias=False),
  164. norm_layer(inter_channels),
  165. nn.PReLU(),
  166. nn.ConvTranspose2d(inter_channels, inter_channels, 2, 2, bias=False),
  167. norm_layer(inter_channels),
  168. nn.PReLU(),
  169. nn.Conv2d(inter_channels, out_channels, 1, bias=False),
  170. norm_layer(out_channels),
  171. nn.Dropout2d(0.1)
  172. )
  173. self.act = nn.PReLU()
  174. def forward(self, x, max_indices):
  175. out_up = self.conv(x)
  176. out_up = self.upsampling(out_up, max_indices)
  177. out_ext = self.block(x)
  178. out = self.act(out_up + out_ext)
  179. return out
  180. def get_enet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models', pretrained_base=True, **kwargs):
  181. acronyms = {
  182. 'pascal_voc': 'pascal_voc',
  183. 'pascal_aug': 'pascal_aug',
  184. 'ade20k': 'ade',
  185. 'coco': 'coco',
  186. 'citys': 'citys',
  187. }
  188. from core.data.dataloader import datasets
  189. model = ENet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  190. if pretrained:
  191. from .model_store import get_model_file
  192. device = torch.device(kwargs['local_rank'])
  193. model.load_state_dict(torch.load(get_model_file('enet_%s' % (acronyms[dataset]), root=root),
  194. map_location=device))
  195. return model
  196. def get_enet_citys(**kwargs):
  197. return get_enet('citys', '', **kwargs)
  198. if __name__ == '__main__':
  199. img = torch.randn(1, 3, 512, 512)
  200. model = get_enet_citys()
  201. output = model(img)