Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Backbone modules.
  4. """
  5. from collections import OrderedDict
  6. import torch
  7. import torch.nn.functional as F
  8. import torchvision
  9. from torch import nn
  10. import sys,os
  11. sys.path.append(os.path.abspath(os.path.dirname(__file__)) )
  12. import vgg_ as models
  13. class BackboneBase_VGG(nn.Module):
  14. def __init__(self, backbone: nn.Module, num_channels: int, name: str, return_interm_layers: bool):
  15. super().__init__()
  16. features = list(backbone.features.children())
  17. if return_interm_layers:
  18. if name == 'vgg16_bn':
  19. self.body1 = nn.Sequential(*features[:13])
  20. self.body2 = nn.Sequential(*features[13:23])
  21. self.body3 = nn.Sequential(*features[23:33])
  22. self.body4 = nn.Sequential(*features[33:43])
  23. else:
  24. self.body1 = nn.Sequential(*features[:9])
  25. self.body2 = nn.Sequential(*features[9:16])
  26. self.body3 = nn.Sequential(*features[16:23])
  27. self.body4 = nn.Sequential(*features[23:30])
  28. else:
  29. if name == 'vgg16_bn':
  30. self.body = nn.Sequential(*features[:44]) # 16x down-sample
  31. elif name == 'vgg16':
  32. self.body = nn.Sequential(*features[:30]) # 16x down-sample
  33. self.num_channels = num_channels
  34. self.return_interm_layers = return_interm_layers
  35. def forward(self, tensor_list):
  36. out = []
  37. if self.return_interm_layers:
  38. xs = tensor_list
  39. for _, layer in enumerate([self.body1, self.body2, self.body3, self.body4]):
  40. xs = layer(xs)
  41. out.append(xs)
  42. else:
  43. xs = self.body(tensor_list)
  44. out.append(xs)
  45. return out
  46. class Backbone_VGG(BackboneBase_VGG):
  47. """ResNet backbone with frozen BatchNorm."""
  48. def __init__(self, name: str, return_interm_layers: bool):
  49. if name == 'vgg16_bn':
  50. backbone = models.vgg16_bn(pretrained=False)
  51. elif name == 'vgg16':
  52. backbone = models.vgg16(pretrained=False)
  53. num_channels = 256
  54. super().__init__(backbone, num_channels, name, return_interm_layers)
  55. def build_backbone(args):
  56. #backbone = Backbone_VGG(args.backbone, False)
  57. backbone = Backbone_VGG(args.backbone, True)
  58. return backbone
  59. if __name__ == '__main__':
  60. Backbone_VGG('vgg16', True)