高速公路违停检测
No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

43 líneas
1.4KB

  1. from collections import OrderedDict
  2. import torch
  3. import torch.nn as nn
  4. from .bn import ABN
  5. class DenseModule(nn.Module):
  6. def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
  7. super(DenseModule, self).__init__()
  8. self.in_channels = in_channels
  9. self.growth = growth
  10. self.layers = layers
  11. self.convs1 = nn.ModuleList()
  12. self.convs3 = nn.ModuleList()
  13. for i in range(self.layers):
  14. self.convs1.append(nn.Sequential(OrderedDict([
  15. ("bn", norm_act(in_channels)),
  16. ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
  17. ])))
  18. self.convs3.append(nn.Sequential(OrderedDict([
  19. ("bn", norm_act(self.growth * bottleneck_factor)),
  20. ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
  21. dilation=dilation))
  22. ])))
  23. in_channels += self.growth
  24. @property
  25. def out_channels(self):
  26. return self.in_channels + self.growth * self.layers
  27. def forward(self, x):
  28. inputs = [x]
  29. for i in range(self.layers):
  30. x = torch.cat(inputs, dim=1)
  31. x = self.convs1[i](x)
  32. x = self.convs3[i](x)
  33. inputs += [x]
  34. return torch.cat(inputs, dim=1)