高速公路违停检测
Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

43 Zeilen
1.4KB

  1. from collections import OrderedDict
  2. import torch
  3. import torch.nn as nn
  4. from .bn import ABN
  5. class DenseModule(nn.Module):
  6. def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
  7. super(DenseModule, self).__init__()
  8. self.in_channels = in_channels
  9. self.growth = growth
  10. self.layers = layers
  11. self.convs1 = nn.ModuleList()
  12. self.convs3 = nn.ModuleList()
  13. for i in range(self.layers):
  14. self.convs1.append(nn.Sequential(OrderedDict([
  15. ("bn", norm_act(in_channels)),
  16. ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
  17. ])))
  18. self.convs3.append(nn.Sequential(OrderedDict([
  19. ("bn", norm_act(self.growth * bottleneck_factor)),
  20. ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
  21. dilation=dilation))
  22. ])))
  23. in_channels += self.growth
  24. @property
  25. def out_channels(self):
  26. return self.in_channels + self.growth * self.layers
  27. def forward(self, x):
  28. inputs = [x]
  29. for i in range(self.layers):
  30. x = torch.cat(inputs, dim=1)
  31. x = self.convs1[i](x)
  32. x = self.convs3[i](x)
  33. inputs += [x]
  34. return torch.cat(inputs, dim=1)