高速公路违停检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

43 lines
1.4KB

  1. from collections import OrderedDict
  2. import torch
  3. import torch.nn as nn
  4. from .bn import ABN
  5. class DenseModule(nn.Module):
  6. def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
  7. super(DenseModule, self).__init__()
  8. self.in_channels = in_channels
  9. self.growth = growth
  10. self.layers = layers
  11. self.convs1 = nn.ModuleList()
  12. self.convs3 = nn.ModuleList()
  13. for i in range(self.layers):
  14. self.convs1.append(nn.Sequential(OrderedDict([
  15. ("bn", norm_act(in_channels)),
  16. ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
  17. ])))
  18. self.convs3.append(nn.Sequential(OrderedDict([
  19. ("bn", norm_act(self.growth * bottleneck_factor)),
  20. ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
  21. dilation=dilation))
  22. ])))
  23. in_channels += self.growth
  24. @property
  25. def out_channels(self):
  26. return self.in_channels + self.growth * self.layers
  27. def forward(self, x):
  28. inputs = [x]
  29. for i in range(self.layers):
  30. x = torch.cat(inputs, dim=1)
  31. x = self.convs1[i](x)
  32. x = self.convs3[i](x)
  33. inputs += [x]
  34. return torch.cat(inputs, dim=1)