用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

180 lines
6.3KB

  1. """Image Cascade Network"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from core.models.segbase import SegBaseModel
  6. __all__ = ['ICNet', 'get_icnet', 'get_icnet_resnet50_citys',
  7. 'get_icnet_resnet101_citys', 'get_icnet_resnet152_citys']
  8. class ICNet(SegBaseModel):
  9. """Image Cascade Network"""
  10. def __init__(self, nclass, backbone='resnet50', aux=False, jpu=False, pretrained_base=True, **kwargs):
  11. super(ICNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
  12. self.conv_sub1 = nn.Sequential(
  13. _ConvBNReLU(3, 32, 3, 2, **kwargs),
  14. _ConvBNReLU(32, 32, 3, 2, **kwargs),
  15. _ConvBNReLU(32, 64, 3, 2, **kwargs)
  16. )
  17. self.ppm = PyramidPoolingModule()
  18. self.head = _ICHead(nclass, **kwargs)
  19. self.__setattr__('exclusive', ['conv_sub1', 'head'])
  20. def forward(self, x):
  21. # sub 1
  22. x_sub1 = self.conv_sub1(x)
  23. # sub 2
  24. x_sub2 = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
  25. _, x_sub2, _, _ = self.base_forward(x_sub2)
  26. # sub 4
  27. x_sub4 = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True)
  28. _, _, _, x_sub4 = self.base_forward(x_sub4)
  29. # add PyramidPoolingModule
  30. x_sub4 = self.ppm(x_sub4)
  31. outputs = self.head(x_sub1, x_sub2, x_sub4)
  32. return tuple(outputs)
  33. class PyramidPoolingModule(nn.Module):
  34. def __init__(self, pyramids=[1,2,3,6]):
  35. super(PyramidPoolingModule, self).__init__()
  36. self.pyramids = pyramids
  37. def forward(self, input):
  38. feat = input
  39. height, width = input.shape[2:]
  40. for bin_size in self.pyramids:
  41. x = F.adaptive_avg_pool2d(input, output_size=bin_size)
  42. x = F.interpolate(x, size=(height, width), mode='bilinear', align_corners=True)
  43. feat = feat + x
  44. return feat
  45. class _ICHead(nn.Module):
  46. def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
  47. super(_ICHead, self).__init__()
  48. #self.cff_12 = CascadeFeatureFusion(512, 64, 128, nclass, norm_layer, **kwargs)
  49. self.cff_12 = CascadeFeatureFusion(128, 64, 128, nclass, norm_layer, **kwargs)
  50. self.cff_24 = CascadeFeatureFusion(2048, 512, 128, nclass, norm_layer, **kwargs)
  51. self.conv_cls = nn.Conv2d(128, nclass, 1, bias=False)
  52. def forward(self, x_sub1, x_sub2, x_sub4):
  53. outputs = list()
  54. x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2)
  55. outputs.append(x_24_cls)
  56. #x_cff_12, x_12_cls = self.cff_12(x_sub2, x_sub1)
  57. x_cff_12, x_12_cls = self.cff_12(x_cff_24, x_sub1)
  58. outputs.append(x_12_cls)
  59. up_x2 = F.interpolate(x_cff_12, scale_factor=2, mode='bilinear', align_corners=True)
  60. up_x2 = self.conv_cls(up_x2)
  61. outputs.append(up_x2)
  62. up_x8 = F.interpolate(up_x2, scale_factor=4, mode='bilinear', align_corners=True)
  63. outputs.append(up_x8)
  64. # 1 -> 1/4 -> 1/8 -> 1/16
  65. outputs.reverse()
  66. return outputs
  67. class _ConvBNReLU(nn.Module):
  68. def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1,
  69. groups=1, norm_layer=nn.BatchNorm2d, bias=False, **kwargs):
  70. super(_ConvBNReLU, self).__init__()
  71. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
  72. self.bn = norm_layer(out_channels)
  73. self.relu = nn.ReLU(True)
  74. def forward(self, x):
  75. x = self.conv(x)
  76. x = self.bn(x)
  77. x = self.relu(x)
  78. return x
  79. class CascadeFeatureFusion(nn.Module):
  80. """CFF Unit"""
  81. def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
  82. super(CascadeFeatureFusion, self).__init__()
  83. self.conv_low = nn.Sequential(
  84. nn.Conv2d(low_channels, out_channels, 3, padding=2, dilation=2, bias=False),
  85. norm_layer(out_channels)
  86. )
  87. self.conv_high = nn.Sequential(
  88. nn.Conv2d(high_channels, out_channels, 1, bias=False),
  89. norm_layer(out_channels)
  90. )
  91. self.conv_low_cls = nn.Conv2d(out_channels, nclass, 1, bias=False)
  92. def forward(self, x_low, x_high):
  93. x_low = F.interpolate(x_low, size=x_high.size()[2:], mode='bilinear', align_corners=True)
  94. x_low = self.conv_low(x_low)
  95. x_high = self.conv_high(x_high)
  96. x = x_low + x_high
  97. x = F.relu(x, inplace=True)
  98. x_low_cls = self.conv_low_cls(x_low)
  99. return x, x_low_cls
  100. def get_icnet(dataset='citys', backbone='resnet50', pretrained=False, root='~/.torch/models',
  101. pretrained_base=True, **kwargs):
  102. acronyms = {
  103. 'pascal_voc': 'pascal_voc',
  104. 'pascal_aug': 'pascal_aug',
  105. 'ade20k': 'ade',
  106. 'coco': 'coco',
  107. 'citys': 'citys',
  108. }
  109. from ..data.dataloader import datasets
  110. model = ICNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  111. if pretrained:
  112. from .model_store import get_model_file
  113. device = torch.device(kwargs['local_rank'])
  114. model.load_state_dict(torch.load(get_model_file('icnet_%s_%s' % (backbone, acronyms[dataset]), root=root),
  115. map_location=device))
  116. return model
  117. def get_icnet_resnet50_citys(**kwargs):
  118. return get_icnet('citys', 'resnet50', **kwargs)
  119. def get_icnet_resnet101_citys(**kwargs):
  120. return get_icnet('citys', 'resnet101', **kwargs)
  121. def get_icnet_resnet152_citys(**kwargs):
  122. return get_icnet('citys', 'resnet152', **kwargs)
  123. if __name__ == '__main__':
  124. # img = torch.randn(1, 3, 256, 256)
  125. # model = get_icnet_resnet50_citys()
  126. # outputs = model(img)
  127. input = torch.rand(2, 3, 224, 224)
  128. model = ICNet(4, pretrained_base=False)
  129. # target = torch.zeros(4, 512, 512).cuda()
  130. # model.eval()
  131. # print(model)
  132. loss = model(input)
  133. #print(loss, loss.shape)
  134. # from torchsummary import summary
  135. #
  136. # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
  137. import torch
  138. from thop import profile
  139. from torchsummary import summary
  140. flop, params = profile(model, input_size=(1, 3, 512, 512))
  141. print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))