AIlib2/segutils/core/models/fcn.py

236 lines
8.4 KiB
Python

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.extend(['/home/thsw2/WJ/src/yolov5/segutils/','../..','..' ])
from core.models.base_models.vgg import vgg16
__all__ = ['get_fcn32s', 'get_fcn16s', 'get_fcn8s',
'get_fcn32s_vgg16_voc', 'get_fcn16s_vgg16_voc', 'get_fcn8s_vgg16_voc']
class FCN32s(nn.Module):
"""There are some difference from original fcn"""
def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True,
norm_layer=nn.BatchNorm2d, **kwargs):
super(FCN32s, self).__init__()
self.aux = aux
if backbone == 'vgg16':
self.pretrained = vgg16(pretrained=pretrained_base).features
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
self.head = _FCNHead(512, nclass, norm_layer)
if aux:
self.auxlayer = _FCNHead(512, nclass, norm_layer)
self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
def forward(self, x):
size = x.size()[2:]
pool5 = self.pretrained(x)
outputs = []
out = self.head(pool5)
out = F.interpolate(out, size, mode='bilinear', align_corners=True)
outputs.append(out)
if self.aux:
auxout = self.auxlayer(pool5)
auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
outputs.append(auxout)
return tuple(outputs)
class FCN16s(nn.Module):
def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True, norm_layer=nn.BatchNorm2d, **kwargs):
super(FCN16s, self).__init__()
self.aux = aux
if backbone == 'vgg16':
self.pretrained = vgg16(pretrained=pretrained_base).features
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
self.pool4 = nn.Sequential(*self.pretrained[:24])
self.pool5 = nn.Sequential(*self.pretrained[24:])
self.head = _FCNHead(512, nclass, norm_layer)
self.score_pool4 = nn.Conv2d(512, nclass, 1)
if aux:
self.auxlayer = _FCNHead(512, nclass, norm_layer)
self.__setattr__('exclusive', ['head', 'score_pool4', 'auxlayer'] if aux else ['head', 'score_pool4'])
def forward(self, x):
pool4 = self.pool4(x)
pool5 = self.pool5(pool4)
outputs = []
score_fr = self.head(pool5)
score_pool4 = self.score_pool4(pool4)
upscore2 = F.interpolate(score_fr, score_pool4.size()[2:], mode='bilinear', align_corners=True)
fuse_pool4 = upscore2 + score_pool4
out = F.interpolate(fuse_pool4, x.size()[2:], mode='bilinear', align_corners=True)
outputs.append(out)
if self.aux:
auxout = self.auxlayer(pool5)
auxout = F.interpolate(auxout, x.size()[2:], mode='bilinear', align_corners=True)
outputs.append(auxout)
#return tuple(outputs)
return outputs[0]
class FCN8s(nn.Module):
def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True, norm_layer=nn.BatchNorm2d, **kwargs):
super(FCN8s, self).__init__()
self.aux = aux
if backbone == 'vgg16':
self.pretrained = vgg16(pretrained=pretrained_base).features
else:
raise RuntimeError('unknown backbone: {}'.format(backbone))
self.pool3 = nn.Sequential(*self.pretrained[:17])
self.pool4 = nn.Sequential(*self.pretrained[17:24])
self.pool5 = nn.Sequential(*self.pretrained[24:])
self.head = _FCNHead(512, nclass, norm_layer)
self.score_pool3 = nn.Conv2d(256, nclass, 1)
self.score_pool4 = nn.Conv2d(512, nclass, 1)
if aux:
self.auxlayer = _FCNHead(512, nclass, norm_layer)
self.__setattr__('exclusive',
['head', 'score_pool3', 'score_pool4', 'auxlayer'] if aux else ['head', 'score_pool3',
'score_pool4'])
def forward(self, x):
pool3 = self.pool3(x)
pool4 = self.pool4(pool3)
pool5 = self.pool5(pool4)
outputs = []
score_fr = self.head(pool5)
score_pool4 = self.score_pool4(pool4)
score_pool3 = self.score_pool3(pool3)
upscore2 = F.interpolate(score_fr, score_pool4.size()[2:], mode='bilinear', align_corners=True)
fuse_pool4 = upscore2 + score_pool4
upscore_pool4 = F.interpolate(fuse_pool4, score_pool3.size()[2:], mode='bilinear', align_corners=True)
fuse_pool3 = upscore_pool4 + score_pool3
out = F.interpolate(fuse_pool3, x.size()[2:], mode='bilinear', align_corners=True)
outputs.append(out)
if self.aux:
auxout = self.auxlayer(pool5)
auxout = F.interpolate(auxout, x.size()[2:], mode='bilinear', align_corners=True)
outputs.append(auxout)
return tuple(outputs)
class _FCNHead(nn.Module):
def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs):
super(_FCNHead, self).__init__()
inter_channels = in_channels // 4
self.block = nn.Sequential(
nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
norm_layer(inter_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(inter_channels, channels, 1)
)
def forward(self, x):
return self.block(x)
def get_fcn32s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
pretrained_base=True, **kwargs):
acronyms = {
'pascal_voc': 'pascal_voc',
'pascal_aug': 'pascal_aug',
'ade20k': 'ade',
'coco': 'coco',
'citys': 'citys',
}
from ..data.dataloader import datasets
model = FCN32s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('fcn32s_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model
def get_fcn16s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
pretrained_base=True, **kwargs):
acronyms = {
'pascal_voc': 'pascal_voc',
'pascal_aug': 'pascal_aug',
'ade20k': 'ade',
'coco': 'coco',
'citys': 'citys',
}
from ..data.dataloader import datasets
model = FCN16s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('fcn16s_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model
def get_fcn8s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
pretrained_base=True, **kwargs):
acronyms = {
'pascal_voc': 'pascal_voc',
'pascal_aug': 'pascal_aug',
'ade20k': 'ade',
'coco': 'coco',
'citys': 'citys',
}
from ..data.dataloader import datasets
model = FCN8s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('fcn8s_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model
def get_fcn32s_vgg16_voc(**kwargs):
return get_fcn32s('pascal_voc', 'vgg16', **kwargs)
def get_fcn16s_vgg16_voc(**kwargs):
return get_fcn16s('pascal_voc', 'vgg16', **kwargs)
def get_fcn8s_vgg16_voc(**kwargs):
return get_fcn8s('pascal_voc', 'vgg16', **kwargs)
if __name__ == "__main__":
model = FCN16s(21)
print(model)
input = torch.rand(2, 3, 224,224)
#target = torch.zeros(4, 512, 512).cuda()
#model.eval()
#print(model)
loss = model(input)
print(loss)
print(loss.shape)
import torch
from thop import profile
from torchsummary import summary
flop,params=profile(model,input_size=(1,3,512,512))
print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop/1e9, params/1e6))