落水人员检测
Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

43 linhas
1.4KB

  1. from collections import OrderedDict
  2. import torch
  3. import torch.nn as nn
  4. from .bn import ABN
  5. class DenseModule(nn.Module):
  6. def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
  7. super(DenseModule, self).__init__()
  8. self.in_channels = in_channels
  9. self.growth = growth
  10. self.layers = layers
  11. self.convs1 = nn.ModuleList()
  12. self.convs3 = nn.ModuleList()
  13. for i in range(self.layers):
  14. self.convs1.append(nn.Sequential(OrderedDict([
  15. ("bn", norm_act(in_channels)),
  16. ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
  17. ])))
  18. self.convs3.append(nn.Sequential(OrderedDict([
  19. ("bn", norm_act(self.growth * bottleneck_factor)),
  20. ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
  21. dilation=dilation))
  22. ])))
  23. in_channels += self.growth
  24. @property
  25. def out_channels(self):
  26. return self.in_channels + self.growth * self.layers
  27. def forward(self, x):
  28. inputs = [x]
  29. for i in range(self.layers):
  30. x = torch.cat(inputs, dim=1)
  31. x = self.convs1[i](x)
  32. x = self.convs3[i](x)
  33. inputs += [x]
  34. return torch.cat(inputs, dim=1)