|
- import torch
- import torch.nn as nn
- from torch.nn import init
- import math
-
-
-
- class ConvX(nn.Module):
- def __init__(self, in_planes, out_planes, kernel=3, stride=1):
- super(ConvX, self).__init__()
- self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False)
- self.bn = nn.BatchNorm2d(out_planes)
- self.relu = nn.ReLU(inplace=True)
-
- def forward(self, x):
- out = self.relu(self.bn(self.conv(x)))
- return out
-
-
- class AddBottleneck(nn.Module):
- def __init__(self, in_planes, out_planes, block_num=3, stride=1):
- super(AddBottleneck, self).__init__()
- assert block_num > 1, print("block number should be larger than 1.")
- self.conv_list = nn.ModuleList()
- self.stride = stride
- if stride == 2:
- self.avd_layer = nn.Sequential(
- nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
- nn.BatchNorm2d(out_planes//2),
- )
- self.skip = nn.Sequential(
- nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes, bias=False),
- nn.BatchNorm2d(in_planes),
- nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
- nn.BatchNorm2d(out_planes),
- )
- stride = 1
-
- for idx in range(block_num):
- if idx == 0:
- self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
- elif idx == 1 and block_num == 2:
- self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))
- elif idx == 1 and block_num > 2:
- self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))
- elif idx < block_num - 1:
- self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1))))
- else:
- self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx))))
-
- def forward(self, x):
- out_list = []
- out = x
-
- for idx, conv in enumerate(self.conv_list):
- if idx == 0 and self.stride == 2:
- out = self.avd_layer(conv(out))
- else:
- out = conv(out)
- out_list.append(out)
-
- if self.stride == 2:
- x = self.skip(x)
-
- return torch.cat(out_list, dim=1) + x
-
-
-
- class CatBottleneck(nn.Module):
- def __init__(self, in_planes, out_planes, block_num=3, stride=1):
- super(CatBottleneck, self).__init__()
- assert block_num > 1, print("block number should be larger than 1.")
- self.conv_list = nn.ModuleList()
- self.stride = stride
- if stride == 2:
- self.avd_layer = nn.Sequential(
- nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
- nn.BatchNorm2d(out_planes//2),
- )
- self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
- stride = 1
-
- for idx in range(block_num):
- if idx == 0:
- self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
- elif idx == 1 and block_num == 2:
- self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))
- elif idx == 1 and block_num > 2:
- self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))
- elif idx < block_num - 1:
- self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1))))
- else:
- self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx))))
-
- def forward(self, x):
- out_list = []
- out1 = self.conv_list[0](x)
-
- for idx, conv in enumerate(self.conv_list[1:]):
- if idx == 0:
- if self.stride == 2:
- out = conv(self.avd_layer(out1))
- else:
- out = conv(out1)
- else:
- out = conv(out)
- out_list.append(out)
-
- if self.stride == 2:
- out1 = self.skip(out1)
- out_list.insert(0, out1)
-
- out = torch.cat(out_list, dim=1)
- return out
-
- #STDC2Net
- class STDCNet1446(nn.Module):
- def __init__(self, base=64, layers=[4,5,3], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False):
- super(STDCNet1446, self).__init__()
- if type == "cat":
- block = CatBottleneck
- elif type == "add":
- block = AddBottleneck
- self.use_conv_last = use_conv_last
- self.features = self._make_layers(base, layers, block_num, block)
- self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1)
- self.gap = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False)
- self.bn = nn.BatchNorm1d(max(1024, base*16))
- self.relu = nn.ReLU(inplace=True)
- self.dropout = nn.Dropout(p=dropout)
- self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False)
-
- self.x2 = nn.Sequential(self.features[:1])
- self.x4 = nn.Sequential(self.features[1:2])
- self.x8 = nn.Sequential(self.features[2:6])
- self.x16 = nn.Sequential(self.features[6:11])
- self.x32 = nn.Sequential(self.features[11:])
-
- if pretrain_model:
- print('use pretrain model {}'.format(pretrain_model))
- self.init_weight(pretrain_model)
- else:
- self.init_params()
-
- def init_weight(self, pretrain_model):
-
- state_dict = torch.load(pretrain_model)["state_dict"]
- self_state_dict = self.state_dict()
- for k, v in state_dict.items():
- self_state_dict.update({k: v})
- self.load_state_dict(self_state_dict)
-
- def init_params(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, mode='fan_out')
- if m.bias is not None:
- init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- init.constant_(m.weight, 1)
- init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- init.normal_(m.weight, std=0.001)
- if m.bias is not None:
- init.constant_(m.bias, 0)
-
- def _make_layers(self, base, layers, block_num, block):
- features = []
- features += [ConvX(3, base//2, 3, 2)]
- features += [ConvX(base//2, base, 3, 2)]
-
- for i, layer in enumerate(layers):
- for j in range(layer):
- if i == 0 and j == 0:
- features.append(block(base, base*4, block_num, 2))
- elif j == 0:
- features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
- else:
- features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))
-
- return nn.Sequential(*features)
-
- def forward(self, x):
- feat2 = self.x2(x)
- feat4 = self.x4(feat2)
- feat8 = self.x8(feat4)
- feat16 = self.x16(feat8)
- feat32 = self.x32(feat16)
- if self.use_conv_last:
- feat32 = self.conv_last(feat32)
-
- return feat2, feat4, feat8, feat16, feat32
-
- def forward_impl(self, x):
- out = self.features(x)
- out = self.conv_last(out).pow(2)
- out = self.gap(out).flatten(1)
- out = self.fc(out)
- # out = self.bn(out)
- out = self.relu(out)
- # out = self.relu(self.bn(self.fc(out)))
- out = self.dropout(out)
- out = self.linear(out)
- return out
-
- # STDC1Net
- class STDCNet813(nn.Module):
- def __init__(self, base=64, layers=[2,2,2], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False):
- super(STDCNet813, self).__init__()
- if type == "cat":
- block = CatBottleneck
- elif type == "add":
- block = AddBottleneck
- self.use_conv_last = use_conv_last
- self.features = self._make_layers(base, layers, block_num, block)
- self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1)
- self.gap = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False)
- self.bn = nn.BatchNorm1d(max(1024, base*16))
- self.relu = nn.ReLU(inplace=True)
- self.dropout = nn.Dropout(p=dropout)
- self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False)
-
- self.x2 = nn.Sequential(self.features[:1])
- self.x4 = nn.Sequential(self.features[1:2])
- self.x8 = nn.Sequential(self.features[2:4])
- self.x16 = nn.Sequential(self.features[4:6])
- self.x32 = nn.Sequential(self.features[6:])
-
- if pretrain_model:
- print('use pretrain model {}'.format(pretrain_model))
- self.init_weight(pretrain_model)
- else:
- self.init_params()
-
- def init_weight(self, pretrain_model):
-
- state_dict = torch.load(pretrain_model)["state_dict"]
- self_state_dict = self.state_dict()
- for k, v in state_dict.items():
- self_state_dict.update({k: v})
- self.load_state_dict(self_state_dict)
-
- def init_params(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, mode='fan_out')
- if m.bias is not None:
- init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- init.constant_(m.weight, 1)
- init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- init.normal_(m.weight, std=0.001)
- if m.bias is not None:
- init.constant_(m.bias, 0)
-
- def _make_layers(self, base, layers, block_num, block):
- features = []
- features += [ConvX(3, base//2, 3, 2)]
- features += [ConvX(base//2, base, 3, 2)]
-
- for i, layer in enumerate(layers):
- for j in range(layer):
- if i == 0 and j == 0:
- features.append(block(base, base*4, block_num, 2))
- elif j == 0:
- features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
- else:
- features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))
-
- return nn.Sequential(*features)
-
- def forward(self, x):
- feat2 = self.x2(x)
- feat4 = self.x4(feat2)
- feat8 = self.x8(feat4)
- feat16 = self.x16(feat8)
- feat32 = self.x32(feat16)
- if self.use_conv_last:
- feat32 = self.conv_last(feat32)
-
- return feat2, feat4, feat8, feat16, feat32
-
- def forward_impl(self, x):
- out = self.features(x)
- out = self.conv_last(out).pow(2)
- out = self.gap(out).flatten(1)
- out = self.fc(out)
- # out = self.bn(out)
- out = self.relu(out)
- # out = self.relu(self.bn(self.fc(out)))
- out = self.dropout(out)
- out = self.linear(out)
- return out
-
- if __name__ == "__main__":
- model = STDCNet813(num_classes=1000, dropout=0.00, block_num=4)
- model.eval()
- x = torch.randn(1,3,224,224)
- y = model(x)
- torch.save(model.state_dict(), 'cat.pth')
- print(y.size())
|