You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

265 line
11KB

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.init as init
  4. import torchvision
  5. from torchvision import models
  6. from collections import namedtuple
  7. from packaging import version
  8. def init_weights(modules):
  9. for m in modules:
  10. if isinstance(m, nn.Conv2d):
  11. init.xavier_uniform_(m.weight.data)
  12. if m.bias is not None:
  13. m.bias.data.zero_()
  14. elif isinstance(m, nn.BatchNorm2d):
  15. m.weight.data.fill_(1)
  16. m.bias.data.zero_()
  17. elif isinstance(m, nn.Linear):
  18. m.weight.data.normal_(0, 0.01)
  19. m.bias.data.zero_()
  20. class vgg16_bn(torch.nn.Module):
  21. def __init__(self, pretrained=True, freeze=True):
  22. super(vgg16_bn, self).__init__()
  23. if version.parse(torchvision.__version__) >= version.parse('0.13'):
  24. vgg_pretrained_features = models.vgg16_bn(
  25. weights=models.VGG16_BN_Weights.DEFAULT if pretrained else None
  26. ).features
  27. else: #torchvision.__version__ < 0.13
  28. models.vgg.model_urls['vgg16_bn'] = models.vgg.model_urls['vgg16_bn'].replace('https://', 'http://')
  29. vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
  30. self.slice1 = torch.nn.Sequential()
  31. self.slice2 = torch.nn.Sequential()
  32. self.slice3 = torch.nn.Sequential()
  33. self.slice4 = torch.nn.Sequential()
  34. self.slice5 = torch.nn.Sequential()
  35. for x in range(12): # conv2_2
  36. self.slice1.add_module(str(x), vgg_pretrained_features[x])
  37. for x in range(12, 19): # conv3_3
  38. self.slice2.add_module(str(x), vgg_pretrained_features[x])
  39. for x in range(19, 29): # conv4_3
  40. self.slice3.add_module(str(x), vgg_pretrained_features[x])
  41. for x in range(29, 39): # conv5_3
  42. self.slice4.add_module(str(x), vgg_pretrained_features[x])
  43. # fc6, fc7 without atrous conv
  44. self.slice5 = torch.nn.Sequential(
  45. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  46. nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
  47. nn.Conv2d(1024, 1024, kernel_size=1)
  48. )
  49. if not pretrained:
  50. init_weights(self.slice1.modules())
  51. init_weights(self.slice2.modules())
  52. init_weights(self.slice3.modules())
  53. init_weights(self.slice4.modules())
  54. init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
  55. if freeze:
  56. for param in self.slice1.parameters(): # only first conv
  57. param.requires_grad= False
  58. def forward(self, X):
  59. h = self.slice1(X)
  60. h_relu2_2 = h
  61. h = self.slice2(h)
  62. h_relu3_2 = h
  63. h = self.slice3(h)
  64. h_relu4_3 = h
  65. h = self.slice4(h)
  66. h_relu5_3 = h
  67. h = self.slice5(h)
  68. h_fc7 = h
  69. vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
  70. out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
  71. return out
  72. class BidirectionalLSTM(nn.Module):
  73. def __init__(self, input_size, hidden_size, output_size):
  74. super(BidirectionalLSTM, self).__init__()
  75. self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
  76. self.linear = nn.Linear(hidden_size * 2, output_size)
  77. def forward(self, input):
  78. """
  79. input : visual feature [batch_size x T x input_size]
  80. output : contextual feature [batch_size x T x output_size]
  81. """
  82. try: # multi gpu needs this
  83. self.rnn.flatten_parameters()
  84. except: # quantization doesn't work with this
  85. pass
  86. recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
  87. output = self.linear(recurrent) # batch_size x T x output_size
  88. return output
  89. class VGG_FeatureExtractor(nn.Module):
  90. def __init__(self, input_channel, output_channel=256):
  91. super(VGG_FeatureExtractor, self).__init__()
  92. self.output_channel = [int(output_channel / 8), int(output_channel / 4),
  93. int(output_channel / 2), output_channel]
  94. self.ConvNet = nn.Sequential(
  95. nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
  96. nn.MaxPool2d(2, 2),
  97. nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
  98. nn.MaxPool2d(2, 2),
  99. nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
  100. nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
  101. nn.MaxPool2d((2, 1), (2, 1)),
  102. nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
  103. nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
  104. nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
  105. nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
  106. nn.MaxPool2d((2, 1), (2, 1)),
  107. nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True))
  108. def forward(self, input):
  109. return self.ConvNet(input)
  110. class ResNet_FeatureExtractor(nn.Module):
  111. """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
  112. def __init__(self, input_channel, output_channel=512):
  113. super(ResNet_FeatureExtractor, self).__init__()
  114. self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3])
  115. def forward(self, input):
  116. return self.ConvNet(input)
  117. class BasicBlock(nn.Module):
  118. expansion = 1
  119. def __init__(self, inplanes, planes, stride=1, downsample=None):
  120. super(BasicBlock, self).__init__()
  121. self.conv1 = self._conv3x3(inplanes, planes)
  122. self.bn1 = nn.BatchNorm2d(planes)
  123. self.conv2 = self._conv3x3(planes, planes)
  124. self.bn2 = nn.BatchNorm2d(planes)
  125. self.relu = nn.ReLU(inplace=True)
  126. self.downsample = downsample
  127. self.stride = stride
  128. def _conv3x3(self, in_planes, out_planes, stride=1):
  129. "3x3 convolution with padding"
  130. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  131. padding=1, bias=False)
  132. def forward(self, x):
  133. residual = x
  134. out = self.conv1(x)
  135. out = self.bn1(out)
  136. out = self.relu(out)
  137. out = self.conv2(out)
  138. out = self.bn2(out)
  139. if self.downsample is not None:
  140. residual = self.downsample(x)
  141. out += residual
  142. out = self.relu(out)
  143. return out
  144. class ResNet(nn.Module):
  145. def __init__(self, input_channel, output_channel, block, layers):
  146. super(ResNet, self).__init__()
  147. self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
  148. self.inplanes = int(output_channel / 8)
  149. self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
  150. kernel_size=3, stride=1, padding=1, bias=False)
  151. self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
  152. self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
  153. kernel_size=3, stride=1, padding=1, bias=False)
  154. self.bn0_2 = nn.BatchNorm2d(self.inplanes)
  155. self.relu = nn.ReLU(inplace=True)
  156. self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  157. self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
  158. self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
  159. 0], kernel_size=3, stride=1, padding=1, bias=False)
  160. self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
  161. self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  162. self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
  163. self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
  164. 1], kernel_size=3, stride=1, padding=1, bias=False)
  165. self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
  166. self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
  167. self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
  168. self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
  169. 2], kernel_size=3, stride=1, padding=1, bias=False)
  170. self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
  171. self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
  172. self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
  173. 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
  174. self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
  175. self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
  176. 3], kernel_size=2, stride=1, padding=0, bias=False)
  177. self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
  178. def _make_layer(self, block, planes, blocks, stride=1):
  179. downsample = None
  180. if stride != 1 or self.inplanes != planes * block.expansion:
  181. downsample = nn.Sequential(
  182. nn.Conv2d(self.inplanes, planes * block.expansion,
  183. kernel_size=1, stride=stride, bias=False),
  184. nn.BatchNorm2d(planes * block.expansion),
  185. )
  186. layers = []
  187. layers.append(block(self.inplanes, planes, stride, downsample))
  188. self.inplanes = planes * block.expansion
  189. for i in range(1, blocks):
  190. layers.append(block(self.inplanes, planes))
  191. return nn.Sequential(*layers)
  192. def forward(self, x):
  193. x = self.conv0_1(x)
  194. x = self.bn0_1(x)
  195. x = self.relu(x)
  196. x = self.conv0_2(x)
  197. x = self.bn0_2(x)
  198. x = self.relu(x)
  199. x = self.maxpool1(x)
  200. x = self.layer1(x)
  201. x = self.conv1(x)
  202. x = self.bn1(x)
  203. x = self.relu(x)
  204. x = self.maxpool2(x)
  205. x = self.layer2(x)
  206. x = self.conv2(x)
  207. x = self.bn2(x)
  208. x = self.relu(x)
  209. x = self.maxpool3(x)
  210. x = self.layer3(x)
  211. x = self.conv3(x)
  212. x = self.bn3(x)
  213. x = self.relu(x)
  214. x = self.layer4(x)
  215. x = self.conv4_1(x)
  216. x = self.bn4_1(x)
  217. x = self.relu(x)
  218. x = self.conv4_2(x)
  219. x = self.bn4_2(x)
  220. x = self.relu(x)
  221. return x