236 lines
8.4 KiB
Python
236 lines
8.4 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import sys
|
|
sys.path.extend(['/home/thsw2/WJ/src/yolov5/segutils/','../..','..' ])
|
|
from core.models.base_models.vgg import vgg16
|
|
|
|
__all__ = ['get_fcn32s', 'get_fcn16s', 'get_fcn8s',
|
|
'get_fcn32s_vgg16_voc', 'get_fcn16s_vgg16_voc', 'get_fcn8s_vgg16_voc']
|
|
|
|
|
|
class FCN32s(nn.Module):
|
|
"""There are some difference from original fcn"""
|
|
|
|
def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True,
|
|
norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(FCN32s, self).__init__()
|
|
self.aux = aux
|
|
if backbone == 'vgg16':
|
|
self.pretrained = vgg16(pretrained=pretrained_base).features
|
|
else:
|
|
raise RuntimeError('unknown backbone: {}'.format(backbone))
|
|
self.head = _FCNHead(512, nclass, norm_layer)
|
|
if aux:
|
|
self.auxlayer = _FCNHead(512, nclass, norm_layer)
|
|
|
|
self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
|
|
|
|
def forward(self, x):
|
|
size = x.size()[2:]
|
|
pool5 = self.pretrained(x)
|
|
|
|
outputs = []
|
|
out = self.head(pool5)
|
|
out = F.interpolate(out, size, mode='bilinear', align_corners=True)
|
|
outputs.append(out)
|
|
|
|
if self.aux:
|
|
auxout = self.auxlayer(pool5)
|
|
auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
|
|
outputs.append(auxout)
|
|
|
|
return tuple(outputs)
|
|
|
|
|
|
class FCN16s(nn.Module):
|
|
def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True, norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(FCN16s, self).__init__()
|
|
self.aux = aux
|
|
if backbone == 'vgg16':
|
|
self.pretrained = vgg16(pretrained=pretrained_base).features
|
|
else:
|
|
raise RuntimeError('unknown backbone: {}'.format(backbone))
|
|
self.pool4 = nn.Sequential(*self.pretrained[:24])
|
|
self.pool5 = nn.Sequential(*self.pretrained[24:])
|
|
self.head = _FCNHead(512, nclass, norm_layer)
|
|
self.score_pool4 = nn.Conv2d(512, nclass, 1)
|
|
if aux:
|
|
self.auxlayer = _FCNHead(512, nclass, norm_layer)
|
|
|
|
self.__setattr__('exclusive', ['head', 'score_pool4', 'auxlayer'] if aux else ['head', 'score_pool4'])
|
|
|
|
def forward(self, x):
|
|
pool4 = self.pool4(x)
|
|
pool5 = self.pool5(pool4)
|
|
|
|
outputs = []
|
|
score_fr = self.head(pool5)
|
|
|
|
score_pool4 = self.score_pool4(pool4)
|
|
|
|
upscore2 = F.interpolate(score_fr, score_pool4.size()[2:], mode='bilinear', align_corners=True)
|
|
fuse_pool4 = upscore2 + score_pool4
|
|
|
|
out = F.interpolate(fuse_pool4, x.size()[2:], mode='bilinear', align_corners=True)
|
|
outputs.append(out)
|
|
|
|
if self.aux:
|
|
auxout = self.auxlayer(pool5)
|
|
auxout = F.interpolate(auxout, x.size()[2:], mode='bilinear', align_corners=True)
|
|
outputs.append(auxout)
|
|
|
|
#return tuple(outputs)
|
|
return outputs[0]
|
|
|
|
class FCN8s(nn.Module):
|
|
def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True, norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(FCN8s, self).__init__()
|
|
self.aux = aux
|
|
if backbone == 'vgg16':
|
|
self.pretrained = vgg16(pretrained=pretrained_base).features
|
|
else:
|
|
raise RuntimeError('unknown backbone: {}'.format(backbone))
|
|
self.pool3 = nn.Sequential(*self.pretrained[:17])
|
|
self.pool4 = nn.Sequential(*self.pretrained[17:24])
|
|
self.pool5 = nn.Sequential(*self.pretrained[24:])
|
|
self.head = _FCNHead(512, nclass, norm_layer)
|
|
self.score_pool3 = nn.Conv2d(256, nclass, 1)
|
|
self.score_pool4 = nn.Conv2d(512, nclass, 1)
|
|
if aux:
|
|
self.auxlayer = _FCNHead(512, nclass, norm_layer)
|
|
|
|
self.__setattr__('exclusive',
|
|
['head', 'score_pool3', 'score_pool4', 'auxlayer'] if aux else ['head', 'score_pool3',
|
|
'score_pool4'])
|
|
|
|
def forward(self, x):
|
|
pool3 = self.pool3(x)
|
|
pool4 = self.pool4(pool3)
|
|
pool5 = self.pool5(pool4)
|
|
|
|
outputs = []
|
|
score_fr = self.head(pool5)
|
|
|
|
score_pool4 = self.score_pool4(pool4)
|
|
score_pool3 = self.score_pool3(pool3)
|
|
|
|
upscore2 = F.interpolate(score_fr, score_pool4.size()[2:], mode='bilinear', align_corners=True)
|
|
fuse_pool4 = upscore2 + score_pool4
|
|
|
|
upscore_pool4 = F.interpolate(fuse_pool4, score_pool3.size()[2:], mode='bilinear', align_corners=True)
|
|
fuse_pool3 = upscore_pool4 + score_pool3
|
|
|
|
out = F.interpolate(fuse_pool3, x.size()[2:], mode='bilinear', align_corners=True)
|
|
outputs.append(out)
|
|
|
|
if self.aux:
|
|
auxout = self.auxlayer(pool5)
|
|
auxout = F.interpolate(auxout, x.size()[2:], mode='bilinear', align_corners=True)
|
|
outputs.append(auxout)
|
|
|
|
return tuple(outputs)
|
|
|
|
|
|
class _FCNHead(nn.Module):
|
|
def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs):
|
|
super(_FCNHead, self).__init__()
|
|
inter_channels = in_channels // 4
|
|
self.block = nn.Sequential(
|
|
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
|
|
norm_layer(inter_channels),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(0.1),
|
|
nn.Conv2d(inter_channels, channels, 1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.block(x)
|
|
|
|
|
|
def get_fcn32s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
|
|
pretrained_base=True, **kwargs):
|
|
acronyms = {
|
|
'pascal_voc': 'pascal_voc',
|
|
'pascal_aug': 'pascal_aug',
|
|
'ade20k': 'ade',
|
|
'coco': 'coco',
|
|
'citys': 'citys',
|
|
}
|
|
from ..data.dataloader import datasets
|
|
model = FCN32s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
|
|
if pretrained:
|
|
from .model_store import get_model_file
|
|
device = torch.device(kwargs['local_rank'])
|
|
model.load_state_dict(torch.load(get_model_file('fcn32s_%s_%s' % (backbone, acronyms[dataset]), root=root),
|
|
map_location=device))
|
|
return model
|
|
|
|
|
|
def get_fcn16s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
|
|
pretrained_base=True, **kwargs):
|
|
acronyms = {
|
|
'pascal_voc': 'pascal_voc',
|
|
'pascal_aug': 'pascal_aug',
|
|
'ade20k': 'ade',
|
|
'coco': 'coco',
|
|
'citys': 'citys',
|
|
}
|
|
from ..data.dataloader import datasets
|
|
model = FCN16s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
|
|
if pretrained:
|
|
from .model_store import get_model_file
|
|
device = torch.device(kwargs['local_rank'])
|
|
model.load_state_dict(torch.load(get_model_file('fcn16s_%s_%s' % (backbone, acronyms[dataset]), root=root),
|
|
map_location=device))
|
|
return model
|
|
|
|
|
|
def get_fcn8s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
|
|
pretrained_base=True, **kwargs):
|
|
acronyms = {
|
|
'pascal_voc': 'pascal_voc',
|
|
'pascal_aug': 'pascal_aug',
|
|
'ade20k': 'ade',
|
|
'coco': 'coco',
|
|
'citys': 'citys',
|
|
}
|
|
from ..data.dataloader import datasets
|
|
model = FCN8s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
|
|
if pretrained:
|
|
from .model_store import get_model_file
|
|
device = torch.device(kwargs['local_rank'])
|
|
model.load_state_dict(torch.load(get_model_file('fcn8s_%s_%s' % (backbone, acronyms[dataset]), root=root),
|
|
map_location=device))
|
|
return model
|
|
|
|
|
|
def get_fcn32s_vgg16_voc(**kwargs):
|
|
return get_fcn32s('pascal_voc', 'vgg16', **kwargs)
|
|
|
|
|
|
def get_fcn16s_vgg16_voc(**kwargs):
|
|
return get_fcn16s('pascal_voc', 'vgg16', **kwargs)
|
|
|
|
|
|
def get_fcn8s_vgg16_voc(**kwargs):
|
|
return get_fcn8s('pascal_voc', 'vgg16', **kwargs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = FCN16s(21)
|
|
print(model)
|
|
input = torch.rand(2, 3, 224,224)
|
|
#target = torch.zeros(4, 512, 512).cuda()
|
|
#model.eval()
|
|
#print(model)
|
|
loss = model(input)
|
|
print(loss)
|
|
print(loss.shape)
|
|
import torch
|
|
from thop import profile
|
|
from torchsummary import summary
|
|
flop,params=profile(model,input_size=(1,3,512,512))
|
|
print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop/1e9, params/1e6))
|