高速公路违停检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

89 lines
3.4KB

  1. from collections import OrderedDict
  2. import torch.nn as nn
  3. from .bn import ABN
  4. class IdentityResidualBlock(nn.Module):
  5. def __init__(self,
  6. in_channels,
  7. channels,
  8. stride=1,
  9. dilation=1,
  10. groups=1,
  11. norm_act=ABN,
  12. dropout=None):
  13. """Configurable identity-mapping residual block
  14. Parameters
  15. ----------
  16. in_channels : int
  17. Number of input channels.
  18. channels : list of int
  19. Number of channels in the internal feature maps. Can either have two or three elements: if three construct
  20. a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
  21. `3 x 3` then `1 x 1` convolutions.
  22. stride : int
  23. Stride of the first `3 x 3` convolution
  24. dilation : int
  25. Dilation to apply to the `3 x 3` convolutions.
  26. groups : int
  27. Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
  28. bottleneck blocks.
  29. norm_act : callable
  30. Function to create normalization / activation Module.
  31. dropout: callable
  32. Function to create Dropout Module.
  33. """
  34. super(IdentityResidualBlock, self).__init__()
  35. # Check parameters for inconsistencies
  36. if len(channels) != 2 and len(channels) != 3:
  37. raise ValueError("channels must contain either two or three values")
  38. if len(channels) == 2 and groups != 1:
  39. raise ValueError("groups > 1 are only valid if len(channels) == 3")
  40. is_bottleneck = len(channels) == 3
  41. need_proj_conv = stride != 1 or in_channels != channels[-1]
  42. self.bn1 = norm_act(in_channels)
  43. if not is_bottleneck:
  44. layers = [
  45. ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
  46. dilation=dilation)),
  47. ("bn2", norm_act(channels[0])),
  48. ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
  49. dilation=dilation))
  50. ]
  51. if dropout is not None:
  52. layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
  53. else:
  54. layers = [
  55. ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
  56. ("bn2", norm_act(channels[0])),
  57. ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
  58. groups=groups, dilation=dilation)),
  59. ("bn3", norm_act(channels[1])),
  60. ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
  61. ]
  62. if dropout is not None:
  63. layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
  64. self.convs = nn.Sequential(OrderedDict(layers))
  65. if need_proj_conv:
  66. self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
  67. def forward(self, x):
  68. if hasattr(self, "proj_conv"):
  69. bn1 = self.bn1(x)
  70. shortcut = self.proj_conv(bn1)
  71. else:
  72. shortcut = x.clone()
  73. bn1 = self.bn1(x)
  74. out = self.convs(bn1)
  75. out.add_(shortcut)
  76. return out