用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

213 lines
7.3KB

  1. """Context Encoding for Semantic Segmentation"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from .segbase import SegBaseModel
  6. from .fcn import _FCNHead
  7. __all__ = ['EncNet', 'EncModule', 'get_encnet', 'get_encnet_resnet50_ade',
  8. 'get_encnet_resnet101_ade', 'get_encnet_resnet152_ade']
  9. class EncNet(SegBaseModel):
  10. def __init__(self, nclass, backbone='resnet50', aux=True, se_loss=True, lateral=False,
  11. pretrained_base=True, **kwargs):
  12. super(EncNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
  13. self.head = _EncHead(2048, nclass, se_loss=se_loss, lateral=lateral, **kwargs)
  14. if aux:
  15. self.auxlayer = _FCNHead(1024, nclass, **kwargs)
  16. self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
  17. def forward(self, x):
  18. size = x.size()[2:]
  19. features = self.base_forward(x)
  20. x = list(self.head(*features))
  21. x[0] = F.interpolate(x[0], size, mode='bilinear', align_corners=True)
  22. if self.aux:
  23. auxout = self.auxlayer(features[2])
  24. auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
  25. x.append(auxout)
  26. return tuple(x)
  27. class _EncHead(nn.Module):
  28. def __init__(self, in_channels, nclass, se_loss=True, lateral=True,
  29. norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
  30. super(_EncHead, self).__init__()
  31. self.lateral = lateral
  32. self.conv5 = nn.Sequential(
  33. nn.Conv2d(in_channels, 512, 3, padding=1, bias=False),
  34. norm_layer(512, **({} if norm_kwargs is None else norm_kwargs)),
  35. nn.ReLU(True)
  36. )
  37. if lateral:
  38. self.connect = nn.ModuleList([
  39. nn.Sequential(
  40. nn.Conv2d(512, 512, 1, bias=False),
  41. norm_layer(512, **({} if norm_kwargs is None else norm_kwargs)),
  42. nn.ReLU(True)),
  43. nn.Sequential(
  44. nn.Conv2d(1024, 512, 1, bias=False),
  45. norm_layer(512, **({} if norm_kwargs is None else norm_kwargs)),
  46. nn.ReLU(True)),
  47. ])
  48. self.fusion = nn.Sequential(
  49. nn.Conv2d(3 * 512, 512, 3, padding=1, bias=False),
  50. norm_layer(512, **({} if norm_kwargs is None else norm_kwargs)),
  51. nn.ReLU(True)
  52. )
  53. self.encmodule = EncModule(512, nclass, ncodes=32, se_loss=se_loss,
  54. norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs)
  55. self.conv6 = nn.Sequential(
  56. nn.Dropout(0.1, False),
  57. nn.Conv2d(512, nclass, 1)
  58. )
  59. def forward(self, *inputs):
  60. feat = self.conv5(inputs[-1])
  61. if self.lateral:
  62. c2 = self.connect[0](inputs[1])
  63. c3 = self.connect[1](inputs[2])
  64. feat = self.fusion(torch.cat([feat, c2, c3], 1))
  65. outs = list(self.encmodule(feat))
  66. outs[0] = self.conv6(outs[0])
  67. return tuple(outs)
  68. class EncModule(nn.Module):
  69. def __init__(self, in_channels, nclass, ncodes=32, se_loss=True,
  70. norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
  71. super(EncModule, self).__init__()
  72. self.se_loss = se_loss
  73. self.encoding = nn.Sequential(
  74. nn.Conv2d(in_channels, in_channels, 1, bias=False),
  75. norm_layer(in_channels, **({} if norm_kwargs is None else norm_kwargs)),
  76. nn.ReLU(True),
  77. Encoding(D=in_channels, K=ncodes),
  78. nn.BatchNorm1d(ncodes),
  79. nn.ReLU(True),
  80. Mean(dim=1)
  81. )
  82. self.fc = nn.Sequential(
  83. nn.Linear(in_channels, in_channels),
  84. nn.Sigmoid()
  85. )
  86. if self.se_loss:
  87. self.selayer = nn.Linear(in_channels, nclass)
  88. def forward(self, x):
  89. en = self.encoding(x)
  90. b, c, _, _ = x.size()
  91. gamma = self.fc(en)
  92. y = gamma.view(b, c, 1, 1)
  93. outputs = [F.relu_(x + x * y)]
  94. if self.se_loss:
  95. outputs.append(self.selayer(en))
  96. return tuple(outputs)
  97. class Encoding(nn.Module):
  98. def __init__(self, D, K):
  99. super(Encoding, self).__init__()
  100. # init codewords and smoothing factor
  101. self.D, self.K = D, K
  102. self.codewords = nn.Parameter(torch.Tensor(K, D), requires_grad=True)
  103. self.scale = nn.Parameter(torch.Tensor(K), requires_grad=True)
  104. self.reset_params()
  105. def reset_params(self):
  106. std1 = 1. / ((self.K * self.D) ** (1 / 2))
  107. self.codewords.data.uniform_(-std1, std1)
  108. self.scale.data.uniform_(-1, 0)
  109. def forward(self, X):
  110. # input X is a 4D tensor
  111. assert (X.size(1) == self.D)
  112. B, D = X.size(0), self.D
  113. if X.dim() == 3:
  114. # BxDxN -> BxNxD
  115. X = X.transpose(1, 2).contiguous()
  116. elif X.dim() == 4:
  117. # BxDxHxW -> Bx(HW)xD
  118. X = X.view(B, D, -1).transpose(1, 2).contiguous()
  119. else:
  120. raise RuntimeError('Encoding Layer unknown input dims!')
  121. # assignment weights BxNxK
  122. A = F.softmax(self.scale_l2(X, self.codewords, self.scale), dim=2)
  123. # aggregate
  124. E = self.aggregate(A, X, self.codewords)
  125. return E
  126. def __repr__(self):
  127. return self.__class__.__name__ + '(' \
  128. + 'N x' + str(self.D) + '=>' + str(self.K) + 'x' \
  129. + str(self.D) + ')'
  130. @staticmethod
  131. def scale_l2(X, C, S):
  132. S = S.view(1, 1, C.size(0), 1)
  133. X = X.unsqueeze(2).expand(X.size(0), X.size(1), C.size(0), C.size(1))
  134. C = C.unsqueeze(0).unsqueeze(0)
  135. SL = S * (X - C)
  136. SL = SL.pow(2).sum(3)
  137. return SL
  138. @staticmethod
  139. def aggregate(A, X, C):
  140. A = A.unsqueeze(3)
  141. X = X.unsqueeze(2).expand(X.size(0), X.size(1), C.size(0), C.size(1))
  142. C = C.unsqueeze(0).unsqueeze(0)
  143. E = A * (X - C)
  144. E = E.sum(1)
  145. return E
  146. class Mean(nn.Module):
  147. def __init__(self, dim, keep_dim=False):
  148. super(Mean, self).__init__()
  149. self.dim = dim
  150. self.keep_dim = keep_dim
  151. def forward(self, input):
  152. return input.mean(self.dim, self.keep_dim)
  153. def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models',
  154. pretrained_base=True, **kwargs):
  155. acronyms = {
  156. 'pascal_voc': 'pascal_voc',
  157. 'pascal_aug': 'pascal_aug',
  158. 'ade20k': 'ade',
  159. 'coco': 'coco',
  160. 'citys': 'citys',
  161. }
  162. from ..data.dataloader import datasets
  163. model = EncNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  164. if pretrained:
  165. from .model_store import get_model_file
  166. device = torch.device(kwargs['local_rank'])
  167. model.load_state_dict(torch.load(get_model_file('encnet_%s_%s' % (backbone, acronyms[dataset]), root=root),
  168. map_location=device))
  169. return model
  170. def get_encnet_resnet50_ade(**kwargs):
  171. return get_encnet('ade20k', 'resnet50', **kwargs)
  172. def get_encnet_resnet101_ade(**kwargs):
  173. return get_encnet('ade20k', 'resnet101', **kwargs)
  174. def get_encnet_resnet152_ade(**kwargs):
  175. return get_encnet('ade20k', 'resnet152', **kwargs)
  176. if __name__ == '__main__':
  177. img = torch.randn(2, 3, 224, 224)
  178. model = get_encnet_resnet50_ade()
  179. outputs = model(img)