소스 검색

segDemo

master
thsw 2 년 전
부모
커밋
086d1fe635
15개의 변경된 파일1012개의 추가작업 그리고 1개의 파일을 삭제
  1. BIN
      images/examples/00000004.jpg
  2. BIN
      images/examples/00000008.jpg
  3. BIN
      images/examples/00000032.jpg
  4. BIN
      images/examples/00000033.jpg
  5. BIN
      images/examples/20220524_北十里长沟中支11_7905_3283.jpg
  6. BIN
      images/examples/20220524_北十里长沟中支11_7920_7690.jpg
  7. BIN
      images/examples/20220624_响水河_12300_1621.jpg
  8. BIN
      images/examples/20220624_响水河_12315_5080.jpg
  9. +0
    -0
      models_725/__init__.py
  10. +298
    -0
      models_725/bisenet.py
  11. +3
    -0
      models_725/class_dict.csv
  12. +337
    -0
      models_725/model_stages.py
  13. +69
    -0
      models_725/segWaterBuilding.py
  14. +302
    -0
      models_725/stdcnet.py
  15. +3
    -1
      predict.py

BIN
images/examples/00000004.jpg 파일 보기

Before After
Width: 4000  |  Height: 3000  |  Size: 4.6MB

BIN
images/examples/00000008.jpg 파일 보기

Before After
Width: 4000  |  Height: 2250  |  Size: 3.7MB

BIN
images/examples/00000032.jpg 파일 보기

Before After
Width: 4000  |  Height: 2250  |  Size: 3.6MB

BIN
images/examples/00000033.jpg 파일 보기

Before After
Width: 4000  |  Height: 3000  |  Size: 6.3MB

BIN
images/examples/20220524_北十里长沟中支11_7905_3283.jpg 파일 보기

Before After
Width: 1920  |  Height: 1080  |  Size: 379KB

BIN
images/examples/20220524_北十里长沟中支11_7920_7690.jpg 파일 보기

Before After
Width: 1920  |  Height: 1080  |  Size: 421KB

BIN
images/examples/20220624_响水河_12300_1621.jpg 파일 보기

Before After
Width: 1920  |  Height: 1080  |  Size: 746KB

BIN
images/examples/20220624_响水河_12315_5080.jpg 파일 보기

Before After
Width: 1920  |  Height: 1080  |  Size: 753KB

+ 0
- 0
models_725/__init__.py 파일 보기


+ 298
- 0
models_725/bisenet.py 파일 보기

