用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dinknet.py 12KB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. """
  2. Codes of LinkNet based on https://github.com/snakers4/spacenet-three
  3. """
  4. import torch
  5. import torch.nn as nn
  6. from torch.autograd import Variable
  7. from torchvision import models
  8. import torch.nn.functional as F
  9. from functools import partial
  10. nonlinearity = partial(F.relu,inplace=True)
  11. class Dblock_more_dilate(nn.Module):
  12. def __init__(self,channel):
  13. super(Dblock_more_dilate, self).__init__()
  14. self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
  15. self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
  16. self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
  17. self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
  18. self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
  19. for m in self.modules():
  20. if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
  21. if m.bias is not None:
  22. m.bias.data.zero_()
  23. def forward(self, x):
  24. dilate1_out = nonlinearity(self.dilate1(x))
  25. dilate2_out = nonlinearity(self.dilate2(dilate1_out))
  26. dilate3_out = nonlinearity(self.dilate3(dilate2_out))
  27. dilate4_out = nonlinearity(self.dilate4(dilate3_out))
  28. dilate5_out = nonlinearity(self.dilate5(dilate4_out))
  29. out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out + dilate5_out
  30. return out
  31. class Dblock(nn.Module):
  32. def __init__(self,channel):
  33. super(Dblock, self).__init__()
  34. self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
  35. self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=2, padding=2)
  36. self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=4, padding=4)
  37. self.dilate4 = nn.Conv2d(channel, channel, kernel_size=3, dilation=8, padding=8)
  38. #self.dilate5 = nn.Conv2d(channel, channel, kernel_size=3, dilation=16, padding=16)
  39. for m in self.modules():
  40. if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
  41. if m.bias is not None:
  42. m.bias.data.zero_()
  43. def forward(self, x):
  44. dilate1_out = nonlinearity(self.dilate1(x))
  45. dilate2_out = nonlinearity(self.dilate2(dilate1_out))
  46. dilate3_out = nonlinearity(self.dilate3(dilate2_out))
  47. dilate4_out = nonlinearity(self.dilate4(dilate3_out))
  48. #dilate5_out = nonlinearity(self.dilate5(dilate4_out))
  49. out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out# + dilate5_out
  50. return out
  51. class DecoderBlock(nn.Module):
  52. def __init__(self, in_channels, n_filters):
  53. super(DecoderBlock,self).__init__()
  54. self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
  55. self.norm1 = nn.BatchNorm2d(in_channels // 4)
  56. self.relu1 = nonlinearity
  57. self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1)
  58. self.norm2 = nn.BatchNorm2d(in_channels // 4)
  59. self.relu2 = nonlinearity
  60. self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
  61. self.norm3 = nn.BatchNorm2d(n_filters)
  62. self.relu3 = nonlinearity
  63. def forward(self, x):
  64. x = self.conv1(x)
  65. x = self.norm1(x)
  66. x = self.relu1(x)
  67. x = self.deconv2(x)
  68. x = self.norm2(x)
  69. x = self.relu2(x)
  70. x = self.conv3(x)
  71. x = self.norm3(x)
  72. x = self.relu3(x)
  73. return x
  74. class DinkNet34_less_pool(nn.Module):
  75. def __init__(self, num_classes=1):
  76. super(DinkNet34_more_dilate, self).__init__()
  77. filters = [64, 128, 256, 512]
  78. resnet = models.resnet34(pretrained=True)
  79. self.firstconv = resnet.conv1
  80. self.firstbn = resnet.bn1
  81. self.firstrelu = resnet.relu
  82. self.firstmaxpool = resnet.maxpool
  83. self.encoder1 = resnet.layer1
  84. self.encoder2 = resnet.layer2
  85. self.encoder3 = resnet.layer3
  86. self.dblock = Dblock_more_dilate(256)
  87. self.decoder3 = DecoderBlock(filters[2], filters[1])
  88. self.decoder2 = DecoderBlock(filters[1], filters[0])
  89. self.decoder1 = DecoderBlock(filters[0], filters[0])
  90. self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
  91. self.finalrelu1 = nonlinearity
  92. self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
  93. self.finalrelu2 = nonlinearity
  94. self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
  95. def forward(self, x):
  96. # Encoder
  97. x = self.firstconv(x)
  98. x = self.firstbn(x)
  99. x = self.firstrelu(x)
  100. x = self.firstmaxpool(x)
  101. e1 = self.encoder1(x)
  102. e2 = self.encoder2(e1)
  103. e3 = self.encoder3(e2)
  104. #Center
  105. e3 = self.dblock(e3)
  106. # Decoder
  107. d3 = self.decoder3(e3) + e2
  108. d2 = self.decoder2(d3) + e1
  109. d1 = self.decoder1(d2)
  110. # Final Classification
  111. out = self.finaldeconv1(d1)
  112. out = self.finalrelu1(out)
  113. out = self.finalconv2(out)
  114. out = self.finalrelu2(out)
  115. out = self.finalconv3(out)
  116. #return F.sigmoid(out)
  117. return out
  118. class DinkNet34(nn.Module):
  119. def __init__(self, num_classes=1, num_channels=3):
  120. super(DinkNet34, self).__init__()
  121. filters = [64, 128, 256, 512]
  122. resnet = models.resnet34(pretrained=True)
  123. self.firstconv = resnet.conv1
  124. self.firstbn = resnet.bn1
  125. self.firstrelu = resnet.relu
  126. self.firstmaxpool = resnet.maxpool
  127. self.encoder1 = resnet.layer1
  128. self.encoder2 = resnet.layer2
  129. self.encoder3 = resnet.layer3
  130. self.encoder4 = resnet.layer4
  131. self.dblock = Dblock(512)
  132. self.decoder4 = DecoderBlock(filters[3], filters[2])
  133. self.decoder3 = DecoderBlock(filters[2], filters[1])
  134. self.decoder2 = DecoderBlock(filters[1], filters[0])
  135. self.decoder1 = DecoderBlock(filters[0], filters[0])
  136. self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
  137. self.finalrelu1 = nonlinearity
  138. self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
  139. self.finalrelu2 = nonlinearity
  140. self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
  141. def forward(self, x):
  142. # Encoder
  143. x = self.firstconv(x)
  144. x = self.firstbn(x)
  145. x = self.firstrelu(x)
  146. x = self.firstmaxpool(x)
  147. e1 = self.encoder1(x)
  148. e2 = self.encoder2(e1)
  149. e3 = self.encoder3(e2)
  150. e4 = self.encoder4(e3)
  151. # Center
  152. e4 = self.dblock(e4)
  153. # Decoder
  154. d4 = self.decoder4(e4) + e3
  155. d3 = self.decoder3(d4) + e2
  156. d2 = self.decoder2(d3) + e1
  157. d1 = self.decoder1(d2)
  158. out = self.finaldeconv1(d1)
  159. out = self.finalrelu1(out)
  160. out = self.finalconv2(out)
  161. out = self.finalrelu2(out)
  162. out = self.finalconv3(out)
  163. #return F.sigmoid(out)
  164. return out
  165. class DinkNet50(nn.Module):
  166. def __init__(self, num_classes=1):
  167. super(DinkNet50, self).__init__()
  168. filters = [256, 512, 1024, 2048]
  169. resnet = models.resnet50(pretrained=True)
  170. self.firstconv = resnet.conv1
  171. self.firstbn = resnet.bn1
  172. self.firstrelu = resnet.relu
  173. self.firstmaxpool = resnet.maxpool
  174. self.encoder1 = resnet.layer1
  175. self.encoder2 = resnet.layer2
  176. self.encoder3 = resnet.layer3
  177. self.encoder4 = resnet.layer4
  178. self.dblock = Dblock_more_dilate(2048)
  179. self.decoder4 = DecoderBlock(filters[3], filters[2])
  180. self.decoder3 = DecoderBlock(filters[2], filters[1])
  181. self.decoder2 = DecoderBlock(filters[1], filters[0])
  182. self.decoder1 = DecoderBlock(filters[0], filters[0])
  183. self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
  184. self.finalrelu1 = nonlinearity
  185. self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
  186. self.finalrelu2 = nonlinearity
  187. self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
  188. def forward(self, x):
  189. # Encoder
  190. x = self.firstconv(x)
  191. x = self.firstbn(x)
  192. x = self.firstrelu(x)
  193. x = self.firstmaxpool(x)
  194. e1 = self.encoder1(x)
  195. e2 = self.encoder2(e1)
  196. e3 = self.encoder3(e2)
  197. e4 = self.encoder4(e3)
  198. # Center
  199. e4 = self.dblock(e4)
  200. # Decoder
  201. d4 = self.decoder4(e4) + e3
  202. d3 = self.decoder3(d4) + e2
  203. d2 = self.decoder2(d3) + e1
  204. d1 = self.decoder1(d2)
  205. out = self.finaldeconv1(d1)
  206. out = self.finalrelu1(out)
  207. out = self.finalconv2(out)
  208. out = self.finalrelu2(out)
  209. out = self.finalconv3(out)
  210. #return F.sigmoid(out)
  211. return out
  212. class DinkNet101(nn.Module):
  213. def __init__(self, num_classes=1):
  214. super(DinkNet101, self).__init__()
  215. filters = [256, 512, 1024, 2048]
  216. resnet = models.resnet101(pretrained=True)
  217. self.firstconv = resnet.conv1
  218. self.firstbn = resnet.bn1
  219. self.firstrelu = resnet.relu
  220. self.firstmaxpool = resnet.maxpool
  221. self.encoder1 = resnet.layer1
  222. self.encoder2 = resnet.layer2
  223. self.encoder3 = resnet.layer3
  224. self.encoder4 = resnet.layer4
  225. self.dblock = Dblock_more_dilate(2048)
  226. self.decoder4 = DecoderBlock(filters[3], filters[2])
  227. self.decoder3 = DecoderBlock(filters[2], filters[1])
  228. self.decoder2 = DecoderBlock(filters[1], filters[0])
  229. self.decoder1 = DecoderBlock(filters[0], filters[0])
  230. self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
  231. self.finalrelu1 = nonlinearity
  232. self.finalconv2 = nn.Conv2d(32, 32, 3, padding=1)
  233. self.finalrelu2 = nonlinearity
  234. self.finalconv3 = nn.Conv2d(32, num_classes, 3, padding=1)
  235. def forward(self, x):
  236. # Encoder
  237. x = self.firstconv(x)
  238. x = self.firstbn(x)
  239. x = self.firstrelu(x)
  240. x = self.firstmaxpool(x)
  241. e1 = self.encoder1(x)
  242. e2 = self.encoder2(e1)
  243. e3 = self.encoder3(e2)
  244. e4 = self.encoder4(e3)
  245. # Center
  246. e4 = self.dblock(e4)
  247. # Decoder
  248. d4 = self.decoder4(e4) + e3
  249. d3 = self.decoder3(d4) + e2
  250. d2 = self.decoder2(d3) + e1
  251. d1 = self.decoder1(d2)
  252. out = self.finaldeconv1(d1)
  253. out = self.finalrelu1(out)
  254. out = self.finalconv2(out)
  255. out = self.finalrelu2(out)
  256. out = self.finalconv3(out)
  257. #return F.sigmoid(out)
  258. return out
  259. class LinkNet34(nn.Module):
  260. def __init__(self, num_classes=1):
  261. super(LinkNet34, self).__init__()
  262. filters = [64, 128, 256, 512]
  263. resnet = models.resnet34(pretrained=True)
  264. self.firstconv = resnet.conv1
  265. self.firstbn = resnet.bn1
  266. self.firstrelu = resnet.relu
  267. self.firstmaxpool = resnet.maxpool
  268. self.encoder1 = resnet.layer1
  269. self.encoder2 = resnet.layer2
  270. self.encoder3 = resnet.layer3
  271. self.encoder4 = resnet.layer4
  272. self.decoder4 = DecoderBlock(filters[3], filters[2])
  273. self.decoder3 = DecoderBlock(filters[2], filters[1])
  274. self.decoder2 = DecoderBlock(filters[1], filters[0])
  275. self.decoder1 = DecoderBlock(filters[0], filters[0])
  276. self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2)
  277. self.finalrelu1 = nonlinearity
  278. self.finalconv2 = nn.Conv2d(32, 32, 3)
  279. self.finalrelu2 = nonlinearity
  280. self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1)
  281. def forward(self, x):
  282. # Encoder
  283. x = self.firstconv(x)
  284. x = self.firstbn(x)
  285. x = self.firstrelu(x)
  286. x = self.firstmaxpool(x)
  287. e1 = self.encoder1(x)
  288. e2 = self.encoder2(e1)
  289. e3 = self.encoder3(e2)
  290. e4 = self.encoder4(e3)
  291. # Decoder
  292. d4 = self.decoder4(e4) + e3
  293. d3 = self.decoder3(d4) + e2
  294. d2 = self.decoder2(d3) + e1
  295. d1 = self.decoder1(d2)
  296. out = self.finaldeconv1(d1)
  297. out = self.finalrelu1(out)
  298. out = self.finalconv2(out)
  299. out = self.finalrelu2(out)
  300. out = self.finalconv3(out)
  301. #return F.sigmoid(out)
  302. return out