@@ -0,0 +1,298 @@ | |||
"""Bilateral Segmentation Network""" | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import numpy as np | |||
from core.models.base_models.resnet import resnet18,resnet50 | |||
from core.nn import _ConvBNReLU | |||
__all__ = ['BiSeNet', 'get_bisenet', 'get_bisenet_resnet18_citys'] | |||
class BiSeNet(nn.Module): | |||
def __init__(self, nclass, backbone='resnet18', aux=False, jpu=False, pretrained_base=True, **kwargs): | |||
super(BiSeNet, self).__init__() | |||
self.aux = aux | |||
self.spatial_path = SpatialPath(3, 128, **kwargs) | |||
self.context_path = ContextPath(backbone, pretrained_base, **kwargs) | |||
self.ffm = FeatureFusion(256, 256, 4, **kwargs) | |||
self.head = _BiSeHead(256, 64, nclass, **kwargs) | |||
if aux: | |||
self.auxlayer1 = _BiSeHead(128, 256, nclass, **kwargs) | |||
self.auxlayer2 = _BiSeHead(128, 256, nclass, **kwargs) | |||
self.__setattr__('exclusive', | |||
['spatial_path', 'context_path', 'ffm', 'head', 'auxlayer1', 'auxlayer2'] if aux else [ | |||
'spatial_path', 'context_path', 'ffm', 'head']) | |||
def forward(self, x,outsize=None,test_flag=False): | |||
size = x.size()[2:] | |||
spatial_out = self.spatial_path(x) | |||
context_out = self.context_path(x) | |||
fusion_out = self.ffm(spatial_out, context_out[-1]) | |||
outputs = [] | |||
x = self.head(fusion_out) | |||
x = F.interpolate(x, size, mode='bilinear', align_corners=True) | |||
if outsize: | |||
print('######using torch resize#######',outsize) | |||
x = F.interpolate(x, outsize, mode='bilinear', align_corners=True) | |||
outputs.append(x) | |||
if self.aux: | |||
auxout1 = self.auxlayer1(context_out[0]) | |||
auxout1 = F.interpolate(auxout1, size, mode='bilinear', align_corners=True) | |||
outputs.append(auxout1) | |||
auxout2 = self.auxlayer2(context_out[1]) | |||
auxout2 = F.interpolate(auxout2, size, mode='bilinear', align_corners=True) | |||
outputs.append(auxout2) | |||
if test_flag: | |||
outputs = [torch.argmax(outputx ,axis=1) for outputx in outputs] | |||
#return tuple(outputs) | |||
return outputs[0] | |||
class BiSeNet_MultiOutput(nn.Module): | |||
def __init__(self, nclass, backbone='resnet18', aux=False, jpu=False, pretrained_base=True, **kwargs): | |||
super(BiSeNet_MultiOutput, self).__init__() | |||
self.aux = aux | |||
self.spatial_path = SpatialPath(3, 128, **kwargs) | |||
self.context_path = ContextPath(backbone, pretrained_base, **kwargs) | |||
self.ffm = FeatureFusion(256, 256, 4, **kwargs) | |||
assert isinstance(nclass,list) | |||
self.outCnt = len(nclass) | |||
for ii,nclassii in enumerate(nclass): | |||
setattr(self,'head%d'%(ii) , _BiSeHead(256, 64, nclassii, **kwargs)) | |||
if aux: | |||
self.auxlayer1 = _BiSeHead(128, 256, nclass, **kwargs) | |||
self.auxlayer2 = _BiSeHead(128, 256, nclass, **kwargs) | |||
self.__setattr__('exclusive', | |||
['spatial_path', 'context_path', 'ffm', 'head', 'auxlayer1', 'auxlayer2'] if aux else [ | |||
'spatial_path', 'context_path', 'ffm', 'head']) | |||
def forward(self, x,outsize=None,test_flag=False,smooth_kernel=0): | |||
size = x.size()[2:] | |||
spatial_out = self.spatial_path(x) | |||
context_out = self.context_path(x) | |||
fusion_out = self.ffm(spatial_out, context_out[-1]) | |||
outputs = [] | |||
for ii in range(self.outCnt): | |||
x = getattr(self,'head%d'%(ii))(fusion_out) | |||
x = F.interpolate(x, size, mode='bilinear', align_corners=True) | |||
outputs.append(x) | |||
if self.aux: | |||
auxout1 = self.auxlayer1(context_out[0]) | |||
auxout1 = F.interpolate(auxout1, size, mode='bilinear', align_corners=True) | |||
outputs.append(auxout1) | |||
auxout2 = self.auxlayer2(context_out[1]) | |||
auxout2 = F.interpolate(auxout2, size, mode='bilinear', align_corners=True) | |||
outputs.append(auxout2) | |||
if test_flag: | |||
outputs = [torch.argmax(outputx ,axis=1) for outputx in outputs] | |||
if smooth_kernel>0: | |||
gaussian_kernel = torch.from_numpy(np.ones((1,1,smooth_kernel,smooth_kernel)) ) | |||
pad = int((smooth_kernel - 1)/2) | |||
if not gaussian_kernel.is_cuda: | |||
gaussian_kernel = gaussian_kernel.to(x.device) | |||
#print(gaussian_kernel.dtype,gaussian_kernel,outputs[0].dtype) | |||
outputs = [ x.unsqueeze(1).double() for x in outputs] | |||
outputs = [torch.conv2d(x, gaussian_kernel, padding=pad) for x in outputs ] | |||
outputs = [ x.squeeze(1).long() for x in outputs] | |||
#return tuple(outputs) | |||
return outputs | |||
class _BiSeHead(nn.Module): | |||
def __init__(self, in_channels, inter_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs): | |||
super(_BiSeHead, self).__init__() | |||
self.block = nn.Sequential( | |||
_ConvBNReLU(in_channels, inter_channels, 3, 1, 1, norm_layer=norm_layer), | |||
nn.Dropout(0.1), | |||
nn.Conv2d(inter_channels, nclass, 1) | |||
) | |||
def forward(self, x): | |||
x = self.block(x) | |||
return x | |||
class SpatialPath(nn.Module): | |||
"""Spatial path""" | |||
def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs): | |||
super(SpatialPath, self).__init__() | |||
inter_channels = 64 | |||
self.conv7x7 = _ConvBNReLU(in_channels, inter_channels, 7, 2, 3, norm_layer=norm_layer) | |||
self.conv3x3_1 = _ConvBNReLU(inter_channels, inter_channels, 3, 2, 1, norm_layer=norm_layer) | |||
self.conv3x3_2 = _ConvBNReLU(inter_channels, inter_channels, 3, 2, 1, norm_layer=norm_layer) | |||
self.conv1x1 = _ConvBNReLU(inter_channels, out_channels, 1, 1, 0, norm_layer=norm_layer) | |||
def forward(self, x): | |||
x = self.conv7x7(x) | |||
x = self.conv3x3_1(x) | |||
x = self.conv3x3_2(x) | |||
x = self.conv1x1(x) | |||
return x | |||
class _GlobalAvgPooling(nn.Module): | |||
def __init__(self, in_channels, out_channels, norm_layer, **kwargs): | |||
super(_GlobalAvgPooling, self).__init__() | |||
self.gap = nn.Sequential( | |||
nn.AdaptiveAvgPool2d(1), | |||
nn.Conv2d(in_channels, out_channels, 1, bias=False), | |||
norm_layer(out_channels), | |||
nn.ReLU(True) | |||
) | |||
def forward(self, x): | |||
size = x.size()[2:] | |||
pool = self.gap(x) | |||
out = F.interpolate(pool, size, mode='bilinear', align_corners=True) | |||
return out | |||
class AttentionRefinmentModule(nn.Module): | |||
def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs): | |||
super(AttentionRefinmentModule, self).__init__() | |||
self.conv3x3 = _ConvBNReLU(in_channels, out_channels, 3, 1, 1, norm_layer=norm_layer) | |||
self.channel_attention = nn.Sequential( | |||
nn.AdaptiveAvgPool2d(1), | |||
_ConvBNReLU(out_channels, out_channels, 1, 1, 0, norm_layer=norm_layer), | |||
nn.Sigmoid() | |||
) | |||
def forward(self, x): | |||
x = self.conv3x3(x) | |||
attention = self.channel_attention(x) | |||
x = x * attention | |||
return x | |||
class ContextPath(nn.Module): | |||
def __init__(self, backbone='resnet18', pretrained_base=True, norm_layer=nn.BatchNorm2d, **kwargs): | |||
super(ContextPath, self).__init__() | |||
if backbone == 'resnet18': | |||
pretrained = resnet18(pretrained=pretrained_base, **kwargs) | |||
elif backbone=='resnet50': | |||
pretrained = resnet50(pretrained=pretrained_base, **kwargs) | |||
else: | |||
raise RuntimeError('unknown backbone: {}'.format(backbone)) | |||
self.conv1 = pretrained.conv1 | |||
self.bn1 = pretrained.bn1 | |||
self.relu = pretrained.relu | |||
self.maxpool = pretrained.maxpool | |||
self.layer1 = pretrained.layer1 | |||
self.layer2 = pretrained.layer2 | |||
self.layer3 = pretrained.layer3 | |||
self.layer4 = pretrained.layer4 | |||
inter_channels = 128 | |||
self.global_context = _GlobalAvgPooling(512, inter_channels, norm_layer) | |||
self.arms = nn.ModuleList( | |||
[AttentionRefinmentModule(512, inter_channels, norm_layer, **kwargs), | |||
AttentionRefinmentModule(256, inter_channels, norm_layer, **kwargs)] | |||
) | |||
self.refines = nn.ModuleList( | |||
[_ConvBNReLU(inter_channels, inter_channels, 3, 1, 1, norm_layer=norm_layer), | |||
_ConvBNReLU(inter_channels, inter_channels, 3, 1, 1, norm_layer=norm_layer)] | |||
) | |||
def forward(self, x): | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
x = self.maxpool(x) | |||
x = self.layer1(x) | |||
context_blocks = [] | |||
context_blocks.append(x) | |||
x = self.layer2(x) | |||
context_blocks.append(x) | |||
c3 = self.layer3(x) | |||
context_blocks.append(c3) | |||
c4 = self.layer4(c3) | |||
context_blocks.append(c4) | |||
context_blocks.reverse() | |||
global_context = self.global_context(c4) | |||
last_feature = global_context | |||
context_outputs = [] | |||
for i, (feature, arm, refine) in enumerate(zip(context_blocks[:2], self.arms, self.refines)): | |||
feature = arm(feature) | |||
feature += last_feature | |||
last_feature = F.interpolate(feature, size=context_blocks[i + 1].size()[2:], | |||
mode='bilinear', align_corners=True) | |||
last_feature = refine(last_feature) | |||
context_outputs.append(last_feature) | |||
return context_outputs | |||
class FeatureFusion(nn.Module): | |||
def __init__(self, in_channels, out_channels, reduction=1, norm_layer=nn.BatchNorm2d, **kwargs): | |||
super(FeatureFusion, self).__init__() | |||
self.conv1x1 = _ConvBNReLU(in_channels, out_channels, 1, 1, 0, norm_layer=norm_layer, **kwargs) | |||
self.channel_attention = nn.Sequential( | |||
nn.AdaptiveAvgPool2d(1), | |||
_ConvBNReLU(out_channels, out_channels // reduction, 1, 1, 0, norm_layer=norm_layer), | |||
_ConvBNReLU(out_channels // reduction, out_channels, 1, 1, 0, norm_layer=norm_layer), | |||
nn.Sigmoid() | |||
) | |||
def forward(self, x1, x2): | |||
fusion = torch.cat([x1, x2], dim=1) | |||
out = self.conv1x1(fusion) | |||
attention = self.channel_attention(out) | |||
out = out + out * attention | |||
return out | |||
def get_bisenet(dataset='citys', backbone='resnet18', pretrained=False, root='~/.torch/models', | |||
pretrained_base=True, **kwargs): | |||
acronyms = { | |||
'pascal_voc': 'pascal_voc', | |||
'pascal_aug': 'pascal_aug', | |||
'ade20k': 'ade', | |||
'coco': 'coco', | |||
'citys': 'citys', | |||
} | |||
from ..data.dataloader import datasets | |||
model = BiSeNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) | |||
if pretrained: | |||
from .model_store import get_model_file | |||
device = torch.device(kwargs['local_rank']) | |||
model.load_state_dict(torch.load(get_model_file('bisenet_%s_%s' % (backbone, acronyms[dataset]), root=root), | |||
map_location=device)) | |||
return model | |||
def get_bisenet_resnet18_citys(**kwargs): | |||
return get_bisenet('citys', 'resnet18', **kwargs) | |||
if __name__ == '__main__': | |||
# img = torch.randn(2, 3, 224, 224) | |||
# model = BiSeNet(19, backbone='resnet18') | |||
# print(model.exclusive) | |||
input = torch.rand(2, 3, 224, 224) | |||
model = BiSeNet(4, pretrained_base=True) | |||
# target = torch.zeros(4, 512, 512).cuda() | |||
# model.eval() | |||
# print(model) | |||
loss = model(input) | |||
print(loss, loss.shape) | |||
# from torchsummary import summary | |||
# | |||
# summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数 | |||
import torch | |||
from thop import profile | |||
from torchsummary import summary | |||
flop, params = profile(model, input_size=(1, 3, 512, 512)) | |||
print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6)) |
@@ -0,0 +1,3 @@ | |||
name,r,g,b | |||
0,0,0,0 | |||
1,255,255,255 |
@@ -0,0 +1,337 @@ | |||
#!/usr/bin/python | |||
# -*- encoding: utf-8 -*- | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torchvision | |||
import time | |||
from models_725.stdcnet import STDCNet1446, STDCNet813 | |||
#from models_725.bn import InPlaceABNSync as BatchNorm2d | |||
# BatchNorm2d = nn.BatchNorm2d | |||
class ConvBNReLU(nn.Module): | |||
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): | |||
super(ConvBNReLU, self).__init__() | |||
self.conv = nn.Conv2d(in_chan, | |||
out_chan, | |||
kernel_size = ks, | |||
stride = stride, | |||
padding = padding, | |||
bias = False) | |||
# self.bn = BatchNorm2d(out_chan) | |||
# self.bn = BatchNorm2d(out_chan, activation='none') | |||
self.bn = nn.BatchNorm2d(out_chan) | |||
self.relu = nn.ReLU() | |||
self.init_weight() | |||
def forward(self, x): | |||
x = self.conv(x) | |||
x = self.bn(x) | |||
x = self.relu(x) | |||
return x | |||
def init_weight(self): | |||
for ly in self.children(): | |||
if isinstance(ly, nn.Conv2d): | |||
nn.init.kaiming_normal_(ly.weight, a=1) | |||
if not ly.bias is None: nn.init.constant_(ly.bias, 0) | |||
class BiSeNetOutput(nn.Module): | |||
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): | |||
super(BiSeNetOutput, self).__init__() | |||
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) | |||
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) | |||
self.init_weight() | |||
def forward(self, x): | |||
x = self.conv(x) | |||
x = self.conv_out(x) | |||
return x | |||
def init_weight(self): | |||
for ly in self.children(): | |||
if isinstance(ly, nn.Conv2d): | |||
nn.init.kaiming_normal_(ly.weight, a=1) | |||
if not ly.bias is None: nn.init.constant_(ly.bias, 0) | |||
def get_params(self): | |||
wd_params, nowd_params = [], [] | |||
for name, module in self.named_modules(): | |||
if isinstance(module, (nn.Linear, nn.Conv2d)): | |||
wd_params.append(module.weight) | |||
if not module.bias is None: | |||
nowd_params.append(module.bias) | |||
elif isinstance(module, nn.BatchNorm2d):######################1 | |||
nowd_params += list(module.parameters()) | |||
return wd_params, nowd_params | |||
class AttentionRefinementModule(nn.Module): | |||
def __init__(self, in_chan, out_chan, *args, **kwargs): | |||
super(AttentionRefinementModule, self).__init__() | |||
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) | |||
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) | |||
# self.bn_atten = nn.BatchNorm2d(out_chan) | |||
# self.bn_atten = BatchNorm2d(out_chan, activation='none') | |||
self.bn_atten = nn.BatchNorm2d(out_chan)########################2 | |||
self.sigmoid_atten = nn.Sigmoid() | |||
self.init_weight() | |||
def forward(self, x): | |||
feat = self.conv(x) | |||
atten = F.avg_pool2d(feat, feat.size()[2:]) | |||
atten = self.conv_atten(atten) | |||
atten = self.bn_atten(atten) | |||
atten = self.sigmoid_atten(atten) | |||
out = torch.mul(feat, atten) | |||
return out | |||
def init_weight(self): | |||
for ly in self.children(): | |||
if isinstance(ly, nn.Conv2d): | |||
nn.init.kaiming_normal_(ly.weight, a=1) | |||
if not ly.bias is None: nn.init.constant_(ly.bias, 0) | |||
class ContextPath(nn.Module): | |||
def __init__(self, backbone='CatNetSmall', pretrain_model='', use_conv_last=False, *args, **kwargs): | |||
super(ContextPath, self).__init__() | |||
self.backbone_name = backbone | |||
if backbone == 'STDCNet1446': | |||
self.backbone = STDCNet1446(pretrain_model=pretrain_model, use_conv_last=use_conv_last) | |||
self.arm16 = AttentionRefinementModule(512, 128) | |||
inplanes = 1024 | |||
if use_conv_last: | |||
inplanes = 1024 | |||
self.arm32 = AttentionRefinementModule(inplanes, 128) | |||
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) | |||
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) | |||
self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0) | |||
elif backbone == 'STDCNet813': | |||
self.backbone = STDCNet813(pretrain_model=pretrain_model, use_conv_last=use_conv_last) | |||
self.arm16 = AttentionRefinementModule(512, 128) | |||
inplanes = 1024 | |||
if use_conv_last: | |||
inplanes = 1024 | |||
self.arm32 = AttentionRefinementModule(inplanes, 128) | |||
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) | |||
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) | |||
self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0) | |||
else: | |||
print("backbone is not in backbone lists") | |||
exit(0) | |||
self.init_weight() | |||
def forward(self, x): | |||
H0, W0 = x.size()[2:] | |||
feat2, feat4, feat8, feat16, feat32 = self.backbone(x) | |||
H8, W8 = feat8.size()[2:] | |||
H16, W16 = feat16.size()[2:] | |||
H32, W32 = feat32.size()[2:] | |||
avg = F.avg_pool2d(feat32, feat32.size()[2:]) | |||
avg = self.conv_avg(avg) | |||
avg_up = F.interpolate(avg, (H32, W32), mode='nearest') | |||
feat32_arm = self.arm32(feat32) | |||
feat32_sum = feat32_arm + avg_up | |||
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') | |||
feat32_up = self.conv_head32(feat32_up) | |||
feat16_arm = self.arm16(feat16) | |||
feat16_sum = feat16_arm + feat32_up | |||
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') | |||
feat16_up = self.conv_head16(feat16_up) | |||
return feat2, feat4, feat8, feat16, feat16_up, feat32_up # x8, x16 | |||
# return feat8, feat16_up # x8, x16 | |||
def init_weight(self): | |||
for ly in self.children(): | |||
if isinstance(ly, nn.Conv2d): | |||
nn.init.kaiming_normal_(ly.weight, a=1) | |||
if not ly.bias is None: nn.init.constant_(ly.bias, 0) | |||
def get_params(self): | |||
wd_params, nowd_params = [], [] | |||
for name, module in self.named_modules(): | |||
if isinstance(module, (nn.Linear, nn.Conv2d)): | |||
wd_params.append(module.weight) | |||
if not module.bias is None: | |||
nowd_params.append(module.bias) | |||
elif isinstance(module, nn.BatchNorm2d):#################3 | |||
nowd_params += list(module.parameters()) | |||
return wd_params, nowd_params | |||
class FeatureFusionModule(nn.Module): | |||
def __init__(self, in_chan, out_chan, *args, **kwargs): | |||
super(FeatureFusionModule, self).__init__() | |||
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) | |||
self.conv1 = nn.Conv2d(out_chan, | |||
out_chan//4, | |||
kernel_size = 1, | |||
stride = 1, | |||
padding = 0, | |||
bias = False) | |||
self.conv2 = nn.Conv2d(out_chan//4, | |||
out_chan, | |||
kernel_size = 1, | |||
stride = 1, | |||
padding = 0, | |||
bias = False) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.sigmoid = nn.Sigmoid() | |||
self.init_weight() | |||
def forward(self, fsp, fcp): | |||
fcat = torch.cat([fsp, fcp], dim=1) | |||
feat = self.convblk(fcat) | |||
atten = F.avg_pool2d(feat, feat.size()[2:]) | |||
atten = self.conv1(atten) | |||
atten = self.relu(atten) | |||
atten = self.conv2(atten) | |||
atten = self.sigmoid(atten) | |||
feat_atten = torch.mul(feat, atten) | |||
feat_out = feat_atten + feat | |||
return feat_out | |||
def init_weight(self): | |||
for ly in self.children(): | |||
if isinstance(ly, nn.Conv2d): | |||
nn.init.kaiming_normal_(ly.weight, a=1) | |||
if not ly.bias is None: nn.init.constant_(ly.bias, 0) | |||
def get_params(self): | |||
wd_params, nowd_params = [], [] | |||
for name, module in self.named_modules(): | |||
if isinstance(module, (nn.Linear, nn.Conv2d)): | |||
wd_params.append(module.weight) | |||
if not module.bias is None: | |||
nowd_params.append(module.bias) | |||
elif isinstance(module, nn.BatchNorm2d):##################4 | |||
nowd_params += list(module.parameters()) | |||
return wd_params, nowd_params | |||
class BiSeNet(nn.Module): | |||
def __init__(self, backbone, n_classes, pretrain_model='', use_boundary_2=False, use_boundary_4=False, | |||
use_boundary_8=False, use_boundary_16=False, use_conv_last=False): | |||
super(BiSeNet, self).__init__() | |||
self.use_boundary_2 = use_boundary_2 | |||
self.use_boundary_4 = use_boundary_4 | |||
self.use_boundary_8 = use_boundary_8 | |||
self.use_boundary_16 = use_boundary_16 | |||
# self.heat_map = heat_map | |||
self.cp = ContextPath(backbone, pretrain_model, use_conv_last=use_conv_last) | |||
if backbone == 'STDCNet1446': | |||
conv_out_inplanes = 128 | |||
sp2_inplanes = 32 | |||
sp4_inplanes = 64 | |||
sp8_inplanes = 256 | |||
sp16_inplanes = 512 | |||
inplane = sp8_inplanes + conv_out_inplanes | |||
elif backbone == 'STDCNet813': | |||
conv_out_inplanes = 128 | |||
sp2_inplanes = 32 | |||
sp4_inplanes = 64 | |||
sp8_inplanes = 256 | |||
sp16_inplanes = 512 | |||
inplane = sp8_inplanes + conv_out_inplanes | |||
else: | |||
print("backbone is not in backbone lists") | |||
exit(0) | |||
self.ffm = FeatureFusionModule(inplane, 256) | |||
self.conv_out = BiSeNetOutput(256, 256, n_classes) | |||
self.conv_out16 = BiSeNetOutput(conv_out_inplanes, 64, n_classes) | |||
self.conv_out32 = BiSeNetOutput(conv_out_inplanes, 64, n_classes) | |||
self.conv_out_sp16 = BiSeNetOutput(sp16_inplanes, 64, 1) | |||
self.conv_out_sp8 = BiSeNetOutput(sp8_inplanes, 64, 1) | |||
self.conv_out_sp4 = BiSeNetOutput(sp4_inplanes, 64, 1) | |||
self.conv_out_sp2 = BiSeNetOutput(sp2_inplanes, 64, 1) | |||
self.init_weight() | |||
def forward(self, x): | |||
H, W = x.size()[2:] | |||
# time_0 = time.time() | |||
# feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x) | |||
feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x) | |||
# print('----backbone', (time.time() - time_0) * 1000) | |||
# feat_out_sp2 = self.conv_out_sp2(feat_res2) | |||
# | |||
# feat_out_sp4 = self.conv_out_sp4(feat_res4) | |||
# | |||
# feat_out_sp8 = self.conv_out_sp8(feat_res8) | |||
# | |||
# feat_out_sp16 = self.conv_out_sp16(feat_res16) | |||
# time_1 = time.time() | |||
feat_fuse = self.ffm(feat_res8, feat_cp8) | |||
# print('----ffm', (time.time() - time_1) * 1000) | |||
# time_2 = time.time() | |||
feat_out = self.conv_out(feat_fuse) | |||
# feat_out16 = self.conv_out16(feat_cp8) | |||
# feat_out32 = self.conv_out32(feat_cp16) | |||
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) | |||
# print('----conv_out', (time.time() - time_2) * 1000) | |||
# feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) | |||
# feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) | |||
# if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8: | |||
# return feat_out, feat_out16, feat_out32, feat_out_sp2, feat_out_sp4, feat_out_sp8 | |||
# | |||
# if (not self.use_boundary_2) and self.use_boundary_4 and self.use_boundary_8: | |||
# return feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8 | |||
# | |||
# if (not self.use_boundary_2) and (not self.use_boundary_4) and self.use_boundary_8: | |||
return feat_out | |||
# if (not self.use_boundary_2) and (not self.use_boundary_4) and (not self.use_boundary_8): | |||
# return feat_out, feat_out16, feat_out32 | |||
def init_weight(self): | |||
for ly in self.children(): | |||
if isinstance(ly, nn.Conv2d): | |||
nn.init.kaiming_normal_(ly.weight, a=1) | |||
if not ly.bias is None: nn.init.constant_(ly.bias, 0) | |||
def get_params(self): | |||
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] | |||
for name, child in self.named_children(): | |||
child_wd_params, child_nowd_params = child.get_params() | |||
if isinstance(child, (FeatureFusionModule, BiSeNetOutput)): | |||
lr_mul_wd_params += child_wd_params | |||
lr_mul_nowd_params += child_nowd_params | |||
else: | |||
wd_params += child_wd_params | |||
nowd_params += child_nowd_params | |||
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params | |||
if __name__ == "__main__": | |||
net = BiSeNet('STDCNet813', 19) | |||
net.cuda() | |||
net.eval() | |||
in_ten = torch.randn(1, 3, 768, 1536).cuda() | |||
out, out16, out32 = net(in_ten) | |||
print(out.shape) | |||
# torch.save(net.state_dict(), 'STDCNet813.pth')### | |||
@@ -0,0 +1,69 @@ | |||
import torch | |||
from torchvision import transforms | |||
import cv2 | |||
import numpy as np | |||
import matplotlib.pyplot as plt | |||
from models_725.model_stages import BiSeNet | |||
import torch.nn.functional as F | |||
class SegModel(object): | |||
def __init__(self, nclass=2, device='cuda:0', | |||
respth='./model_maxmIOU75_1720_0.946_360640.pth', | |||
multiOutput=False): | |||
self.model = BiSeNet(backbone='STDCNet813', n_classes=nclass, | |||
use_boundary_2=False, use_boundary_4=False, | |||
use_boundary_8=True, use_boundary_16=False, | |||
use_conv_last=False) | |||
self.model.load_state_dict(torch.load(respth)) | |||
self.device = device | |||
self.multiOutput = multiOutput | |||
self.model= self.model.to(self.device) | |||
self.to_tensor = transforms.Compose([ | |||
transforms.ToTensor(), | |||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |||
]) | |||
self.mean = (0.485, 0.456, 0.406) | |||
self.std = (0.229, 0.224, 0.225) | |||
def eval(self, image=None): | |||
H, W, _ = image.shape | |||
img = self.preprocess_image(image) | |||
imgs = img.cuda() | |||
self.model.eval() | |||
logits = self.model(imgs) | |||
logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True) | |||
probs = torch.softmax(logits, dim=1) | |||
pred = torch.argmax(probs, dim=1) | |||
return pred, probs[0][1] | |||
def preprocess_image(self, image): | |||
image = cv2.resize(image, (640,360), interpolation=cv2.INTER_LINEAR) | |||
image = image.astype(np.float32) | |||
image /= 255.0 | |||
image[:, :, 0] -= self.mean[0] | |||
image[:, :, 1] -= self.mean[1] | |||
image[:, :, 2] -= self.mean[2] | |||
image[:, :, 0] /= self.std[0] | |||
image[:, :, 1] /= self.std[1] | |||
image[:, :, 2] /= self.std[2] | |||
image = np.transpose(image, (2, 0, 1)) | |||
image = torch.from_numpy(image).float() | |||
image = image.unsqueeze(0) | |||
return image | |||
@@ -0,0 +1,302 @@ | |||
import torch | |||
import torch.nn as nn | |||
from torch.nn import init | |||
import math | |||
class ConvX(nn.Module): | |||
def __init__(self, in_planes, out_planes, kernel=3, stride=1): | |||
super(ConvX, self).__init__() | |||
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False) | |||
self.bn = nn.BatchNorm2d(out_planes) | |||
self.relu = nn.ReLU(inplace=True) | |||
def forward(self, x): | |||
out = self.relu(self.bn(self.conv(x))) | |||
return out | |||
class AddBottleneck(nn.Module): | |||
def __init__(self, in_planes, out_planes, block_num=3, stride=1): | |||
super(AddBottleneck, self).__init__() | |||
assert block_num > 1, print("block number should be larger than 1.") | |||
self.conv_list = nn.ModuleList() | |||
self.stride = stride | |||
if stride == 2: | |||
self.avd_layer = nn.Sequential( | |||
nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False), | |||
nn.BatchNorm2d(out_planes//2), | |||
) | |||
self.skip = nn.Sequential( | |||
nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes, bias=False), | |||
nn.BatchNorm2d(in_planes), | |||
nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False), | |||
nn.BatchNorm2d(out_planes), | |||
) | |||
stride = 1 | |||
for idx in range(block_num): | |||
if idx == 0: | |||
self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1)) | |||
elif idx == 1 and block_num == 2: | |||
self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride)) | |||
elif idx == 1 and block_num > 2: | |||
self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride)) | |||
elif idx < block_num - 1: | |||
self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1)))) | |||
else: | |||
self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx)))) | |||
def forward(self, x): | |||
out_list = [] | |||
out = x | |||
for idx, conv in enumerate(self.conv_list): | |||
if idx == 0 and self.stride == 2: | |||
out = self.avd_layer(conv(out)) | |||
else: | |||
out = conv(out) | |||
out_list.append(out) | |||
if self.stride == 2: | |||
x = self.skip(x) | |||
return torch.cat(out_list, dim=1) + x | |||
class CatBottleneck(nn.Module): | |||
def __init__(self, in_planes, out_planes, block_num=3, stride=1): | |||
super(CatBottleneck, self).__init__() | |||
assert block_num > 1, print("block number should be larger than 1.") | |||
self.conv_list = nn.ModuleList() | |||
self.stride = stride | |||
if stride == 2: | |||
self.avd_layer = nn.Sequential( | |||
nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False), | |||
nn.BatchNorm2d(out_planes//2), | |||
) | |||
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) | |||
stride = 1 | |||
for idx in range(block_num): | |||
if idx == 0: | |||
self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1)) | |||
elif idx == 1 and block_num == 2: | |||
self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride)) | |||
elif idx == 1 and block_num > 2: | |||
self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride)) | |||
elif idx < block_num - 1: | |||
self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1)))) | |||
else: | |||
self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx)))) | |||
def forward(self, x): | |||
out_list = [] | |||
out1 = self.conv_list[0](x) | |||
for idx, conv in enumerate(self.conv_list[1:]): | |||
if idx == 0: | |||
if self.stride == 2: | |||
out = conv(self.avd_layer(out1)) | |||
else: | |||
out = conv(out1) | |||
else: | |||
out = conv(out) | |||
out_list.append(out) | |||
if self.stride == 2: | |||
out1 = self.skip(out1) | |||
out_list.insert(0, out1) | |||
out = torch.cat(out_list, dim=1) | |||
return out | |||
#STDC2Net | |||
class STDCNet1446(nn.Module): | |||
def __init__(self, base=64, layers=[4,5,3], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False): | |||
super(STDCNet1446, self).__init__() | |||
if type == "cat": | |||
block = CatBottleneck | |||
elif type == "add": | |||
block = AddBottleneck | |||
self.use_conv_last = use_conv_last | |||
self.features = self._make_layers(base, layers, block_num, block) | |||
self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1) | |||
self.gap = nn.AdaptiveAvgPool2d(1) | |||
self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False) | |||
self.bn = nn.BatchNorm1d(max(1024, base*16)) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.dropout = nn.Dropout(p=dropout) | |||
self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False) | |||
self.x2 = nn.Sequential(self.features[:1]) | |||
self.x4 = nn.Sequential(self.features[1:2]) | |||
self.x8 = nn.Sequential(self.features[2:6]) | |||
self.x16 = nn.Sequential(self.features[6:11]) | |||
self.x32 = nn.Sequential(self.features[11:]) | |||
if pretrain_model: | |||
print('use pretrain model {}'.format(pretrain_model)) | |||
self.init_weight(pretrain_model) | |||
else: | |||
self.init_params() | |||
def init_weight(self, pretrain_model): | |||
state_dict = torch.load(pretrain_model)["state_dict"] | |||
self_state_dict = self.state_dict() | |||
for k, v in state_dict.items(): | |||
self_state_dict.update({k: v}) | |||
self.load_state_dict(self_state_dict) | |||
def init_params(self): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
init.kaiming_normal_(m.weight, mode='fan_out') | |||
if m.bias is not None: | |||
init.constant_(m.bias, 0) | |||
elif isinstance(m, nn.BatchNorm2d): | |||
init.constant_(m.weight, 1) | |||
init.constant_(m.bias, 0) | |||
elif isinstance(m, nn.Linear): | |||
init.normal_(m.weight, std=0.001) | |||
if m.bias is not None: | |||
init.constant_(m.bias, 0) | |||
def _make_layers(self, base, layers, block_num, block): | |||
features = [] | |||
features += [ConvX(3, base//2, 3, 2)] | |||
features += [ConvX(base//2, base, 3, 2)] | |||
for i, layer in enumerate(layers): | |||
for j in range(layer): | |||
if i == 0 and j == 0: | |||
features.append(block(base, base*4, block_num, 2)) | |||
elif j == 0: | |||
features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2)) | |||
else: | |||
features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1)) | |||
return nn.Sequential(*features) | |||
def forward(self, x): | |||
feat2 = self.x2(x) | |||
feat4 = self.x4(feat2) | |||
feat8 = self.x8(feat4) | |||
feat16 = self.x16(feat8) | |||
feat32 = self.x32(feat16) | |||
if self.use_conv_last: | |||
feat32 = self.conv_last(feat32) | |||
return feat2, feat4, feat8, feat16, feat32 | |||
def forward_impl(self, x): | |||
out = self.features(x) | |||
out = self.conv_last(out).pow(2) | |||
out = self.gap(out).flatten(1) | |||
out = self.fc(out) | |||
# out = self.bn(out) | |||
out = self.relu(out) | |||
# out = self.relu(self.bn(self.fc(out))) | |||
out = self.dropout(out) | |||
out = self.linear(out) | |||
return out | |||
# STDC1Net | |||
class STDCNet813(nn.Module): | |||
def __init__(self, base=64, layers=[2,2,2], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False): | |||
super(STDCNet813, self).__init__() | |||
if type == "cat": | |||
block = CatBottleneck | |||
elif type == "add": | |||
block = AddBottleneck | |||
self.use_conv_last = use_conv_last | |||
self.features = self._make_layers(base, layers, block_num, block) | |||
self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1) | |||
self.gap = nn.AdaptiveAvgPool2d(1) | |||
self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False) | |||
self.bn = nn.BatchNorm1d(max(1024, base*16)) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.dropout = nn.Dropout(p=dropout) | |||
self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False) | |||
self.x2 = nn.Sequential(self.features[:1]) | |||
self.x4 = nn.Sequential(self.features[1:2]) | |||
self.x8 = nn.Sequential(self.features[2:4]) | |||
self.x16 = nn.Sequential(self.features[4:6]) | |||
self.x32 = nn.Sequential(self.features[6:]) | |||
if pretrain_model: | |||
print('use pretrain model {}'.format(pretrain_model)) | |||
self.init_weight(pretrain_model) | |||
else: | |||
self.init_params() | |||
def init_weight(self, pretrain_model): | |||
state_dict = torch.load(pretrain_model)["state_dict"] | |||
self_state_dict = self.state_dict() | |||
for k, v in state_dict.items(): | |||
self_state_dict.update({k: v}) | |||
self.load_state_dict(self_state_dict) | |||
def init_params(self): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
init.kaiming_normal_(m.weight, mode='fan_out') | |||
if m.bias is not None: | |||
init.constant_(m.bias, 0) | |||
elif isinstance(m, nn.BatchNorm2d): | |||
init.constant_(m.weight, 1) | |||
init.constant_(m.bias, 0) | |||
elif isinstance(m, nn.Linear): | |||
init.normal_(m.weight, std=0.001) | |||
if m.bias is not None: | |||
init.constant_(m.bias, 0) | |||
def _make_layers(self, base, layers, block_num, block): | |||
features = [] | |||
features += [ConvX(3, base//2, 3, 2)] | |||
features += [ConvX(base//2, base, 3, 2)] | |||
for i, layer in enumerate(layers): | |||
for j in range(layer): | |||
if i == 0 and j == 0: | |||
features.append(block(base, base*4, block_num, 2)) | |||
elif j == 0: | |||
features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2)) | |||
else: | |||
features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1)) | |||
return nn.Sequential(*features) | |||
def forward(self, x): | |||
feat2 = self.x2(x) | |||
feat4 = self.x4(feat2) | |||
feat8 = self.x8(feat4) | |||
feat16 = self.x16(feat8) | |||
feat32 = self.x32(feat16) | |||
if self.use_conv_last: | |||
feat32 = self.conv_last(feat32) | |||
return feat2, feat4, feat8, feat16, feat32 | |||
def forward_impl(self, x): | |||
out = self.features(x) | |||
out = self.conv_last(out).pow(2) | |||
out = self.gap(out).flatten(1) | |||
out = self.fc(out) | |||
# out = self.bn(out) | |||
out = self.relu(out) | |||
# out = self.relu(self.bn(self.fc(out))) | |||
out = self.dropout(out) | |||
out = self.linear(out) | |||
return out | |||
if __name__ == "__main__": | |||
model = STDCNet813(num_classes=1000, dropout=0.00, block_num=4) | |||
model.eval() | |||
x = torch.randn(1,3,224,224) | |||
y = model(x) | |||
torch.save(model.state_dict(), 'cat.pth') | |||
print(y.size()) |
@@ -21,6 +21,7 @@ def predict_lunkuo(impth=None): | |||
if __name__ == '__main__': | |||
impth = 'images/examples' | |||
outpth= 'images/results' | |||
folders = os.listdir(impth) | |||
segmodel = SegModel() | |||
for i in range(len(folders)): | |||
@@ -29,5 +30,6 @@ if __name__ == '__main__': | |||
img = Image.open(imgpath).convert('RGB') | |||
img = np.array(img) | |||
time11 = time.time() | |||
predict_lunkuo(impth=img) | |||
img=predict_lunkuo(impth=img) | |||
cv2.imwrite( os.path.join( outpth,folders[i] ) ,img ) | |||
print('----all_process', (time.time() - time11) * 1000) |