@@ -0,0 +1,298 @@
"""Bilateral Segmentation Network"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from core.models.base_models.resnet import resnet18,resnet50
from core.nn import _ConvBNReLU

__all__ = ['BiSeNet', 'get_bisenet', 'get_bisenet_resnet18_citys']


class BiSeNet(nn.Module):
def __init__(self, nclass, backbone='resnet18', aux=False, jpu=False, pretrained_base=True, **kwargs):
super(BiSeNet, self).__init__()
self.aux = aux
self.spatial_path = SpatialPath(3, 128, **kwargs)
self.context_path = ContextPath(backbone, pretrained_base, **kwargs)
self.ffm = FeatureFusion(256, 256, 4, **kwargs)
self.head = _BiSeHead(256, 64, nclass, **kwargs)
if aux:
self.auxlayer1 = _BiSeHead(128, 256, nclass, **kwargs)
self.auxlayer2 = _BiSeHead(128, 256, nclass, **kwargs)

self.__setattr__('exclusive',
['spatial_path', 'context_path', 'ffm', 'head', 'auxlayer1', 'auxlayer2'] if aux else [
'spatial_path', 'context_path', 'ffm', 'head'])

def forward(self, x,outsize=None,test_flag=False):
size = x.size()[2:]
spatial_out = self.spatial_path(x)
context_out = self.context_path(x)
fusion_out = self.ffm(spatial_out, context_out[-1])
outputs = []
x = self.head(fusion_out)
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
if outsize:
print('######using torch resize#######',outsize)
x = F.interpolate(x, outsize, mode='bilinear', align_corners=True)
outputs.append(x)

if self.aux:
auxout1 = self.auxlayer1(context_out[0])
auxout1 = F.interpolate(auxout1, size, mode='bilinear', align_corners=True)
outputs.append(auxout1)
auxout2 = self.auxlayer2(context_out[1])
auxout2 = F.interpolate(auxout2, size, mode='bilinear', align_corners=True)
outputs.append(auxout2)
if test_flag:
outputs = [torch.argmax(outputx ,axis=1) for outputx in outputs]
#return tuple(outputs)
return outputs[0]
class BiSeNet_MultiOutput(nn.Module):
def __init__(self, nclass, backbone='resnet18', aux=False, jpu=False, pretrained_base=True, **kwargs):
super(BiSeNet_MultiOutput, self).__init__()
self.aux = aux
self.spatial_path = SpatialPath(3, 128, **kwargs)
self.context_path = ContextPath(backbone, pretrained_base, **kwargs)
self.ffm = FeatureFusion(256, 256, 4, **kwargs)
assert isinstance(nclass,list)
self.outCnt = len(nclass)
for ii,nclassii in enumerate(nclass):
setattr(self,'head%d'%(ii) , _BiSeHead(256, 64, nclassii, **kwargs))
if aux:
self.auxlayer1 = _BiSeHead(128, 256, nclass, **kwargs)
self.auxlayer2 = _BiSeHead(128, 256, nclass, **kwargs)

self.__setattr__('exclusive',
['spatial_path', 'context_path', 'ffm', 'head', 'auxlayer1', 'auxlayer2'] if aux else [
'spatial_path', 'context_path', 'ffm', 'head'])

def forward(self, x,outsize=None,test_flag=False,smooth_kernel=0):
size = x.size()[2:]
spatial_out = self.spatial_path(x)
context_out = self.context_path(x)
fusion_out = self.ffm(spatial_out, context_out[-1])
outputs = []
for ii in range(self.outCnt):
x = getattr(self,'head%d'%(ii))(fusion_out)
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
outputs.append(x)
if self.aux:
auxout1 = self.auxlayer1(context_out[0])
auxout1 = F.interpolate(auxout1, size, mode='bilinear', align_corners=True)
outputs.append(auxout1)
auxout2 = self.auxlayer2(context_out[1])
auxout2 = F.interpolate(auxout2, size, mode='bilinear', align_corners=True)
outputs.append(auxout2)
if test_flag:
outputs = [torch.argmax(outputx ,axis=1) for outputx in outputs]
if smooth_kernel>0:
gaussian_kernel = torch.from_numpy(np.ones((1,1,smooth_kernel,smooth_kernel)) )
pad = int((smooth_kernel - 1)/2)
if not gaussian_kernel.is_cuda:
gaussian_kernel = gaussian_kernel.to(x.device)
#print(gaussian_kernel.dtype,gaussian_kernel,outputs[0].dtype)
outputs = [ x.unsqueeze(1).double() for x in outputs]
outputs = [torch.conv2d(x, gaussian_kernel, padding=pad) for x in outputs ]
outputs = [ x.squeeze(1).long() for x in outputs]
#return tuple(outputs)
return outputs
class _BiSeHead(nn.Module):
def __init__(self, in_channels, inter_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
super(_BiSeHead, self).__init__()
self.block = nn.Sequential(
_ConvBNReLU(in_channels, inter_channels, 3, 1, 1, norm_layer=norm_layer),
nn.Dropout(0.1),
nn.Conv2d(inter_channels, nclass, 1)
)

def forward(self, x):
x = self.block(x)
return x


class SpatialPath(nn.Module):
"""Spatial path"""

def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
super(SpatialPath, self).__init__()
inter_channels = 64
self.conv7x7 = _ConvBNReLU(in_channels, inter_channels, 7, 2, 3, norm_layer=norm_layer)
self.conv3x3_1 = _ConvBNReLU(inter_channels, inter_channels, 3, 2, 1, norm_layer=norm_layer)
self.conv3x3_2 = _ConvBNReLU(inter_channels, inter_channels, 3, 2, 1, norm_layer=norm_layer)
self.conv1x1 = _ConvBNReLU(inter_channels, out_channels, 1, 1, 0, norm_layer=norm_layer)

def forward(self, x):
x = self.conv7x7(x)
x = self.conv3x3_1(x)
x = self.conv3x3_2(x)
x = self.conv1x1(x)

return x


class _GlobalAvgPooling(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer, **kwargs):
super(_GlobalAvgPooling, self).__init__()
self.gap = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels),
nn.ReLU(True)
)

def forward(self, x):
size = x.size()[2:]
pool = self.gap(x)
out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
return out


class AttentionRefinmentModule(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
super(AttentionRefinmentModule, self).__init__()
self.conv3x3 = _ConvBNReLU(in_channels, out_channels, 3, 1, 1, norm_layer=norm_layer)
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
_ConvBNReLU(out_channels, out_channels, 1, 1, 0, norm_layer=norm_layer),
nn.Sigmoid()
)

def forward(self, x):
x = self.conv3x3(x)
attention = self.channel_attention(x)
x = x * attention
return x


class ContextPath(nn.Module):
def __init__(self, backbone='resnet18', pretrained_base=True, norm_layer=nn.BatchNorm2d, **kwargs):
super(ContextPath, self).__init__()
if backbone == 'resnet18':
pretrained = resnet18(pretrained=pretrained_base, **kwargs)
elif backbone=='resnet50':
pretrained = resnet50(pretrained=pretrained_base, **kwargs)
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
self.conv1 = pretrained.conv1
self.bn1 = pretrained.bn1
self.relu = pretrained.relu
self.maxpool = pretrained.maxpool
self.layer1 = pretrained.layer1
self.layer2 = pretrained.layer2
self.layer3 = pretrained.layer3
self.layer4 = pretrained.layer4

inter_channels = 128
self.global_context = _GlobalAvgPooling(512, inter_channels, norm_layer)

self.arms = nn.ModuleList(
[AttentionRefinmentModule(512, inter_channels, norm_layer, **kwargs),
AttentionRefinmentModule(256, inter_channels, norm_layer, **kwargs)]
)
self.refines = nn.ModuleList(
[_ConvBNReLU(inter_channels, inter_channels, 3, 1, 1, norm_layer=norm_layer),
_ConvBNReLU(inter_channels, inter_channels, 3, 1, 1, norm_layer=norm_layer)]
)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)

context_blocks = []
context_blocks.append(x)
x = self.layer2(x)
context_blocks.append(x)
c3 = self.layer3(x)
context_blocks.append(c3)
c4 = self.layer4(c3)
context_blocks.append(c4)
context_blocks.reverse()

global_context = self.global_context(c4)
last_feature = global_context
context_outputs = []
for i, (feature, arm, refine) in enumerate(zip(context_blocks[:2], self.arms, self.refines)):
feature = arm(feature)
feature += last_feature
last_feature = F.interpolate(feature, size=context_blocks[i + 1].size()[2:],
mode='bilinear', align_corners=True)
last_feature = refine(last_feature)
context_outputs.append(last_feature)

return context_outputs


class FeatureFusion(nn.Module):
def __init__(self, in_channels, out_channels, reduction=1, norm_layer=nn.BatchNorm2d, **kwargs):
super(FeatureFusion, self).__init__()
self.conv1x1 = _ConvBNReLU(in_channels, out_channels, 1, 1, 0, norm_layer=norm_layer, **kwargs)
self.channel_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
_ConvBNReLU(out_channels, out_channels // reduction, 1, 1, 0, norm_layer=norm_layer),
_ConvBNReLU(out_channels // reduction, out_channels, 1, 1, 0, norm_layer=norm_layer),
nn.Sigmoid()
)

def forward(self, x1, x2):
fusion = torch.cat([x1, x2], dim=1)
out = self.conv1x1(fusion)
attention = self.channel_attention(out)
out = out + out * attention
return out


def get_bisenet(dataset='citys', backbone='resnet18', pretrained=False, root='~/.torch/models',
pretrained_base=True, **kwargs):
acronyms = {
'pascal_voc': 'pascal_voc',
'pascal_aug': 'pascal_aug',
'ade20k': 'ade',
'coco': 'coco',
'citys': 'citys',
}
from ..data.dataloader import datasets
model = BiSeNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('bisenet_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


def get_bisenet_resnet18_citys(**kwargs):
return get_bisenet('citys', 'resnet18', **kwargs)


if __name__ == '__main__':
# img = torch.randn(2, 3, 224, 224)
# model = BiSeNet(19, backbone='resnet18')
# print(model.exclusive)
input = torch.rand(2, 3, 224, 224)
model = BiSeNet(4, pretrained_base=True)
# target = torch.zeros(4, 512, 512).cuda()
# model.eval()
# print(model)
loss = model(input)
print(loss, loss.shape)

# from torchsummary import summary
#
# summary(model, (3, 224, 224)) # 打印表格,按顺序输出每层的输出形状和参数
import torch
from thop import profile
from torchsummary import summary

flop, params = profile(model, input_size=(1, 3, 512, 512))
print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop / 1e9, params / 1e6))

+ 3
- 0
models_725/class_dict.csv 파일 보기

@@ -0,0 +1,3 @@
name,r,g,b
0,0,0,0
1,255,255,255

+ 337
- 0
models_725/model_stages.py 파일 보기

@@ -0,0 +1,337 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import time
from models_725.stdcnet import STDCNet1446, STDCNet813
#from models_725.bn import InPlaceABNSync as BatchNorm2d
# BatchNorm2d = nn.BatchNorm2d

class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
out_chan,
kernel_size = ks,
stride = stride,
padding = padding,
bias = False)
# self.bn = BatchNorm2d(out_chan)
# self.bn = BatchNorm2d(out_chan, activation='none')
self.bn = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU()
self.init_weight()

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x

def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
self.init_weight()

def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
return x

def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)

def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):######################1
nowd_params += list(module.parameters())
return wd_params, nowd_params


class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
# self.bn_atten = nn.BatchNorm2d(out_chan)
# self.bn_atten = BatchNorm2d(out_chan, activation='none')
self.bn_atten = nn.BatchNorm2d(out_chan)########################2

self.sigmoid_atten = nn.Sigmoid()
self.init_weight()

def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out

def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)


class ContextPath(nn.Module):
def __init__(self, backbone='CatNetSmall', pretrain_model='', use_conv_last=False, *args, **kwargs):
super(ContextPath, self).__init__()
self.backbone_name = backbone
if backbone == 'STDCNet1446':
self.backbone = STDCNet1446(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
self.arm16 = AttentionRefinementModule(512, 128)
inplanes = 1024
if use_conv_last:
inplanes = 1024
self.arm32 = AttentionRefinementModule(inplanes, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)

elif backbone == 'STDCNet813':
self.backbone = STDCNet813(pretrain_model=pretrain_model, use_conv_last=use_conv_last)
self.arm16 = AttentionRefinementModule(512, 128)
inplanes = 1024
if use_conv_last:
inplanes = 1024
self.arm32 = AttentionRefinementModule(inplanes, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(inplanes, 128, ks=1, stride=1, padding=0)
else:
print("backbone is not in backbone lists")
exit(0)

self.init_weight()

def forward(self, x):
H0, W0 = x.size()[2:]

feat2, feat4, feat8, feat16, feat32 = self.backbone(x)
H8, W8 = feat8.size()[2:]
H16, W16 = feat16.size()[2:]
H32, W32 = feat32.size()[2:]
avg = F.avg_pool2d(feat32, feat32.size()[2:])

avg = self.conv_avg(avg)
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')

feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
feat32_up = self.conv_head32(feat32_up)

feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
feat16_up = self.conv_head16(feat16_up)
return feat2, feat4, feat8, feat16, feat16_up, feat32_up # x8, x16

# return feat8, feat16_up # x8, x16

def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)

def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):#################3
nowd_params += list(module.parameters())
return wd_params, nowd_params


class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan,
out_chan//4,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.conv2 = nn.Conv2d(out_chan//4,
out_chan,
kernel_size = 1,
stride = 1,
padding = 0,
bias = False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
self.init_weight()

def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out

def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)

def get_params(self):
wd_params, nowd_params = [], []
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
wd_params.append(module.weight)
if not module.bias is None:
nowd_params.append(module.bias)
elif isinstance(module, nn.BatchNorm2d):##################4
nowd_params += list(module.parameters())
return wd_params, nowd_params


class BiSeNet(nn.Module):
def __init__(self, backbone, n_classes, pretrain_model='', use_boundary_2=False, use_boundary_4=False,
use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
super(BiSeNet, self).__init__()
self.use_boundary_2 = use_boundary_2
self.use_boundary_4 = use_boundary_4
self.use_boundary_8 = use_boundary_8
self.use_boundary_16 = use_boundary_16
# self.heat_map = heat_map
self.cp = ContextPath(backbone, pretrain_model, use_conv_last=use_conv_last)
if backbone == 'STDCNet1446':
conv_out_inplanes = 128
sp2_inplanes = 32
sp4_inplanes = 64
sp8_inplanes = 256
sp16_inplanes = 512
inplane = sp8_inplanes + conv_out_inplanes

elif backbone == 'STDCNet813':
conv_out_inplanes = 128
sp2_inplanes = 32
sp4_inplanes = 64
sp8_inplanes = 256
sp16_inplanes = 512
inplane = sp8_inplanes + conv_out_inplanes

else:
print("backbone is not in backbone lists")
exit(0)

self.ffm = FeatureFusionModule(inplane, 256)
self.conv_out = BiSeNetOutput(256, 256, n_classes)
self.conv_out16 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)
self.conv_out32 = BiSeNetOutput(conv_out_inplanes, 64, n_classes)

self.conv_out_sp16 = BiSeNetOutput(sp16_inplanes, 64, 1)
self.conv_out_sp8 = BiSeNetOutput(sp8_inplanes, 64, 1)
self.conv_out_sp4 = BiSeNetOutput(sp4_inplanes, 64, 1)
self.conv_out_sp2 = BiSeNetOutput(sp2_inplanes, 64, 1)
self.init_weight()

def forward(self, x):
H, W = x.size()[2:]
# time_0 = time.time()
# feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x)
feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x)
# print('----backbone', (time.time() - time_0) * 1000)
# feat_out_sp2 = self.conv_out_sp2(feat_res2)
#
# feat_out_sp4 = self.conv_out_sp4(feat_res4)
#
# feat_out_sp8 = self.conv_out_sp8(feat_res8)
#
# feat_out_sp16 = self.conv_out_sp16(feat_res16)
# time_1 = time.time()
feat_fuse = self.ffm(feat_res8, feat_cp8)
# print('----ffm', (time.time() - time_1) * 1000)
# time_2 = time.time()
feat_out = self.conv_out(feat_fuse)
# feat_out16 = self.conv_out16(feat_cp8)
# feat_out32 = self.conv_out32(feat_cp16)

feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
# print('----conv_out', (time.time() - time_2) * 1000)
# feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
# feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)


# if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8:
# return feat_out, feat_out16, feat_out32, feat_out_sp2, feat_out_sp4, feat_out_sp8
#
# if (not self.use_boundary_2) and self.use_boundary_4 and self.use_boundary_8:
# return feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8
#
# if (not self.use_boundary_2) and (not self.use_boundary_4) and self.use_boundary_8:
return feat_out
# if (not self.use_boundary_2) and (not self.use_boundary_4) and (not self.use_boundary_8):
# return feat_out, feat_out16, feat_out32

def init_weight(self):
for ly in self.children():
if isinstance(ly, nn.Conv2d):
nn.init.kaiming_normal_(ly.weight, a=1)
if not ly.bias is None: nn.init.constant_(ly.bias, 0)

def get_params(self):
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
for name, child in self.named_children():
child_wd_params, child_nowd_params = child.get_params()
if isinstance(child, (FeatureFusionModule, BiSeNetOutput)):
lr_mul_wd_params += child_wd_params
lr_mul_nowd_params += child_nowd_params
else:
wd_params += child_wd_params
nowd_params += child_nowd_params
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params


if __name__ == "__main__":
net = BiSeNet('STDCNet813', 19)
net.cuda()
net.eval()
in_ten = torch.randn(1, 3, 768, 1536).cuda()
out, out16, out32 = net(in_ten)
print(out.shape)
# torch.save(net.state_dict(), 'STDCNet813.pth')###


+ 69
- 0
models_725/segWaterBuilding.py 파일 보기

@@ -0,0 +1,69 @@
import torch
from torchvision import transforms
import cv2
import numpy as np
import matplotlib.pyplot as plt
from models_725.model_stages import BiSeNet
import torch.nn.functional as F


class SegModel(object):
def __init__(self, nclass=2, device='cuda:0',
respth='./model_maxmIOU75_1720_0.946_360640.pth',
multiOutput=False):

self.model = BiSeNet(backbone='STDCNet813', n_classes=nclass,
use_boundary_2=False, use_boundary_4=False,
use_boundary_8=True, use_boundary_16=False,
use_conv_last=False)
self.model.load_state_dict(torch.load(respth))
self.device = device
self.multiOutput = multiOutput
self.model= self.model.to(self.device)
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
self.mean = (0.485, 0.456, 0.406)
self.std = (0.229, 0.224, 0.225)

def eval(self, image=None):
H, W, _ = image.shape
img = self.preprocess_image(image)

imgs = img.cuda()
self.model.eval()
logits = self.model(imgs)

logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True)
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1)

return pred, probs[0][1]

def preprocess_image(self, image):
image = cv2.resize(image, (640,360), interpolation=cv2.INTER_LINEAR)
image = image.astype(np.float32)
image /= 255.0

image[:, :, 0] -= self.mean[0]
image[:, :, 1] -= self.mean[1]
image[:, :, 2] -= self.mean[2]

image[:, :, 0] /= self.std[0]
image[:, :, 1] /= self.std[1]
image[:, :, 2] /= self.std[2]

image = np.transpose(image, (2, 0, 1))
image = torch.from_numpy(image).float()
image = image.unsqueeze(0)

return image



+ 302
- 0
models_725/stdcnet.py 파일 보기

@@ -0,0 +1,302 @@
import torch
import torch.nn as nn
from torch.nn import init
import math

class ConvX(nn.Module):
def __init__(self, in_planes, out_planes, kernel=3, stride=1):
super(ConvX, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
out = self.relu(self.bn(self.conv(x)))
return out


class AddBottleneck(nn.Module):
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
super(AddBottleneck, self).__init__()
assert block_num > 1, print("block number should be larger than 1.")
self.conv_list = nn.ModuleList()
self.stride = stride
if stride == 2:
self.avd_layer = nn.Sequential(
nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
nn.BatchNorm2d(out_planes//2),
)
self.skip = nn.Sequential(
nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes, bias=False),
nn.BatchNorm2d(in_planes),
nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False),
nn.BatchNorm2d(out_planes),
)
stride = 1

for idx in range(block_num):
if idx == 0:
self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
elif idx == 1 and block_num == 2:
self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))
elif idx == 1 and block_num > 2:
self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))
elif idx < block_num - 1:
self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1))))
else:
self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx))))
def forward(self, x):
out_list = []
out = x

for idx, conv in enumerate(self.conv_list):
if idx == 0 and self.stride == 2:
out = self.avd_layer(conv(out))
else:
out = conv(out)
out_list.append(out)

if self.stride == 2:
x = self.skip(x)

return torch.cat(out_list, dim=1) + x



class CatBottleneck(nn.Module):
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
super(CatBottleneck, self).__init__()
assert block_num > 1, print("block number should be larger than 1.")
self.conv_list = nn.ModuleList()
self.stride = stride
if stride == 2:
self.avd_layer = nn.Sequential(
nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
nn.BatchNorm2d(out_planes//2),
)
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
stride = 1

for idx in range(block_num):
if idx == 0:
self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
elif idx == 1 and block_num == 2:
self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))
elif idx == 1 and block_num > 2:
self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))
elif idx < block_num - 1:
self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx+1))))
else:
self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx))))
def forward(self, x):
out_list = []
out1 = self.conv_list[0](x)

for idx, conv in enumerate(self.conv_list[1:]):
if idx == 0:
if self.stride == 2:
out = conv(self.avd_layer(out1))
else:
out = conv(out1)
else:
out = conv(out)
out_list.append(out)

if self.stride == 2:
out1 = self.skip(out1)
out_list.insert(0, out1)

out = torch.cat(out_list, dim=1)
return out

#STDC2Net
class STDCNet1446(nn.Module):
def __init__(self, base=64, layers=[4,5,3], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False):
super(STDCNet1446, self).__init__()
if type == "cat":
block = CatBottleneck
elif type == "add":
block = AddBottleneck
self.use_conv_last = use_conv_last
self.features = self._make_layers(base, layers, block_num, block)
self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False)
self.bn = nn.BatchNorm1d(max(1024, base*16))
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=dropout)
self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False)

self.x2 = nn.Sequential(self.features[:1])
self.x4 = nn.Sequential(self.features[1:2])
self.x8 = nn.Sequential(self.features[2:6])
self.x16 = nn.Sequential(self.features[6:11])
self.x32 = nn.Sequential(self.features[11:])

if pretrain_model:
print('use pretrain model {}'.format(pretrain_model))
self.init_weight(pretrain_model)
else:
self.init_params()

def init_weight(self, pretrain_model):
state_dict = torch.load(pretrain_model)["state_dict"]
self_state_dict = self.state_dict()
for k, v in state_dict.items():
self_state_dict.update({k: v})
self.load_state_dict(self_state_dict)

def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)

def _make_layers(self, base, layers, block_num, block):
features = []
features += [ConvX(3, base//2, 3, 2)]
features += [ConvX(base//2, base, 3, 2)]

for i, layer in enumerate(layers):
for j in range(layer):
if i == 0 and j == 0:
features.append(block(base, base*4, block_num, 2))
elif j == 0:
features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
else:
features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))

return nn.Sequential(*features)

def forward(self, x):
feat2 = self.x2(x)
feat4 = self.x4(feat2)
feat8 = self.x8(feat4)
feat16 = self.x16(feat8)
feat32 = self.x32(feat16)
if self.use_conv_last:
feat32 = self.conv_last(feat32)

return feat2, feat4, feat8, feat16, feat32

def forward_impl(self, x):
out = self.features(x)
out = self.conv_last(out).pow(2)
out = self.gap(out).flatten(1)
out = self.fc(out)
# out = self.bn(out)
out = self.relu(out)
# out = self.relu(self.bn(self.fc(out)))
out = self.dropout(out)
out = self.linear(out)
return out

# STDC1Net
class STDCNet813(nn.Module):
def __init__(self, base=64, layers=[2,2,2], block_num=4, type="cat", num_classes=1000, dropout=0.20, pretrain_model='', use_conv_last=False):
super(STDCNet813, self).__init__()
if type == "cat":
block = CatBottleneck
elif type == "add":
block = AddBottleneck
self.use_conv_last = use_conv_last
self.features = self._make_layers(base, layers, block_num, block)
self.conv_last = ConvX(base*16, max(1024, base*16), 1, 1)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(max(1024, base*16), max(1024, base*16), bias=False)
self.bn = nn.BatchNorm1d(max(1024, base*16))
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(p=dropout)
self.linear = nn.Linear(max(1024, base*16), num_classes, bias=False)

self.x2 = nn.Sequential(self.features[:1])
self.x4 = nn.Sequential(self.features[1:2])
self.x8 = nn.Sequential(self.features[2:4])
self.x16 = nn.Sequential(self.features[4:6])
self.x32 = nn.Sequential(self.features[6:])

if pretrain_model:
print('use pretrain model {}'.format(pretrain_model))
self.init_weight(pretrain_model)
else:
self.init_params()

def init_weight(self, pretrain_model):
state_dict = torch.load(pretrain_model)["state_dict"]
self_state_dict = self.state_dict()
for k, v in state_dict.items():
self_state_dict.update({k: v})
self.load_state_dict(self_state_dict)

def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)

def _make_layers(self, base, layers, block_num, block):
features = []
features += [ConvX(3, base//2, 3, 2)]
features += [ConvX(base//2, base, 3, 2)]

for i, layer in enumerate(layers):
for j in range(layer):
if i == 0 and j == 0:
features.append(block(base, base*4, block_num, 2))
elif j == 0:
features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
else:
features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))

return nn.Sequential(*features)

def forward(self, x):
feat2 = self.x2(x)
feat4 = self.x4(feat2)
feat8 = self.x8(feat4)
feat16 = self.x16(feat8)
feat32 = self.x32(feat16)
if self.use_conv_last:
feat32 = self.conv_last(feat32)

return feat2, feat4, feat8, feat16, feat32

def forward_impl(self, x):
out = self.features(x)
out = self.conv_last(out).pow(2)
out = self.gap(out).flatten(1)
out = self.fc(out)
# out = self.bn(out)
out = self.relu(out)
# out = self.relu(self.bn(self.fc(out)))
out = self.dropout(out)
out = self.linear(out)
return out

if __name__ == "__main__":
model = STDCNet813(num_classes=1000, dropout=0.00, block_num=4)
model.eval()
x = torch.randn(1,3,224,224)
y = model(x)
torch.save(model.state_dict(), 'cat.pth')
print(y.size())

+ 3
- 1
predict.py 파일 보기

@@ -21,6 +21,7 @@ def predict_lunkuo(impth=None):

if __name__ == '__main__':
impth = 'images/examples'
outpth= 'images/results'
folders = os.listdir(impth)
segmodel = SegModel()
for i in range(len(folders)):
@@ -29,5 +30,6 @@ if __name__ == '__main__':
img = Image.open(imgpath).convert('RGB')
img = np.array(img)
time11 = time.time()
predict_lunkuo(impth=img)
img=predict_lunkuo(impth=img)
cv2.imwrite( os.path.join( outpth,folders[i] ) ,img )
print('----all_process', (time.time() - time11) * 1000)

Loading…
취소
저장