Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.

195 lines
7.4KB

  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Mostly copy-paste from torchvision references.
  4. """
  5. import torch
  6. import torch.nn as nn
  7. __all__ = [
  8. 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
  9. 'vgg19_bn', 'vgg19',
  10. ]
  11. model_urls = {
  12. 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
  13. 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
  14. 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
  15. 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
  16. 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
  17. 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
  18. 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
  19. 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
  20. }
  21. model_paths = {
  22. 'vgg16_bn': './vggWeights/vgg16_bn-6c64b313.pth',
  23. 'vgg16': './vggWeights/vgg16-397923af.pth',
  24. }
  25. class VGG(nn.Module):
  26. def __init__(self, features, num_classes=1000, init_weights=True):
  27. super(VGG, self).__init__()
  28. self.features = features
  29. self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
  30. self.classifier = nn.Sequential(
  31. nn.Linear(512 * 7 * 7, 4096),
  32. nn.ReLU(True),
  33. nn.Dropout(),
  34. nn.Linear(4096, 4096),
  35. nn.ReLU(True),
  36. nn.Dropout(),
  37. nn.Linear(4096, num_classes),
  38. )
  39. if init_weights:
  40. self._initialize_weights()
  41. def forward(self, x):
  42. x = self.features(x)
  43. x = self.avgpool(x)
  44. x = torch.flatten(x, 1)
  45. x = self.classifier(x)
  46. return x
  47. def _initialize_weights(self):
  48. for m in self.modules():
  49. if isinstance(m, nn.Conv2d):
  50. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  51. if m.bias is not None:
  52. nn.init.constant_(m.bias, 0)
  53. elif isinstance(m, nn.BatchNorm2d):
  54. nn.init.constant_(m.weight, 1)
  55. nn.init.constant_(m.bias, 0)
  56. elif isinstance(m, nn.Linear):
  57. nn.init.normal_(m.weight, 0, 0.01)
  58. nn.init.constant_(m.bias, 0)
  59. def make_layers(cfg, batch_norm=False, sync=False):
  60. layers = []
  61. in_channels = 3
  62. for v in cfg:
  63. if v == 'M':
  64. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  65. else:
  66. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  67. if batch_norm:
  68. if sync:
  69. print('use sync backbone')
  70. layers += [conv2d, nn.SyncBatchNorm(v), nn.ReLU(inplace=True)]
  71. else:
  72. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
  73. else:
  74. layers += [conv2d, nn.ReLU(inplace=True)]
  75. in_channels = v
  76. return nn.Sequential(*layers)
  77. cfgs = {
  78. 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  79. 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  80. 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  81. 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
  82. }
  83. def _vgg(arch, cfg, batch_norm, pretrained, progress, sync=False, **kwargs):
  84. if pretrained:
  85. kwargs['init_weights'] = False
  86. model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm, sync=sync), **kwargs)
  87. if pretrained:
  88. state_dict = torch.load(model_paths[arch])
  89. model.load_state_dict(state_dict)
  90. return model
  91. def vgg11(pretrained=False, progress=True, **kwargs):
  92. r"""VGG 11-layer model (configuration "A") from
  93. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
  94. Args:
  95. pretrained (bool): If True, returns a model pre-trained on ImageNet
  96. progress (bool): If True, displays a progress bar of the download to stderr
  97. """
  98. return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
  99. def vgg11_bn(pretrained=False, progress=True, **kwargs):
  100. r"""VGG 11-layer model (configuration "A") with batch normalization
  101. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
  102. Args:
  103. pretrained (bool): If True, returns a model pre-trained on ImageNet
  104. progress (bool): If True, displays a progress bar of the download to stderr
  105. """
  106. return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
  107. def vgg13(pretrained=False, progress=True, **kwargs):
  108. r"""VGG 13-layer model (configuration "B")
  109. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
  110. Args:
  111. pretrained (bool): If True, returns a model pre-trained on ImageNet
  112. progress (bool): If True, displays a progress bar of the download to stderr
  113. """
  114. return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
  115. def vgg13_bn(pretrained=False, progress=True, **kwargs):
  116. r"""VGG 13-layer model (configuration "B") with batch normalization
  117. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
  118. Args:
  119. pretrained (bool): If True, returns a model pre-trained on ImageNet
  120. progress (bool): If True, displays a progress bar of the download to stderr
  121. """
  122. return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
  123. def vgg16(pretrained=False, progress=True, **kwargs):
  124. r"""VGG 16-layer model (configuration "D")
  125. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
  126. Args:
  127. pretrained (bool): If True, returns a model pre-trained on ImageNet
  128. progress (bool): If True, displays a progress bar of the download to stderr
  129. """
  130. return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
  131. def vgg16_bn(pretrained=False, progress=True, sync=False, **kwargs):
  132. r"""VGG 16-layer model (configuration "D") with batch normalization
  133. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
  134. Args:
  135. pretrained (bool): If True, returns a model pre-trained on ImageNet
  136. progress (bool): If True, displays a progress bar of the download to stderr
  137. """
  138. return _vgg('vgg16_bn', 'D', True, pretrained, progress, sync=sync, **kwargs)
  139. def vgg19(pretrained=False, progress=True, **kwargs):
  140. r"""VGG 19-layer model (configuration "E")
  141. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
  142. Args:
  143. pretrained (bool): If True, returns a model pre-trained on ImageNet
  144. progress (bool): If True, displays a progress bar of the download to stderr
  145. """
  146. return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
  147. def vgg19_bn(pretrained=False, progress=True, **kwargs):
  148. r"""VGG 19-layer model (configuration 'E') with batch normalization
  149. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
  150. Args:
  151. pretrained (bool): If True, returns a model pre-trained on ImageNet
  152. progress (bool): If True, displays a progress bar of the download to stderr
  153. """
  154. return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)