用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

236 lines
8.4KB

  1. import os
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import sys
  6. sys.path.extend(['/home/thsw2/WJ/src/yolov5/segutils/','../..','..' ])
  7. from core.models.base_models.vgg import vgg16
  8. __all__ = ['get_fcn32s', 'get_fcn16s', 'get_fcn8s',
  9. 'get_fcn32s_vgg16_voc', 'get_fcn16s_vgg16_voc', 'get_fcn8s_vgg16_voc']
  10. class FCN32s(nn.Module):
  11. """There are some difference from original fcn"""
  12. def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True,
  13. norm_layer=nn.BatchNorm2d, **kwargs):
  14. super(FCN32s, self).__init__()
  15. self.aux = aux
  16. if backbone == 'vgg16':
  17. self.pretrained = vgg16(pretrained=pretrained_base).features
  18. else:
  19. raise RuntimeError('unknown backbone: {}'.format(backbone))
  20. self.head = _FCNHead(512, nclass, norm_layer)
  21. if aux:
  22. self.auxlayer = _FCNHead(512, nclass, norm_layer)
  23. self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
  24. def forward(self, x):
  25. size = x.size()[2:]
  26. pool5 = self.pretrained(x)
  27. outputs = []
  28. out = self.head(pool5)
  29. out = F.interpolate(out, size, mode='bilinear', align_corners=True)
  30. outputs.append(out)
  31. if self.aux:
  32. auxout = self.auxlayer(pool5)
  33. auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
  34. outputs.append(auxout)
  35. return tuple(outputs)
  36. class FCN16s(nn.Module):
  37. def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True, norm_layer=nn.BatchNorm2d, **kwargs):
  38. super(FCN16s, self).__init__()
  39. self.aux = aux
  40. if backbone == 'vgg16':
  41. self.pretrained = vgg16(pretrained=pretrained_base).features
  42. else:
  43. raise RuntimeError('unknown backbone: {}'.format(backbone))
  44. self.pool4 = nn.Sequential(*self.pretrained[:24])
  45. self.pool5 = nn.Sequential(*self.pretrained[24:])
  46. self.head = _FCNHead(512, nclass, norm_layer)
  47. self.score_pool4 = nn.Conv2d(512, nclass, 1)
  48. if aux:
  49. self.auxlayer = _FCNHead(512, nclass, norm_layer)
  50. self.__setattr__('exclusive', ['head', 'score_pool4', 'auxlayer'] if aux else ['head', 'score_pool4'])
  51. def forward(self, x):
  52. pool4 = self.pool4(x)
  53. pool5 = self.pool5(pool4)
  54. outputs = []
  55. score_fr = self.head(pool5)
  56. score_pool4 = self.score_pool4(pool4)
  57. upscore2 = F.interpolate(score_fr, score_pool4.size()[2:], mode='bilinear', align_corners=True)
  58. fuse_pool4 = upscore2 + score_pool4
  59. out = F.interpolate(fuse_pool4, x.size()[2:], mode='bilinear', align_corners=True)
  60. outputs.append(out)
  61. if self.aux:
  62. auxout = self.auxlayer(pool5)
  63. auxout = F.interpolate(auxout, x.size()[2:], mode='bilinear', align_corners=True)
  64. outputs.append(auxout)
  65. #return tuple(outputs)
  66. return outputs[0]
  67. class FCN8s(nn.Module):
  68. def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True, norm_layer=nn.BatchNorm2d, **kwargs):
  69. super(FCN8s, self).__init__()
  70. self.aux = aux
  71. if backbone == 'vgg16':
  72. self.pretrained = vgg16(pretrained=pretrained_base).features
  73. else:
  74. raise RuntimeError('unknown backbone: {}'.format(backbone))
  75. self.pool3 = nn.Sequential(*self.pretrained[:17])
  76. self.pool4 = nn.Sequential(*self.pretrained[17:24])
  77. self.pool5 = nn.Sequential(*self.pretrained[24:])
  78. self.head = _FCNHead(512, nclass, norm_layer)
  79. self.score_pool3 = nn.Conv2d(256, nclass, 1)
  80. self.score_pool4 = nn.Conv2d(512, nclass, 1)
  81. if aux:
  82. self.auxlayer = _FCNHead(512, nclass, norm_layer)
  83. self.__setattr__('exclusive',
  84. ['head', 'score_pool3', 'score_pool4', 'auxlayer'] if aux else ['head', 'score_pool3',
  85. 'score_pool4'])
  86. def forward(self, x):
  87. pool3 = self.pool3(x)
  88. pool4 = self.pool4(pool3)
  89. pool5 = self.pool5(pool4)
  90. outputs = []
  91. score_fr = self.head(pool5)
  92. score_pool4 = self.score_pool4(pool4)
  93. score_pool3 = self.score_pool3(pool3)
  94. upscore2 = F.interpolate(score_fr, score_pool4.size()[2:], mode='bilinear', align_corners=True)
  95. fuse_pool4 = upscore2 + score_pool4
  96. upscore_pool4 = F.interpolate(fuse_pool4, score_pool3.size()[2:], mode='bilinear', align_corners=True)
  97. fuse_pool3 = upscore_pool4 + score_pool3
  98. out = F.interpolate(fuse_pool3, x.size()[2:], mode='bilinear', align_corners=True)
  99. outputs.append(out)
  100. if self.aux:
  101. auxout = self.auxlayer(pool5)
  102. auxout = F.interpolate(auxout, x.size()[2:], mode='bilinear', align_corners=True)
  103. outputs.append(auxout)
  104. return tuple(outputs)
  105. class _FCNHead(nn.Module):
  106. def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs):
  107. super(_FCNHead, self).__init__()
  108. inter_channels = in_channels // 4
  109. self.block = nn.Sequential(
  110. nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
  111. norm_layer(inter_channels),
  112. nn.ReLU(inplace=True),
  113. nn.Dropout(0.1),
  114. nn.Conv2d(inter_channels, channels, 1)
  115. )
  116. def forward(self, x):
  117. return self.block(x)
  118. def get_fcn32s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
  119. pretrained_base=True, **kwargs):
  120. acronyms = {
  121. 'pascal_voc': 'pascal_voc',
  122. 'pascal_aug': 'pascal_aug',
  123. 'ade20k': 'ade',
  124. 'coco': 'coco',
  125. 'citys': 'citys',
  126. }
  127. from ..data.dataloader import datasets
  128. model = FCN32s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  129. if pretrained:
  130. from .model_store import get_model_file
  131. device = torch.device(kwargs['local_rank'])
  132. model.load_state_dict(torch.load(get_model_file('fcn32s_%s_%s' % (backbone, acronyms[dataset]), root=root),
  133. map_location=device))
  134. return model
  135. def get_fcn16s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
  136. pretrained_base=True, **kwargs):
  137. acronyms = {
  138. 'pascal_voc': 'pascal_voc',
  139. 'pascal_aug': 'pascal_aug',
  140. 'ade20k': 'ade',
  141. 'coco': 'coco',
  142. 'citys': 'citys',
  143. }
  144. from ..data.dataloader import datasets
  145. model = FCN16s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  146. if pretrained:
  147. from .model_store import get_model_file
  148. device = torch.device(kwargs['local_rank'])
  149. model.load_state_dict(torch.load(get_model_file('fcn16s_%s_%s' % (backbone, acronyms[dataset]), root=root),
  150. map_location=device))
  151. return model
  152. def get_fcn8s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
  153. pretrained_base=True, **kwargs):
  154. acronyms = {
  155. 'pascal_voc': 'pascal_voc',
  156. 'pascal_aug': 'pascal_aug',
  157. 'ade20k': 'ade',
  158. 'coco': 'coco',
  159. 'citys': 'citys',
  160. }
  161. from ..data.dataloader import datasets
  162. model = FCN8s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  163. if pretrained:
  164. from .model_store import get_model_file
  165. device = torch.device(kwargs['local_rank'])
  166. model.load_state_dict(torch.load(get_model_file('fcn8s_%s_%s' % (backbone, acronyms[dataset]), root=root),
  167. map_location=device))
  168. return model
  169. def get_fcn32s_vgg16_voc(**kwargs):
  170. return get_fcn32s('pascal_voc', 'vgg16', **kwargs)
  171. def get_fcn16s_vgg16_voc(**kwargs):
  172. return get_fcn16s('pascal_voc', 'vgg16', **kwargs)
  173. def get_fcn8s_vgg16_voc(**kwargs):
  174. return get_fcn8s('pascal_voc', 'vgg16', **kwargs)
  175. if __name__ == "__main__":
  176. model = FCN16s(21)
  177. print(model)
  178. input = torch.rand(2, 3, 224,224)
  179. #target = torch.zeros(4, 512, 512).cuda()
  180. #model.eval()
  181. #print(model)
  182. loss = model(input)
  183. print(loss)
  184. print(loss.shape)
  185. import torch
  186. from thop import profile
  187. from torchsummary import summary
  188. flop,params=profile(model,input_size=(1,3,512,512))
  189. print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop/1e9, params/1e6))