import torch import torch.nn as nn from torch.nn import init import math class ConvX(nn.Module): def __init__(self, in_planes, out_planes, kernel=3, stride=1): super(ConvX, self).__init__() self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False) self.bn = nn.BatchNorm2d(out_planes) self.relu = nn.ReLU(inplace=True) def forward(self, x): out = self.relu(self.bn(self.conv(x))) return out class AddBottleneck(nn.Module): def __init__(self, in_planes, out_planes, block_num=3, stride=1): super(AddBottleneck, self).__init__() assert block_num > 1, print("block number should be larger than 1.") self.conv_list = nn.ModuleList() self.stride = stride if stride == 2: self.avd_layer = nn.Sequential( nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False), nn.BatchNorm2d(out_planes//2), ) self.skip = nn.Sequential( nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes, bias=False), nn.BatchNorm2d(in_planes), nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False), nn.BatchNorm2d(out_planes), ) stride = 1 for idx in range(block_num): if idx == 0: self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1)) elif idx == 1 and block_num == 2: self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride)) elif idx == 1 and block_num > 2: self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride)) elif idx < block_num - 1: self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1)))) else: self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx)))) def forward(self, x): out_list = [] out = x for idx, conv in enumerate(self.conv_list): if idx == 0 and self.stride == 2: out = self.avd_layer(conv(out)) else: out = conv(out) out_list.append(out) if self.stride == 2: x = self.skip(x) return torch.cat(out_list, dim=1) + x class CatBottleneck(nn.Module): def __init__(self, in_planes, out_planes, block_num=3, stride=1): super(CatBottleneck, self).__init__() assert block_num > 1, print("block number should be larger than 1.") self.conv_list = nn.ModuleList() self.stride = stride if stride == 2: self.avd_layer = nn.Sequential( nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False), nn.BatchNorm2d(out_planes//2), ) self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) stride = 1 for idx in range(block_num): if idx == 0: self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1)) elif idx == 1 and block_num == 2: self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride)) elif idx == 1 and block_num > 2: self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride)) elif idx < block_num - 1: self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1)))) else: self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx)))) def forward(self, x): out_list = [] out1 = self.conv_list[0](x) for idx, conv in enumerate(self.conv_list[1:]): if idx == 0: if self.stride == 2: out = conv(self.avd_layer(out1)) else: out = conv(out1) else: out = conv(out) out_list.append(out) if self.stride == 2: out1 = self.skip(out1) out_list.insert(0, out1) out = torch.cat(out_list, dim=1) return out #STDC2Net class STDCNet1446(nn.Module): def __init__(self, base=64, layers=[4,5,3], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False): super(STDCNet1446, self).__init__() if type == "cat": block = CatBottleneck elif type == "add": block = AddBottleneck self.use_conv_last = use_conv_last self.features = self._make_layers(base, layers, block_num, block) self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1) self.gap = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False) self.bn = nn.BatchNorm1d(max(1024, base*16)) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout(p=dropout) self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False) self.x2 = nn.Sequential(self.features[:1]) self.x4 = nn.Sequential(self.features[1:2]) self.x8 = nn.Sequential(self.features[2:6]) self.x16 = nn.Sequential(self.features[6:11]) self.x32 = nn.Sequential(self.features[11:]) if pretrain_model: print('use pretrain model {}'.format(pretrain_model)) self.init_weight(pretrain_model) else: self.init_params() def init_weight(self, pretrain_model): state_dict = torch.load(pretrain_model)["state_dict"] self_state_dict = self.state_dict() for k, v in state_dict.items(): self_state_dict.update({k: v}) self.load_state_dict(self_state_dict) def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.001) if m.bias is not None: init.constant_(m.bias, 0) def _make_layers(self, base, layers, block_num, block): features = [] features += [ConvX(3, base//2, 3, 2)] features += [ConvX(base//2, base, 3, 2)] for i, layer in enumerate(layers): for j in range(layer): if i == 0 and j == 0: features.append(block(base, base*4, block_num, 2)) elif j == 0: features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2)) else: features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1)) return nn.Sequential(*features) def forward(self, x): feat2 = self.x2(x) feat4 = self.x4(feat2) feat8 = self.x8(feat4) feat16 = self.x16(feat8) feat32 = self.x32(feat16) if self.use_conv_last: feat32 = self.conv_last(feat32) return feat2, feat4, feat8, feat16, feat32 def forward_impl(self, x): out = self.features(x) out = self.conv_last(out).pow(2) out = self.gap(out).flatten(1) out = self.fc(out) # out = self.bn(out) out = self.relu(out) # out = self.relu(self.bn(self.fc(out))) out = self.dropout(out) out = self.linear(out) return out # STDC1Net class STDCNet813(nn.Module): def __init__(self, base=64, layers=[2,2,2], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False): super(STDCNet813, self).__init__() if type == "cat": block = CatBottleneck elif type == "add": block = AddBottleneck self.use_conv_last = use_conv_last self.features = self._make_layers(base, layers, block_num, block) self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1) self.gap = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False) self.bn = nn.BatchNorm1d(max(1024, base*16)) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout(p=dropout) self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False) self.x2 = nn.Sequential(self.features[:1]) self.x4 = nn.Sequential(self.features[1:2]) self.x8 = nn.Sequential(self.features[2:4]) self.x16 = nn.Sequential(self.features[4:6]) self.x32 = nn.Sequential(self.features[6:]) if pretrain_model: print('use pretrain model {}'.format(pretrain_model)) self.init_weight(pretrain_model) else: self.init_params() def init_weight(self, pretrain_model): state_dict = torch.load(pretrain_model)["state_dict"] self_state_dict = self.state_dict() for k, v in state_dict.items(): self_state_dict.update({k: v}) self.load_state_dict(self_state_dict) def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.001) if m.bias is not None: init.constant_(m.bias, 0) def _make_layers(self, base, layers, block_num, block): features = [] features += [ConvX(3, base//2, 3, 2)] features += [ConvX(base//2, base, 3, 2)] for i, layer in enumerate(layers): for j in range(layer): if i == 0 and j == 0: features.append(block(base, base*4, block_num, 2)) elif j == 0: features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2)) else: features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1)) return nn.Sequential(*features) def forward(self, x): feat2 = self.x2(x) feat4 = self.x4(feat2) feat8 = self.x8(feat4) feat16 = self.x16(feat8) feat32 = self.x32(feat16) if self.use_conv_last: feat32 = self.conv_last(feat32) return feat2, feat4, feat8, feat16, feat32 def forward_impl(self, x): out = self.features(x) out = self.conv_last(out).pow(2) out = self.gap(out).flatten(1) out = self.fc(out) # out = self.bn(out) out = self.relu(out) # out = self.relu(self.bn(self.fc(out))) out = self.dropout(out) out = self.linear(out) return out if __name__ == "__main__": model = STDCNet813(num_classes=1000, dropout=0.00, block_num=4) model.eval() x = torch.randn(1,3,224,224) y = model(x) torch.save(model.state_dict(), 'cat.pth') print(y.size())