用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

83 lines
2.8KB

  1. """Fully Convolutional Network with Stride of 8"""
  2. from __future__ import division
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from .segbase import SegBaseModel
  7. __all__ = ['FCN', 'get_fcn', 'get_fcn_resnet50_voc',
  8. 'get_fcn_resnet101_voc', 'get_fcn_resnet152_voc']
  9. class FCN(SegBaseModel):
  10. def __init__(self, nclass, backbone='resnet50', aux=True, pretrained_base=True, **kwargs):
  11. super(FCN, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
  12. self.head = _FCNHead(2048, nclass, **kwargs)
  13. if aux:
  14. self.auxlayer = _FCNHead(1024, nclass, **kwargs)
  15. self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
  16. def forward(self, x):
  17. size = x.size()[2:]
  18. _, _, c3, c4 = self.base_forward(x)
  19. outputs = []
  20. x = self.head(c4)
  21. x = F.interpolate(x, size, mode='bilinear', align_corners=True)
  22. outputs.append(x)
  23. if self.aux:
  24. auxout = self.auxlayer(c3)
  25. auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
  26. outputs.append(auxout)
  27. return tuple(outputs)
  28. class _FCNHead(nn.Module):
  29. def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
  30. super(_FCNHead, self).__init__()
  31. inter_channels = in_channels // 4
  32. self.block = nn.Sequential(
  33. nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
  34. norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
  35. nn.ReLU(True),
  36. nn.Dropout(0.1),
  37. nn.Conv2d(inter_channels, channels, 1)
  38. )
  39. def forward(self, x):
  40. return self.block(x)
  41. def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models',
  42. pretrained_base=True, **kwargs):
  43. acronyms = {
  44. 'pascal_voc': 'pascal_voc',
  45. 'pascal_aug': 'pascal_aug',
  46. 'ade20k': 'ade',
  47. 'coco': 'coco',
  48. 'citys': 'citys',
  49. }
  50. from ..data.dataloader import datasets
  51. model = FCN(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  52. if pretrained:
  53. from .model_store import get_model_file
  54. device = torch.device(kwargs['local_rank'])
  55. model.load_state_dict(torch.load(get_model_file('fcn_%s_%s' % (backbone, acronyms[dataset]), root=root),
  56. map_location=device))
  57. return model
  58. def get_fcn_resnet50_voc(**kwargs):
  59. return get_fcn('pascal_voc', 'resnet50', **kwargs)
  60. def get_fcn_resnet101_voc(**kwargs):
  61. return get_fcn('pascal_voc', 'resnet101', **kwargs)
  62. def get_fcn_resnet152_voc(**kwargs):
  63. return get_fcn('pascal_voc', 'resnet152', **kwargs)