用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
5.0KB

  1. """Basic Module for Semantic Segmentation"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. __all__ = ['_ConvBNPReLU', '_ConvBN', '_BNPReLU', '_ConvBNReLU', '_DepthwiseConv', 'InvertedResidual']
  6. class _ConvBNReLU(nn.Module):
  7. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
  8. dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d, **kwargs):
  9. super(_ConvBNReLU, self).__init__()
  10. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
  11. self.bn = norm_layer(out_channels)
  12. self.relu = nn.ReLU6(False) if relu6 else nn.ReLU(False)
  13. def forward(self, x):
  14. x = self.conv(x)
  15. x = self.bn(x)
  16. x = self.relu(x)
  17. return x
  18. class _ConvBNPReLU(nn.Module):
  19. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
  20. dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs):
  21. super(_ConvBNPReLU, self).__init__()
  22. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
  23. self.bn = norm_layer(out_channels)
  24. self.prelu = nn.PReLU(out_channels)
  25. def forward(self, x):
  26. x = self.conv(x)
  27. x = self.bn(x)
  28. x = self.prelu(x)
  29. return x
  30. class _ConvBN(nn.Module):
  31. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
  32. dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs):
  33. super(_ConvBN, self).__init__()
  34. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
  35. self.bn = norm_layer(out_channels)
  36. def forward(self, x):
  37. x = self.conv(x)
  38. x = self.bn(x)
  39. return x
  40. class _BNPReLU(nn.Module):
  41. def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
  42. super(_BNPReLU, self).__init__()
  43. self.bn = norm_layer(out_channels)
  44. self.prelu = nn.PReLU(out_channels)
  45. def forward(self, x):
  46. x = self.bn(x)
  47. x = self.prelu(x)
  48. return x
  49. # -----------------------------------------------------------------
  50. # For PSPNet
  51. # -----------------------------------------------------------------
  52. class _PSPModule(nn.Module):
  53. def __init__(self, in_channels, sizes=(1, 2, 3, 6), **kwargs):
  54. super(_PSPModule, self).__init__()
  55. out_channels = int(in_channels / 4)
  56. self.avgpools = nn.ModuleList()
  57. self.convs = nn.ModuleList()
  58. for size in sizes:
  59. self.avgpool.append(nn.AdaptiveAvgPool2d(size))
  60. self.convs.append(_ConvBNReLU(in_channels, out_channels, 1, **kwargs))
  61. def forward(self, x):
  62. size = x.size()[2:]
  63. feats = [x]
  64. for (avgpool, conv) in enumerate(zip(self.avgpools, self.convs)):
  65. feats.append(F.interpolate(conv(avgpool(x)), size, mode='bilinear', align_corners=True))
  66. return torch.cat(feats, dim=1)
  67. # -----------------------------------------------------------------
  68. # For MobileNet
  69. # -----------------------------------------------------------------
  70. class _DepthwiseConv(nn.Module):
  71. """conv_dw in MobileNet"""
  72. def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs):
  73. super(_DepthwiseConv, self).__init__()
  74. self.conv = nn.Sequential(
  75. _ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer),
  76. _ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer))
  77. def forward(self, x):
  78. return self.conv(x)
  79. # -----------------------------------------------------------------
  80. # For MobileNetV2
  81. # -----------------------------------------------------------------
  82. class InvertedResidual(nn.Module):
  83. def __init__(self, in_channels, out_channels, stride, expand_ratio, norm_layer=nn.BatchNorm2d, **kwargs):
  84. super(InvertedResidual, self).__init__()
  85. assert stride in [1, 2]
  86. self.use_res_connect = stride == 1 and in_channels == out_channels
  87. layers = list()
  88. inter_channels = int(round(in_channels * expand_ratio))
  89. if expand_ratio != 1:
  90. # pw
  91. layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer))
  92. layers.extend([
  93. # dw
  94. _ConvBNReLU(inter_channels, inter_channels, 3, stride, 1,
  95. groups=inter_channels, relu6=True, norm_layer=norm_layer),
  96. # pw-linear
  97. nn.Conv2d(inter_channels, out_channels, 1, bias=False),
  98. norm_layer(out_channels)])
  99. self.conv = nn.Sequential(*layers)
  100. def forward(self, x):
  101. if self.use_res_connect:
  102. return x + self.conv(x)
  103. else:
  104. return self.conv(x)
  105. if __name__ == '__main__':
  106. x = torch.randn(1, 32, 64, 64)
  107. model = InvertedResidual(32, 64, 2, 1)
  108. out = model(x)