用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
4.6KB

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .base_models.xception import get_xception
  5. from .deeplabv3 import _ASPP
  6. from .fcn import _FCNHead
  7. from ..nn import _ConvBNReLU
  8. __all__ = ['DeepLabV3Plus', 'get_deeplabv3_plus', 'get_deeplabv3_plus_xception_voc']
  9. class DeepLabV3Plus(nn.Module):
  10. r"""DeepLabV3Plus
  11. Parameters
  12. ----------
  13. nclass : int
  14. Number of categories for the training dataset.
  15. backbone : string
  16. Pre-trained dilated backbone network type (default:'xception').
  17. norm_layer : object
  18. Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
  19. for Synchronized Cross-GPU BachNormalization).
  20. aux : bool
  21. Auxiliary loss.
  22. Reference:
  23. Chen, Liang-Chieh, et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic
  24. Image Segmentation."
  25. """
  26. def __init__(self, nclass, backbone='xception', aux=True, pretrained_base=True, dilated=True, **kwargs):
  27. super(DeepLabV3Plus, self).__init__()
  28. self.aux = aux
  29. self.nclass = nclass
  30. output_stride = 8 if dilated else 32
  31. self.pretrained = get_xception(pretrained=pretrained_base, output_stride=output_stride, **kwargs)
  32. # deeplabv3 plus
  33. self.head = _DeepLabHead(nclass, **kwargs)
  34. if aux:
  35. self.auxlayer = _FCNHead(728, nclass, **kwargs)
  36. def base_forward(self, x):
  37. # Entry flow
  38. x = self.pretrained.conv1(x)
  39. x = self.pretrained.bn1(x)
  40. x = self.pretrained.relu(x)
  41. x = self.pretrained.conv2(x)
  42. x = self.pretrained.bn2(x)
  43. x = self.pretrained.relu(x)
  44. x = self.pretrained.block1(x)
  45. # add relu here
  46. x = self.pretrained.relu(x)
  47. low_level_feat = x
  48. x = self.pretrained.block2(x)
  49. x = self.pretrained.block3(x)
  50. # Middle flow
  51. x = self.pretrained.midflow(x)
  52. mid_level_feat = x
  53. # Exit flow
  54. x = self.pretrained.block20(x)
  55. x = self.pretrained.relu(x)
  56. x = self.pretrained.conv3(x)
  57. x = self.pretrained.bn3(x)
  58. x = self.pretrained.relu(x)
  59. x = self.pretrained.conv4(x)
  60. x = self.pretrained.bn4(x)
  61. x = self.pretrained.relu(x)
  62. x = self.pretrained.conv5(x)
  63. x = self.pretrained.bn5(x)
  64. x = self.pretrained.relu(x)
  65. return low_level_feat, mid_level_feat, x
  66. def forward(self, x):
  67. size = x.size()[2:]
  68. c1, c3, c4 = self.base_forward(x)
  69. outputs = list()
  70. x = self.head(c4, c1)
  71. x = F.interpolate(x, size, mode='bilinear', align_corners=True)
  72. outputs.append(x)
  73. if self.aux:
  74. auxout = self.auxlayer(c3)
  75. auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
  76. outputs.append(auxout)
  77. return tuple(outputs)
  78. class _DeepLabHead(nn.Module):
  79. def __init__(self, nclass, c1_channels=128, norm_layer=nn.BatchNorm2d, **kwargs):
  80. super(_DeepLabHead, self).__init__()
  81. self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, **kwargs)
  82. self.c1_block = _ConvBNReLU(c1_channels, 48, 3, padding=1, norm_layer=norm_layer)
  83. self.block = nn.Sequential(
  84. _ConvBNReLU(304, 256, 3, padding=1, norm_layer=norm_layer),
  85. nn.Dropout(0.5),
  86. _ConvBNReLU(256, 256, 3, padding=1, norm_layer=norm_layer),
  87. nn.Dropout(0.1),
  88. nn.Conv2d(256, nclass, 1))
  89. def forward(self, x, c1):
  90. size = c1.size()[2:]
  91. c1 = self.c1_block(c1)
  92. x = self.aspp(x)
  93. x = F.interpolate(x, size, mode='bilinear', align_corners=True)
  94. return self.block(torch.cat([x, c1], dim=1))
  95. def get_deeplabv3_plus(dataset='pascal_voc', backbone='xception', pretrained=False, root='~/.torch/models',
  96. pretrained_base=True, **kwargs):
  97. acronyms = {
  98. 'pascal_voc': 'pascal_voc',
  99. 'pascal_aug': 'pascal_aug',
  100. 'ade20k': 'ade',
  101. 'coco': 'coco',
  102. 'citys': 'citys',
  103. }
  104. from ..data.dataloader import datasets
  105. model = DeepLabV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  106. if pretrained:
  107. from .model_store import get_model_file
  108. device = torch.device(kwargs['local_rank'])
  109. model.load_state_dict(
  110. torch.load(get_model_file('deeplabv3_plus_%s_%s' % (backbone, acronyms[dataset]), root=root),
  111. map_location=device))
  112. return model
  113. def get_deeplabv3_plus_xception_voc(**kwargs):
  114. return get_deeplabv3_plus('pascal_voc', 'xception', **kwargs)
  115. if __name__ == '__main__':
  116. model = get_deeplabv3_plus_xception_voc()