用kafka接收消息
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

256 lines
12KB

  1. import torch
  2. from torch import nn
  3. import torch.nn.functional as F
  4. import core.lib.psa.functional as PF
  5. import modeling.backbone.resnet_real as models
  6. #运行失败,compact可以运行,但over-completed运行不了。也是跟psamask的实现有关:用到了自定义的torch.autograd.Function(里面用到了cpp文件,导入不了_C模块出错)
  7. #
  8. # from . import functions
  9. #
  10. #
  11. # def psa_mask(input, psa_type=0, mask_H_=None, mask_W_=None):
  12. # return functions.psa_mask(input, psa_type, mask_H_, mask_W_)
  13. #
  14. #
  15. # import torch
  16. # from torch.autograd import Function
  17. # from .. import src
  18. # class PSAMask(Function):
  19. # @staticmethod
  20. # def forward(ctx, input, psa_type=0, mask_H_=None, mask_W_=None):
  21. # assert psa_type in [0, 1] # 0-col, 1-dis
  22. # assert (mask_H_ is None and mask_W_ is None) or (mask_H_ is not None and mask_W_ is not None)
  23. # num_, channels_, feature_H_, feature_W_ = input.size()
  24. # if mask_H_ is None and mask_W_ is None:
  25. # mask_H_, mask_W_ = 2 * feature_H_ - 1, 2 * feature_W_ - 1
  26. # assert (mask_H_ % 2 == 1) and (mask_W_ % 2 == 1)
  27. # assert channels_ == mask_H_ * mask_W_
  28. # half_mask_H_, half_mask_W_ = (mask_H_ - 1) // 2, (mask_W_ - 1) // 2
  29. # output = torch.zeros([num_, feature_H_ * feature_W_, feature_H_, feature_W_], dtype=input.dtype, device=input.device)
  30. # if not input.is_cuda:
  31. # src.cpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
  32. # else:
  33. # output = output.cuda()
  34. # src.gpu.psamask_forward(psa_type, input, output, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
  35. # ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_ = psa_type, num_, channels_, feature_H_, feature_W_
  36. # ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_ = mask_H_, mask_W_, half_mask_H_, half_mask_W_
  37. # return output
  38. #
  39. # @staticmethod
  40. # def backward(ctx, grad_output):
  41. # psa_type, num_, channels_, feature_H_, feature_W_ = ctx.psa_type, ctx.num_, ctx.channels_, ctx.feature_H_, ctx.feature_W_
  42. # mask_H_, mask_W_, half_mask_H_, half_mask_W_ = ctx.mask_H_, ctx.mask_W_, ctx.half_mask_H_, ctx.half_mask_W_
  43. # grad_input = torch.zeros([num_, channels_, feature_H_, feature_W_], dtype=grad_output.dtype, device=grad_output.device)
  44. # if not grad_output.is_cuda:
  45. # src.cpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
  46. # else:
  47. # src.gpu.psamask_backward(psa_type, grad_output, grad_input, num_, feature_H_, feature_W_, mask_H_, mask_W_, half_mask_H_, half_mask_W_)
  48. # return grad_input, None, None, None
  49. # psa_mask = PSAMask.apply
  50. class PSA(nn.Module):
  51. def __init__(self, in_channels=2048, mid_channels=512, psa_type=2, compact=False, shrink_factor=2, mask_h=59,
  52. mask_w=59, normalization_factor=1.0, psa_softmax=True):
  53. super(PSA, self).__init__()
  54. assert psa_type in [0, 1, 2]
  55. self.psa_type = psa_type
  56. self.compact = compact
  57. self.shrink_factor = shrink_factor
  58. self.mask_h = mask_h
  59. self.mask_w = mask_w
  60. self.psa_softmax = psa_softmax
  61. if normalization_factor is None:
  62. normalization_factor = mask_h * mask_w
  63. self.normalization_factor = normalization_factor
  64. self.reduce = nn.Sequential(
  65. nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
  66. nn.BatchNorm2d(mid_channels),
  67. nn.ReLU(inplace=True)
  68. )
  69. self.attention = nn.Sequential(
  70. nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False),
  71. nn.BatchNorm2d(mid_channels),
  72. nn.ReLU(inplace=True),
  73. nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False),
  74. )
  75. if psa_type == 2:
  76. self.reduce_p = nn.Sequential(
  77. nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
  78. nn.BatchNorm2d(mid_channels),
  79. nn.ReLU(inplace=True)
  80. )
  81. self.attention_p = nn.Sequential(
  82. nn.Conv2d(mid_channels, mid_channels, kernel_size=1, bias=False),
  83. nn.BatchNorm2d(mid_channels),
  84. nn.ReLU(inplace=True),
  85. nn.Conv2d(mid_channels, mask_h*mask_w, kernel_size=1, bias=False),
  86. )
  87. self.proj = nn.Sequential(
  88. nn.Conv2d(mid_channels * (2 if psa_type == 2 else 1), in_channels, kernel_size=1, bias=False),
  89. nn.BatchNorm2d(in_channels),
  90. nn.ReLU(inplace=True)
  91. )
  92. def forward(self, x):
  93. out = x
  94. if self.psa_type in [0, 1]:
  95. x = self.reduce(x)
  96. n, c, h, w = x.size()
  97. if self.shrink_factor != 1:
  98. h = (h - 1) // self.shrink_factor + 1#可以理解为这样做的目的是向上取整。
  99. w = (w - 1) // self.shrink_factor + 1
  100. x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
  101. y = self.attention(x)
  102. if self.compact:
  103. if self.psa_type == 1:
  104. y = y.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
  105. else:
  106. y = PF.psa_mask(y, self.psa_type, self.mask_h, self.mask_w)
  107. if self.psa_softmax:
  108. y = F.softmax(y, dim=1)
  109. x = torch.bmm(x.view(n, c, h * w), y.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
  110. elif self.psa_type == 2:
  111. x_col = self.reduce(x)
  112. x_dis = self.reduce_p(x)
  113. n, c, h, w = x_col.size()
  114. if self.shrink_factor != 1:
  115. h = (h - 1) // self.shrink_factor + 1
  116. w = (w - 1) // self.shrink_factor + 1
  117. x_col = F.interpolate(x_col, size=(h, w), mode='bilinear', align_corners=True)
  118. x_dis = F.interpolate(x_dis, size=(h, w), mode='bilinear', align_corners=True)
  119. y_col = self.attention(x_col)
  120. y_dis = self.attention_p(x_dis)
  121. if self.compact:
  122. y_dis = y_dis.view(n, h * w, h * w).transpose(1, 2).view(n, h * w, h, w)
  123. else:
  124. y_col = PF.psa_mask(y_col, 0, self.mask_h, self.mask_w)
  125. y_dis = PF.psa_mask(y_dis, 1, self.mask_h, self.mask_w)
  126. if self.psa_softmax:
  127. y_col = F.softmax(y_col, dim=1)
  128. y_dis = F.softmax(y_dis, dim=1)
  129. x_col = torch.bmm(x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
  130. x_dis = torch.bmm(x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(n, c, h, w) * (1.0 / self.normalization_factor)
  131. x = torch.cat([x_col, x_dis], 1)
  132. x = self.proj(x)
  133. if self.shrink_factor != 1:
  134. h = (h - 1) * self.shrink_factor + 1
  135. w = (w - 1) * self.shrink_factor + 1
  136. x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
  137. return torch.cat((out, x), 1)
  138. class PSANet(nn.Module):
  139. def __init__(self, layers=50, dropout=0.1, classes=2, zoom_factor=8, use_psa=True, psa_type=2, compact=False,
  140. shrink_factor=2, mask_h=59, mask_w=59, normalization_factor=1.0, psa_softmax=True,
  141. criterion=nn.CrossEntropyLoss(ignore_index=255), pretrained=True):
  142. super(PSANet, self).__init__()
  143. assert layers in [50, 101, 152]
  144. assert classes > 1
  145. assert zoom_factor in [1, 2, 4, 8]
  146. assert psa_type in [0, 1, 2]
  147. self.zoom_factor = zoom_factor
  148. self.use_psa = use_psa
  149. self.criterion = criterion
  150. if layers == 50:
  151. resnet = models.resnet50(pretrained=pretrained,deep_base=True)
  152. elif layers == 101:
  153. resnet = models.resnet101(pretrained=pretrained)
  154. else:
  155. resnet = models.resnet152(pretrained=pretrained)
  156. self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool)
  157. self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
  158. for n, m in self.layer3.named_modules():
  159. if 'conv2' in n:
  160. m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
  161. elif 'downsample.0' in n:
  162. m.stride = (1, 1)
  163. for n, m in self.layer4.named_modules():
  164. if 'conv2' in n:
  165. m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
  166. elif 'downsample.0' in n:
  167. m.stride = (1, 1)
  168. fea_dim = 2048
  169. if use_psa:
  170. self.psa = PSA(fea_dim, 512, psa_type, compact, shrink_factor, mask_h, mask_w, normalization_factor, psa_softmax)
  171. fea_dim *= 2
  172. self.cls = nn.Sequential(
  173. nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
  174. nn.BatchNorm2d(512),
  175. nn.ReLU(inplace=True),
  176. nn.Dropout2d(p=dropout),
  177. nn.Conv2d(512, classes, kernel_size=1)
  178. )
  179. if self.training:
  180. self.aux = nn.Sequential(
  181. nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
  182. nn.BatchNorm2d(256),
  183. nn.ReLU(inplace=True),
  184. nn.Dropout2d(p=dropout),
  185. nn.Conv2d(256, classes, kernel_size=1)
  186. )
  187. def forward(self, x, y=None):
  188. x_size = x.size()
  189. assert (x_size[2] - 1) % 8 == 0 and (x_size[3] - 1) % 8 == 0
  190. h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
  191. w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)
  192. x = self.layer0(x)
  193. x = self.layer1(x)
  194. x = self.layer2(x)
  195. x_tmp = self.layer3(x)
  196. x = self.layer4(x_tmp)
  197. if self.use_psa:
  198. x = self.psa(x)
  199. x = self.cls(x)
  200. if self.zoom_factor != 1:
  201. x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
  202. if self.training:
  203. aux = self.aux(x_tmp)
  204. if self.zoom_factor != 1:
  205. aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
  206. main_loss = self.criterion(x, y)
  207. aux_loss = self.criterion(aux, y)
  208. return x.max(1)[1], main_loss, aux_loss
  209. else:
  210. return x
  211. if __name__ == '__main__':
  212. import os
  213. os.environ["CUDA_VISIBLE_DEVICES"] = '0'
  214. crop_h = crop_w = 465
  215. input = torch.rand(4, 3, crop_h, crop_w).cuda()
  216. compact = False
  217. mask_h, mask_w = None, None
  218. shrink_factor = 2
  219. if compact:
  220. mask_h = (crop_h - 1) // (8 * shrink_factor) + 1
  221. mask_w = (crop_w - 1) // (8 * shrink_factor) + 1
  222. else:
  223. assert (mask_h is None and mask_w is None) or (mask_h is not None and mask_w is not None)
  224. if mask_h is None and mask_w is None:
  225. mask_h = 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1
  226. mask_w = 2 * ((crop_w - 1) // (8 * shrink_factor) + 1) - 1
  227. else:
  228. assert (mask_h % 2 == 1) and (mask_h >= 3) and (mask_h <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1)
  229. assert (mask_w % 2 == 1) and (mask_w >= 3) and (mask_w <= 2 * ((crop_h - 1) // (8 * shrink_factor) + 1) - 1)
  230. model = PSANet(layers=50, dropout=0.1, classes=21, zoom_factor=8, use_psa=True, psa_type=2, compact=compact,
  231. shrink_factor=shrink_factor, mask_h=mask_h, mask_w=mask_w, psa_softmax=True, pretrained=False).cuda()
  232. print(model)
  233. model.eval()
  234. output = model(input)
  235. print('PSANet', output.size())