高速公路违停检测
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

85 lines
3.4KB

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as functional
  4. from models._util import try_index
  5. from .bn import ABN
  6. class DeeplabV3(nn.Module):
  7. def __init__(self,
  8. in_channels,
  9. out_channels,
  10. hidden_channels=256,
  11. dilations=(12, 24, 36),
  12. norm_act=ABN,
  13. pooling_size=None):
  14. super(DeeplabV3, self).__init__()
  15. self.pooling_size = pooling_size
  16. self.map_convs = nn.ModuleList([
  17. nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
  18. nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
  19. nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
  20. nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
  21. ])
  22. self.map_bn = norm_act(hidden_channels * 4)
  23. self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
  24. self.global_pooling_bn = norm_act(hidden_channels)
  25. self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
  26. self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
  27. self.red_bn = norm_act(out_channels)
  28. self.reset_parameters(self.map_bn.activation, self.map_bn.slope)
  29. def reset_parameters(self, activation, slope):
  30. gain = nn.init.calculate_gain(activation, slope)
  31. for m in self.modules():
  32. if isinstance(m, nn.Conv2d):
  33. nn.init.xavier_normal_(m.weight.data, gain)
  34. if hasattr(m, "bias") and m.bias is not None:
  35. nn.init.constant_(m.bias, 0)
  36. elif isinstance(m, ABN):
  37. if hasattr(m, "weight") and m.weight is not None:
  38. nn.init.constant_(m.weight, 1)
  39. if hasattr(m, "bias") and m.bias is not None:
  40. nn.init.constant_(m.bias, 0)
  41. def forward(self, x):
  42. # Map convolutions
  43. out = torch.cat([m(x) for m in self.map_convs], dim=1)
  44. out = self.map_bn(out)
  45. out = self.red_conv(out)
  46. # Global pooling
  47. pool = self._global_pooling(x)
  48. pool = self.global_pooling_conv(pool)
  49. pool = self.global_pooling_bn(pool)
  50. pool = self.pool_red_conv(pool)
  51. if self.training or self.pooling_size is None:
  52. pool = pool.repeat(1, 1, x.size(2), x.size(3))
  53. out += pool
  54. out = self.red_bn(out)
  55. return out
  56. def _global_pooling(self, x):
  57. if self.training or self.pooling_size is None:
  58. pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
  59. pool = pool.view(x.size(0), x.size(1), 1, 1)
  60. else:
  61. pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
  62. min(try_index(self.pooling_size, 1), x.shape[3]))
  63. padding = (
  64. (pooling_size[1] - 1) // 2,
  65. (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
  66. (pooling_size[0] - 1) // 2,
  67. (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
  68. )
  69. pool = functional.avg_pool2d(x, pooling_size, stride=1)
  70. pool = functional.pad(pool, pad=padding, mode="replicate")
  71. return pool