用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

361 lines
13KB

  1. """ Object Context Network for Scene Parsing"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from core.models.segbase import SegBaseModel
  6. from core.models.fcn import _FCNHead
  7. __all__ = ['OCNet', 'get_ocnet', 'get_base_ocnet_resnet101_citys',
  8. 'get_pyramid_ocnet_resnet101_citys', 'get_asp_ocnet_resnet101_citys']
  9. class OCNet(SegBaseModel):
  10. r"""OCNet
  11. Parameters
  12. ----------
  13. nclass : int
  14. Number of categories for the training dataset.
  15. backbone : string
  16. Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
  17. 'resnet101' or 'resnet152').
  18. norm_layer : object
  19. Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
  20. for Synchronized Cross-GPU BachNormalization).
  21. aux : bool
  22. Auxiliary loss.
  23. Reference:
  24. Yuhui Yuan, Jingdong Wang. "OCNet: Object Context Network for Scene Parsing."
  25. arXiv preprint arXiv:1809.00916 (2018).
  26. """
  27. def __init__(self, nclass, backbone='resnet101', oc_arch='base', aux=False, pretrained_base=True, **kwargs):
  28. super(OCNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
  29. self.head = _OCHead(nclass, oc_arch, **kwargs)
  30. if self.aux:
  31. self.auxlayer = _FCNHead(1024, nclass, **kwargs)
  32. self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
  33. def forward(self, x):
  34. size = x.size()[2:]
  35. _, _, c3, c4 = self.base_forward(x)
  36. outputs = []
  37. x = self.head(c4)
  38. x = F.interpolate(x, size, mode='bilinear', align_corners=True)
  39. outputs.append(x)
  40. if self.aux:
  41. auxout = self.auxlayer(c3)
  42. auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
  43. outputs.append(auxout)
  44. #return tuple(outputs)
  45. return outputs[0]
  46. class _OCHead(nn.Module):
  47. def __init__(self, nclass, oc_arch, norm_layer=nn.BatchNorm2d, **kwargs):
  48. super(_OCHead, self).__init__()
  49. if oc_arch == 'base':
  50. self.context = nn.Sequential(
  51. nn.Conv2d(2048, 512, 3, 1, padding=1, bias=False),
  52. norm_layer(512),
  53. nn.ReLU(True),
  54. BaseOCModule(512, 512, 256, 256, scales=([1]), norm_layer=norm_layer, **kwargs))
  55. elif oc_arch == 'pyramid':
  56. self.context = nn.Sequential(
  57. nn.Conv2d(2048, 512, 3, 1, padding=1, bias=False),
  58. norm_layer(512),
  59. nn.ReLU(True),
  60. PyramidOCModule(512, 512, 256, 512, scales=([1, 2, 3, 6]), norm_layer=norm_layer, **kwargs))
  61. elif oc_arch == 'asp':
  62. self.context = ASPOCModule(2048, 512, 256, 512, norm_layer=norm_layer, **kwargs)
  63. else:
  64. raise ValueError("Unknown OC architecture!")
  65. self.out = nn.Conv2d(512, nclass, 1)
  66. def forward(self, x):
  67. x = self.context(x)
  68. return self.out(x)
  69. class BaseAttentionBlock(nn.Module):
  70. """The basic implementation for self-attention block/non-local block."""
  71. def __init__(self, in_channels, out_channels, key_channels, value_channels,
  72. scale=1, norm_layer=nn.BatchNorm2d, **kwargs):
  73. super(BaseAttentionBlock, self).__init__()
  74. self.scale = scale
  75. self.key_channels = key_channels
  76. self.value_channels = value_channels
  77. if scale > 1:
  78. self.pool = nn.MaxPool2d(scale)
  79. self.f_value = nn.Conv2d(in_channels, value_channels, 1)
  80. self.f_key = nn.Sequential(
  81. nn.Conv2d(in_channels, key_channels, 1),
  82. norm_layer(key_channels),
  83. nn.ReLU(True)
  84. )
  85. self.f_query = self.f_key
  86. self.W = nn.Conv2d(value_channels, out_channels, 1)
  87. nn.init.constant_(self.W.weight, 0)
  88. nn.init.constant_(self.W.bias, 0)
  89. def forward(self, x):
  90. batch_size, c, w, h = x.size()
  91. if self.scale > 1:
  92. x = self.pool(x)
  93. value = self.f_value(x).view(batch_size, self.value_channels, -1).permute(0, 2, 1)
  94. query = self.f_query(x).view(batch_size, self.key_channels, -1).permute(0, 2, 1)
  95. key = self.f_key(x).view(batch_size, self.key_channels, -1)
  96. sim_map = torch.bmm(query, key) * (self.key_channels ** -.5)
  97. sim_map = F.softmax(sim_map, dim=-1)
  98. context = torch.bmm(sim_map, value).permute(0, 2, 1).contiguous()
  99. context = context.view(batch_size, self.value_channels, *x.size()[2:])
  100. context = self.W(context)
  101. if self.scale > 1:
  102. context = F.interpolate(context, size=(w, h), mode='bilinear', align_corners=True)
  103. return context
  104. class BaseOCModule(nn.Module):
  105. """Base-OC"""
  106. def __init__(self, in_channels, out_channels, key_channels, value_channels,
  107. scales=([1]), norm_layer=nn.BatchNorm2d, concat=True, **kwargs):
  108. super(BaseOCModule, self).__init__()
  109. self.stages = nn.ModuleList([
  110. BaseAttentionBlock(in_channels, out_channels, key_channels, value_channels, scale, norm_layer, **kwargs)
  111. for scale in scales])
  112. in_channels = in_channels * 2 if concat else in_channels
  113. self.project = nn.Sequential(
  114. nn.Conv2d(in_channels, out_channels, 1),
  115. norm_layer(out_channels),
  116. nn.ReLU(True),
  117. nn.Dropout2d(0.05)
  118. )
  119. self.concat = concat
  120. def forward(self, x):
  121. priors = [stage(x) for stage in self.stages]
  122. context = priors[0]
  123. for i in range(1, len(priors)):
  124. context += priors[i]
  125. if self.concat:
  126. context = torch.cat([context, x], 1)
  127. out = self.project(context)
  128. return out
  129. class PyramidAttentionBlock(nn.Module):
  130. """The basic implementation for pyramid self-attention block/non-local block"""
  131. def __init__(self, in_channels, out_channels, key_channels, value_channels,
  132. scale=1, norm_layer=nn.BatchNorm2d, **kwargs):
  133. super(PyramidAttentionBlock, self).__init__()
  134. self.scale = scale
  135. self.value_channels = value_channels
  136. self.key_channels = key_channels
  137. self.f_value = nn.Conv2d(in_channels, value_channels, 1)
  138. self.f_key = nn.Sequential(
  139. nn.Conv2d(in_channels, key_channels, 1),
  140. norm_layer(key_channels),
  141. nn.ReLU(True)
  142. )
  143. self.f_query = self.f_key
  144. self.W = nn.Conv2d(value_channels, out_channels, 1)
  145. nn.init.constant_(self.W.weight, 0)
  146. nn.init.constant_(self.W.bias, 0)
  147. def forward(self, x):
  148. batch_size, c, w, h = x.size()
  149. local_x = list()
  150. local_y = list()
  151. step_w, step_h = w // self.scale, h // self.scale
  152. for i in range(self.scale):
  153. for j in range(self.scale):
  154. start_x, start_y = step_w * i, step_h * j
  155. end_x, end_y = min(start_x + step_w, w), min(start_y + step_h, h)
  156. if i == (self.scale - 1):
  157. end_x = w
  158. if j == (self.scale - 1):
  159. end_y = h
  160. local_x += [start_x, end_x]
  161. local_y += [start_y, end_y]
  162. value = self.f_value(x)
  163. query = self.f_query(x)
  164. key = self.f_key(x)
  165. local_list = list()
  166. local_block_cnt = (self.scale ** 2) * 2
  167. for i in range(0, local_block_cnt, 2):
  168. value_local = value[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]]
  169. query_local = query[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]]
  170. key_local = key[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]]
  171. w_local, h_local = value_local.size(2), value_local.size(3)
  172. value_local = value_local.contiguous().view(batch_size, self.value_channels, -1).permute(0, 2, 1)
  173. query_local = query_local.contiguous().view(batch_size, self.key_channels, -1).permute(0, 2, 1)
  174. key_local = key_local.contiguous().view(batch_size, self.key_channels, -1)
  175. sim_map = torch.bmm(query_local, key_local) * (self.key_channels ** -.5)
  176. sim_map = F.softmax(sim_map, dim=-1)
  177. context_local = torch.bmm(sim_map, value_local).permute(0, 2, 1).contiguous()
  178. context_local = context_local.view(batch_size, self.value_channels, w_local, h_local)
  179. local_list.append(context_local)
  180. context_list = list()
  181. for i in range(0, self.scale):
  182. row_tmp = list()
  183. for j in range(self.scale):
  184. row_tmp.append(local_list[j + i * self.scale])
  185. context_list.append(torch.cat(row_tmp, 3))
  186. context = torch.cat(context_list, 2)
  187. context = self.W(context)
  188. return context
  189. class PyramidOCModule(nn.Module):
  190. """Pyramid-OC"""
  191. def __init__(self, in_channels, out_channels, key_channels, value_channels,
  192. scales=([1]), norm_layer=nn.BatchNorm2d, **kwargs):
  193. super(PyramidOCModule, self).__init__()
  194. self.stages = nn.ModuleList([
  195. PyramidAttentionBlock(in_channels, out_channels, key_channels, value_channels, scale, norm_layer, **kwargs)
  196. for scale in scales])
  197. self.up_dr = nn.Sequential(
  198. nn.Conv2d(in_channels, in_channels * len(scales), 1),
  199. norm_layer(in_channels * len(scales)),
  200. nn.ReLU(True)
  201. )
  202. self.project = nn.Sequential(
  203. nn.Conv2d(in_channels * len(scales) * 2, out_channels, 1),
  204. norm_layer(out_channels),
  205. nn.ReLU(True),
  206. nn.Dropout2d(0.05)
  207. )
  208. def forward(self, x):
  209. priors = [stage(x) for stage in self.stages]
  210. context = [self.up_dr(x)]
  211. for i in range(len(priors)):
  212. context += [priors[i]]
  213. context = torch.cat(context, 1)
  214. out = self.project(context)
  215. return out
  216. class ASPOCModule(nn.Module):
  217. """ASP-OC"""
  218. def __init__(self, in_channels, out_channels, key_channels, value_channels,
  219. atrous_rates=(12, 24, 36), norm_layer=nn.BatchNorm2d, **kwargs):
  220. super(ASPOCModule, self).__init__()
  221. self.context = nn.Sequential(
  222. nn.Conv2d(in_channels, out_channels, 3, padding=1),
  223. norm_layer(out_channels),
  224. nn.ReLU(True),
  225. BaseOCModule(out_channels, out_channels, key_channels, value_channels, ([2]), norm_layer, False, **kwargs))
  226. rate1, rate2, rate3 = tuple(atrous_rates)
  227. self.b1 = nn.Sequential(
  228. nn.Conv2d(in_channels, out_channels, 3, padding=rate1, dilation=rate1, bias=False),
  229. norm_layer(out_channels),
  230. nn.ReLU(True))
  231. self.b2 = nn.Sequential(
  232. nn.Conv2d(in_channels, out_channels, 3, padding=rate2, dilation=rate2, bias=False),
  233. norm_layer(out_channels),
  234. nn.ReLU(True))
  235. self.b3 = nn.Sequential(
  236. nn.Conv2d(in_channels, out_channels, 3, padding=rate3, dilation=rate3, bias=False),
  237. norm_layer(out_channels),
  238. nn.ReLU(True))
  239. self.b4 = nn.Sequential(
  240. nn.Conv2d(in_channels, out_channels, 1, bias=False),
  241. norm_layer(out_channels),
  242. nn.ReLU(True))
  243. self.project = nn.Sequential(
  244. nn.Conv2d(out_channels * 5, out_channels, 1, bias=False),
  245. norm_layer(out_channels),
  246. nn.ReLU(True),
  247. nn.Dropout2d(0.1)
  248. )
  249. def forward(self, x):
  250. feat1 = self.context(x)
  251. feat2 = self.b1(x)
  252. feat3 = self.b2(x)
  253. feat4 = self.b3(x)
  254. feat5 = self.b4(x)
  255. out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
  256. out = self.project(out)
  257. return out
  258. def get_ocnet(dataset='citys', backbone='resnet50', oc_arch='base', pretrained=False, root='~/.torch/models',
  259. pretrained_base=True, **kwargs):
  260. acronyms = {
  261. 'pascal_voc': 'pascal_voc',
  262. 'pascal_aug': 'pascal_aug',
  263. 'ade20k': 'ade',
  264. 'coco': 'coco',
  265. 'citys': 'citys',
  266. }
  267. from ..data.dataloader import datasets
  268. model = OCNet(datasets[dataset].NUM_CLASS, backbone=backbone, oc_arch=oc_arch,
  269. pretrained_base=pretrained_base, **kwargs)
  270. if pretrained:
  271. from .model_store import get_model_file
  272. device = torch.device(kwargs['local_rank'])
  273. model.load_state_dict(torch.load(get_model_file('%s_ocnet_%s_%s' % (
  274. oc_arch, backbone, acronyms[dataset]), root=root),
  275. map_location=device))
  276. return model
  277. def get_base_ocnet_resnet101_citys(**kwargs):
  278. return get_ocnet('citys', 'resnet101', 'base', **kwargs)
  279. def get_pyramid_ocnet_resnet101_citys(**kwargs):
  280. return get_ocnet('citys', 'resnet101', 'pyramid', **kwargs)
  281. def get_asp_ocnet_resnet101_citys(**kwargs):
  282. return get_ocnet('citys', 'resnet101', 'asp', **kwargs)
  283. if __name__ == '__main__':
  284. #img = torch.randn(1, 3, 256, 256)
  285. #model = get_asp_ocnet_resnet101_citys()
  286. # outputs = model(img)
  287. input = torch.rand(1, 3, 224,224)
  288. model=OCNet(4,pretrained_base=False)
  289. #target = torch.zeros(4, 512, 512).cuda()
  290. #model.eval()
  291. #print(model)
  292. loss = model(input)
  293. print(loss,loss.shape)
  294. # from torchsummary import summary
  295. #
  296. # summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
  297. import torch
  298. from thop import profile
  299. from torchsummary import summary
  300. flop,params=profile(model,input_size=(1,3,512,512))
  301. print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop/1e9, params/1e6))