543 lines
22 KiB
Python
543 lines
22 KiB
Python
#!/usr/bin/python
|
||
# -*- encoding: utf-8 -*-
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torchvision
|
||
import time
|
||
from stdcnet import STDCNet1446, STDCNet813
|
||
#from models_725.bn import InPlaceABNSync as BatchNorm2d
|
||
# BatchNorm2d = nn.BatchNorm2d
|
||
|
||
#modelSize=(360,640) ##(W,H)
|
||
#print('######Attention model input(H,W):',modelSize)
|
||
class ConvBNReLU(nn.Module):
|
||
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
||
super(ConvBNReLU, self).__init__()
|
||
self.conv = nn.Conv2d(in_chan,
|
||
out_chan,
|
||
kernel_size = ks,
|
||
stride = stride,
|
||
padding = padding,
|
||
bias = False)
|
||
# self.bn = BatchNorm2d(out_chan)
|
||
# self.bn = BatchNorm2d(out_chan, activation='none')
|
||
self.bn = nn.BatchNorm2d(out_chan)
|
||
self.relu = nn.ReLU()
|
||
self.init_weight()
|
||
|
||
def forward(self, x):
|
||
x = self.conv(x)
|
||
x = self.bn(x)
|
||
x = self.relu(x)
|
||
return x
|
||
|
||
def init_weight(self):
|
||
for ly in self.children():
|
||
if isinstance(ly, nn.Conv2d):
|
||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||
|
||
|
||
class BiSeNetOutput(nn.Module):
|
||
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
||
super(BiSeNetOutput, self).__init__()
|
||
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
||
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
||
self.init_weight()
|
||
|
||
def forward(self, x):
|
||
x = self.conv(x)
|
||
x = self.conv_out(x)
|
||
return x
|
||
|
||
def init_weight(self):
|
||
for ly in self.children():
|
||
if isinstance(ly, nn.Conv2d):
|
||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||
|
||
def get_params(self):
|
||
wd_params, nowd_params = [], []
|
||
for name, module in self.named_modules():
|
||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||
wd_params.append(module.weight)
|
||
if not module.bias is None:
|
||
nowd_params.append(module.bias)
|
||
elif isinstance(module, nn.BatchNorm2d):######################1
|
||
nowd_params += list(module.parameters())
|
||
return wd_params, nowd_params
|
||
|
||
|
||
class AttentionRefinementModule_static(nn.Module):
|
||
def __init__(self, in_chan, out_chan,avg_pool2d_kernel_size, *args, **kwargs):
|
||
super(AttentionRefinementModule_static, self).__init__()
|
||
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
||
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
||
# self.bn_atten = nn.BatchNorm2d(out_chan)
|
||
# self.bn_atten = BatchNorm2d(out_chan, activation='none')
|
||
self.bn_atten = nn.BatchNorm2d(out_chan)########################2
|
||
|
||
self.sigmoid_atten = nn.Sigmoid()
|
||
self.avg_pool2d_kernel_size = avg_pool2d_kernel_size
|
||
self.init_weight()
|
||
|
||
def forward(self, x):
|
||
feat = self.conv(x)
|
||
#atten = F.avg_pool2d(feat, feat.size()[2:])
|
||
|
||
atten = F.avg_pool2d(feat, self.avg_pool2d_kernel_size)
|
||
#print('------------------newline89:','out:',atten.size(),'in:',feat.size(), self.avg_pool2d_kernel_size)
|
||
atten = self.conv_atten(atten)
|
||
atten = self.bn_atten(atten)
|
||
atten = self.sigmoid_atten(atten)
|
||
out = torch.mul(feat, atten)
|
||
return out
|
||
|
||
def init_weight(self):
|
||
for ly in self.children():
|
||
if isinstance(ly, nn.Conv2d):
|
||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||
class AttentionRefinementModule(nn.Module):
|
||
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
||
super(AttentionRefinementModule, self).__init__()
|
||
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
||
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
||
# self.bn_atten = nn.BatchNorm2d(out_chan)
|
||
# self.bn_atten = BatchNorm2d(out_chan, activation='none')
|
||
self.bn_atten = nn.BatchNorm2d(out_chan)########################2
|
||
|
||
self.sigmoid_atten = nn.Sigmoid()
|
||
|
||
self.init_weight()
|
||
|
||
def forward(self, x):
|
||
feat = self.conv(x)
|
||
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||
|
||
#atten = F.avg_pool2d(feat, self.avg_pool2d_kernel_size)
|
||
#print('------------------newline89:','out:',atten.size(),'in:',feat.size(), self.avg_pool2d_kernel_size)
|
||
atten = self.conv_atten(atten)
|
||
atten = self.bn_atten(atten)
|
||
atten = self.sigmoid_atten(atten)
|
||
out = torch.mul(feat, atten)
|
||
return out
|
||
|
||
def init_weight(self):
|
||
for ly in self.children():
|
||
if isinstance(ly, nn.Conv2d):
|
||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||
|
||
|
||
class ContextPath_static(nn.Module):
|
||
def __init__(self, backbone='CatNetSmall', pretrain_model='', use_conv_last=False,modelSize=(360,640), *args, **kwargs):
|
||
super(ContextPath_static, self).__init__()
|
||
|
||
self.backbone_name = backbone
|
||
self.modelSize = modelSize
|
||
|
||
self.avg_pool_kernel_size_32=[ int(modelSize[0]/32+0.999), int( modelSize[1]/32+0.999 ) ]
|
||
self.avg_pool_kernel_size_16=[ int(modelSize[0]/16+0.999), int( modelSize[1]/16+0.999 ) ]
|
||
|
||
if backbone == 'STDCNet1446':
|
||
self.backbone = STDCNet1446(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
|
||
self.arm16 = AttentionRefinementModule(512, 128)
|
||
inplanes = 1024
|
||
if use_conv_last:
|
||
inplanes = 1024
|
||
self.arm32 = AttentionRefinementModule(inplanes, 128)
|
||
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||
self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
|
||
|
||
elif backbone == 'STDCNet813':
|
||
self.backbone = STDCNet813(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
|
||
|
||
self.arm16 = AttentionRefinementModule_static(512, 128,self.avg_pool_kernel_size_16)
|
||
|
||
inplanes = 1024
|
||
if use_conv_last:
|
||
inplanes = 1024
|
||
|
||
self.arm32 = AttentionRefinementModule_static(inplanes, 128,self.avg_pool_kernel_size_32)
|
||
|
||
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||
self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
|
||
else:
|
||
print("backbone is not in backbone lists")
|
||
exit(0)
|
||
|
||
self.init_weight()
|
||
|
||
def forward(self, x):
|
||
H0, W0 = x.size()[2:]
|
||
|
||
feat2, feat4, feat8, feat16, feat32 = self.backbone(x)
|
||
print( '------------line179:', feat2.shape , feat4.shape, feat8.shape, feat16.shape, feat32.shape )
|
||
H8, W8 = feat8.size()[2:]
|
||
H16, W16 = feat16.size()[2:]
|
||
H32, W32 = feat32.size()[2:]
|
||
|
||
|
||
#avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
||
print('line147:self.avg_pool_kernel_size_32:',self.avg_pool_kernel_size_32,feat32.shape)
|
||
avg = F.avg_pool2d(feat32, self.avg_pool_kernel_size_32)
|
||
#print('------------------newline140:','out:','out;',avg.size(),' in:',feat32.size())
|
||
avg = self.conv_avg(avg)
|
||
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
||
#print('------------line143,arm32:',feat32.size())
|
||
feat32_arm = self.arm32(feat32)
|
||
feat32_sum = feat32_arm + avg_up
|
||
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
||
feat32_up = self.conv_head32(feat32_up)
|
||
#print('------------line148,arm16:',feat16.size())
|
||
feat16_arm = self.arm16(feat16)
|
||
feat16_sum = feat16_arm + feat32_up
|
||
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
||
feat16_up = self.conv_head16(feat16_up)
|
||
|
||
return feat2, feat4, feat8, feat16, feat16_up, feat32_up # x8, x16
|
||
|
||
# return feat8, feat16_up # x8, x16
|
||
|
||
def init_weight(self):
|
||
for ly in self.children():
|
||
if isinstance(ly, nn.Conv2d):
|
||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||
|
||
def get_params(self):
|
||
wd_params, nowd_params = [], []
|
||
for name, module in self.named_modules():
|
||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||
wd_params.append(module.weight)
|
||
if not module.bias is None:
|
||
nowd_params.append(module.bias)
|
||
elif isinstance(module, nn.BatchNorm2d):#################3
|
||
nowd_params += list(module.parameters())
|
||
return wd_params, nowd_params
|
||
|
||
class ContextPath(nn.Module):
|
||
def __init__(self, backbone='CatNetSmall', pretrain_model='', use_conv_last=False, *args, **kwargs):
|
||
super(ContextPath, self).__init__()
|
||
|
||
self.backbone_name = backbone
|
||
|
||
if backbone == 'STDCNet1446':
|
||
self.backbone = STDCNet1446(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
|
||
self.arm16 = AttentionRefinementModule(512, 128)
|
||
inplanes = 1024
|
||
if use_conv_last:
|
||
inplanes = 1024
|
||
self.arm32 = AttentionRefinementModule(inplanes, 128)
|
||
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||
self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
|
||
|
||
elif backbone == 'STDCNet813':
|
||
self.backbone = STDCNet813(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
|
||
|
||
self.arm16 = AttentionRefinementModule(512, 128)
|
||
inplanes = 1024
|
||
if use_conv_last:
|
||
inplanes = 1024
|
||
|
||
self.arm32 = AttentionRefinementModule(inplanes, 128)
|
||
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||
self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
|
||
else:
|
||
print("backbone is not in backbone lists")
|
||
exit(0)
|
||
|
||
self.init_weight()
|
||
|
||
def forward(self, x):
|
||
H0, W0 = x.size()[2:]
|
||
|
||
feat2, feat4, feat8, feat16, feat32 = self.backbone(x)
|
||
H8, W8 = feat8.size()[2:]
|
||
H16, W16 = feat16.size()[2:]
|
||
H32, W32 = feat32.size()[2:]
|
||
|
||
|
||
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
||
#print('line147:self.avg_pool_kernel_size_32:',self.avg_pool_kernel_size_32)
|
||
#avg = F.avg_pool2d(feat32, self.avg_pool_kernel_size_32)
|
||
#print('------------------newline140:','out:','out;',avg.size(),' in:',feat32.size())
|
||
avg = self.conv_avg(avg)
|
||
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
||
#print('------------line143,arm32:',feat32.size())
|
||
feat32_arm = self.arm32(feat32)
|
||
feat32_sum = feat32_arm + avg_up
|
||
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
||
feat32_up = self.conv_head32(feat32_up)
|
||
#print('------------line148,arm16:',feat16.size())
|
||
feat16_arm = self.arm16(feat16)
|
||
feat16_sum = feat16_arm + feat32_up
|
||
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
||
feat16_up = self.conv_head16(feat16_up)
|
||
|
||
return feat2, feat4, feat8, feat16, feat16_up, feat32_up # x8, x16
|
||
|
||
# return feat8, feat16_up # x8, x16
|
||
|
||
def init_weight(self):
|
||
for ly in self.children():
|
||
if isinstance(ly, nn.Conv2d):
|
||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||
|
||
def get_params(self):
|
||
wd_params, nowd_params = [], []
|
||
for name, module in self.named_modules():
|
||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||
wd_params.append(module.weight)
|
||
if not module.bias is None:
|
||
nowd_params.append(module.bias)
|
||
elif isinstance(module, nn.BatchNorm2d):#################3
|
||
nowd_params += list(module.parameters())
|
||
return wd_params, nowd_params
|
||
|
||
|
||
class FeatureFusionModule_static(nn.Module):
|
||
def __init__(self, in_chan, out_chan,modelSize ,*args, **kwargs):
|
||
super(FeatureFusionModule_static, self).__init__()
|
||
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
||
self.avg_pool_kernel_size=[ int(modelSize[0]/8+0.9999), int( modelSize[1]/8+0.9999 ) ]
|
||
self.conv1 = nn.Conv2d(out_chan,
|
||
out_chan//4,
|
||
kernel_size = 1,
|
||
stride = 1,
|
||
padding = 0,
|
||
bias = False)
|
||
self.conv2 = nn.Conv2d(out_chan//4,
|
||
out_chan,
|
||
kernel_size = 1,
|
||
stride = 1,
|
||
padding = 0,
|
||
bias = False)
|
||
self.relu = nn.ReLU(inplace=True)
|
||
self.sigmoid = nn.Sigmoid()
|
||
self.init_weight()
|
||
|
||
def forward(self, fsp, fcp):
|
||
fcat = torch.cat([fsp, fcp], dim=1)
|
||
feat = self.convblk(fcat)
|
||
|
||
#atten = F.avg_pool2d(feat, feat.size()[2:])
|
||
atten = F.avg_pool2d(feat, kernel_size=self.avg_pool_kernel_size)
|
||
#print('------------------newline199:',' out:',atten.size(),'in:',feat.size())
|
||
|
||
|
||
atten = self.conv1(atten)
|
||
atten = self.relu(atten)
|
||
atten = self.conv2(atten)
|
||
atten = self.sigmoid(atten)
|
||
feat_atten = torch.mul(feat, atten)
|
||
feat_out = feat_atten + feat
|
||
return feat_out
|
||
|
||
def init_weight(self):
|
||
for ly in self.children():
|
||
if isinstance(ly, nn.Conv2d):
|
||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||
|
||
def get_params(self):
|
||
wd_params, nowd_params = [], []
|
||
for name, module in self.named_modules():
|
||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||
wd_params.append(module.weight)
|
||
if not module.bias is None:
|
||
nowd_params.append(module.bias)
|
||
elif isinstance(module, nn.BatchNorm2d):##################4
|
||
nowd_params += list(module.parameters())
|
||
return wd_params, nowd_params
|
||
class FeatureFusionModule(nn.Module):
|
||
def __init__(self, in_chan, out_chan ,*args, **kwargs):
|
||
super(FeatureFusionModule, self).__init__()
|
||
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
||
|
||
self.conv1 = nn.Conv2d(out_chan,
|
||
out_chan//4,
|
||
kernel_size = 1,
|
||
stride = 1,
|
||
padding = 0,
|
||
bias = False)
|
||
self.conv2 = nn.Conv2d(out_chan//4,
|
||
out_chan,
|
||
kernel_size = 1,
|
||
stride = 1,
|
||
padding = 0,
|
||
bias = False)
|
||
self.relu = nn.ReLU(inplace=True)
|
||
self.sigmoid = nn.Sigmoid()
|
||
self.init_weight()
|
||
|
||
def forward(self, fsp, fcp):
|
||
fcat = torch.cat([fsp, fcp], dim=1)
|
||
feat = self.convblk(fcat)
|
||
|
||
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||
#atten = F.avg_pool2d(feat, kernel_size=self.avg_pool_kernel_size)
|
||
#print('------------------newline199:',' out:',atten.size(),'in:',feat.size())
|
||
|
||
|
||
atten = self.conv1(atten)
|
||
atten = self.relu(atten)
|
||
atten = self.conv2(atten)
|
||
atten = self.sigmoid(atten)
|
||
feat_atten = torch.mul(feat, atten)
|
||
feat_out = feat_atten + feat
|
||
return feat_out
|
||
|
||
def init_weight(self):
|
||
for ly in self.children():
|
||
if isinstance(ly, nn.Conv2d):
|
||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||
|
||
def get_params(self):
|
||
wd_params, nowd_params = [], []
|
||
for name, module in self.named_modules():
|
||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||
wd_params.append(module.weight)
|
||
if not module.bias is None:
|
||
nowd_params.append(module.bias)
|
||
elif isinstance(module, nn.BatchNorm2d):##################4
|
||
nowd_params += list(module.parameters())
|
||
return wd_params, nowd_params
|
||
|
||
|
||
class BiSeNet_STDC(nn.Module):
|
||
def __init__(self, backbone, n_classes, pretrain_model='', use_boundary_2=False, use_boundary_4=False,
|
||
use_boundary_8=False, use_boundary_16=False, use_conv_last=False,**kwargs):
|
||
super(BiSeNet_STDC, self).__init__()
|
||
if 'modelSize' in kwargs:
|
||
|
||
modelSize = kwargs['modelSize']
|
||
else:
|
||
modelSize=None
|
||
|
||
self.use_boundary_2 = use_boundary_2
|
||
self.use_boundary_4 = use_boundary_4
|
||
self.use_boundary_8 = use_boundary_8
|
||
self.use_boundary_16 = use_boundary_16
|
||
# self.heat_map = heat_map
|
||
if modelSize:
|
||
self.cp = ContextPath_static(backbone, pretrain_model, use_conv_last=use_conv_last,modelSize=modelSize)
|
||
else:
|
||
self.cp = ContextPath(backbone, pretrain_model, use_conv_last=use_conv_last)
|
||
|
||
if backbone == 'STDCNet1446':
|
||
conv_out_inplanes = 128
|
||
sp2_inplanes = 32
|
||
sp4_inplanes = 64
|
||
sp8_inplanes = 256
|
||
sp16_inplanes = 512
|
||
inplane = sp8_inplanes + conv_out_inplanes
|
||
|
||
elif backbone == 'STDCNet813':
|
||
conv_out_inplanes = 128
|
||
sp2_inplanes = 32
|
||
sp4_inplanes = 64
|
||
sp8_inplanes = 256
|
||
sp16_inplanes = 512
|
||
inplane = sp8_inplanes + conv_out_inplanes
|
||
|
||
else:
|
||
print("backbone is not in backbone lists")
|
||
exit(0)
|
||
if modelSize:
|
||
self.ffm = FeatureFusionModule_static(inplane, 256,modelSize)
|
||
else:
|
||
self.ffm = FeatureFusionModule(inplane, 256)
|
||
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
||
self.conv_out16 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)
|
||
self.conv_out32 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)
|
||
|
||
self.conv_out_sp16 = BiSeNetOutput(sp16_inplanes, 64, 1)
|
||
|
||
self.conv_out_sp8 = BiSeNetOutput(sp8_inplanes, 64, 1)
|
||
self.conv_out_sp4 = BiSeNetOutput(sp4_inplanes, 64, 1)
|
||
self.conv_out_sp2 = BiSeNetOutput(sp2_inplanes, 64, 1)
|
||
self.init_weight()
|
||
|
||
def forward(self, x):
|
||
H, W = x.size()[2:]
|
||
# time_0 = time.time()
|
||
# feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x)
|
||
feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x)
|
||
# print('----backbone', (time.time() - time_0) * 1000)
|
||
# feat_out_sp2 = self.conv_out_sp2(feat_res2)
|
||
#
|
||
# feat_out_sp4 = self.conv_out_sp4(feat_res4)
|
||
#
|
||
# feat_out_sp8 = self.conv_out_sp8(feat_res8)
|
||
#
|
||
# feat_out_sp16 = self.conv_out_sp16(feat_res16)
|
||
# time_1 = time.time()
|
||
feat_fuse = self.ffm(feat_res8, feat_cp8)
|
||
# print('----ffm', (time.time() - time_1) * 1000)
|
||
# time_2 = time.time()
|
||
feat_out = self.conv_out(feat_fuse)
|
||
# feat_out16 = self.conv_out16(feat_cp8)
|
||
# feat_out32 = self.conv_out32(feat_cp16)
|
||
|
||
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
||
# print('----conv_out', (time.time() - time_2) * 1000)
|
||
# feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
||
# feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
||
|
||
|
||
# if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8:
|
||
# return feat_out, feat_out16, feat_out32, feat_out_sp2, feat_out_sp4, feat_out_sp8
|
||
#
|
||
# if (not self.use_boundary_2) and self.use_boundary_4 and self.use_boundary_8:
|
||
# return feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8
|
||
#
|
||
# if (not self.use_boundary_2) and (not self.use_boundary_4) and self.use_boundary_8:
|
||
return feat_out
|
||
|
||
# if (not self.use_boundary_2) and (not self.use_boundary_4) and (not self.use_boundary_8):
|
||
# return feat_out, feat_out16, feat_out32
|
||
|
||
def init_weight(self):
|
||
for ly in self.children():
|
||
if isinstance(ly, nn.Conv2d):
|
||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||
|
||
def get_params(self):
|
||
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
||
for name, child in self.named_children():
|
||
child_wd_params, child_nowd_params = child.get_params()
|
||
if isinstance(child, (FeatureFusionModule, BiSeNetOutput)):
|
||
lr_mul_wd_params += child_wd_params
|
||
lr_mul_nowd_params += child_nowd_params
|
||
else:
|
||
wd_params += child_wd_params
|
||
nowd_params += child_nowd_params
|
||
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
||
|
||
|
||
if __name__ == "__main__":
|
||
|
||
|
||
|
||
model = BiSeNet_STDC(backbone='STDCNet813', n_classes=2,
|
||
use_boundary_2=False, use_boundary_4=False,
|
||
use_boundary_8=True, use_boundary_16=False,
|
||
use_conv_last=False,
|
||
# modelSize=[360,640]
|
||
)
|
||
#modelSize=[360,640]
|
||
print()
|
||
# torch.save(net.state_dict(), 'STDCNet813.pth')###
|
||
|
||
|