|
- import os
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import sys
- sys.path.extend(['/home/thsw2/WJ/src/yolov5/segutils/','../..','..' ])
- from core.models.base_models.vgg import vgg16
-
- __all__ = ['get_fcn32s', 'get_fcn16s', 'get_fcn8s',
- 'get_fcn32s_vgg16_voc', 'get_fcn16s_vgg16_voc', 'get_fcn8s_vgg16_voc']
-
-
- class FCN32s(nn.Module):
- """There are some difference from original fcn"""
-
- def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True,
- norm_layer=nn.BatchNorm2d, **kwargs):
- super(FCN32s, self).__init__()
- self.aux = aux
- if backbone == 'vgg16':
- self.pretrained = vgg16(pretrained=pretrained_base).features
- else:
- raise RuntimeError('unknown backbone: {}'.format(backbone))
- self.head = _FCNHead(512, nclass, norm_layer)
- if aux:
- self.auxlayer = _FCNHead(512, nclass, norm_layer)
-
- self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
-
- def forward(self, x):
- size = x.size()[2:]
- pool5 = self.pretrained(x)
-
- outputs = []
- out = self.head(pool5)
- out = F.interpolate(out, size, mode='bilinear', align_corners=True)
- outputs.append(out)
-
- if self.aux:
- auxout = self.auxlayer(pool5)
- auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
- outputs.append(auxout)
-
- return tuple(outputs)
-
-
- class FCN16s(nn.Module):
- def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True, norm_layer=nn.BatchNorm2d, **kwargs):
- super(FCN16s, self).__init__()
- self.aux = aux
- if backbone == 'vgg16':
- self.pretrained = vgg16(pretrained=pretrained_base).features
- else:
- raise RuntimeError('unknown backbone: {}'.format(backbone))
- self.pool4 = nn.Sequential(*self.pretrained[:24])
- self.pool5 = nn.Sequential(*self.pretrained[24:])
- self.head = _FCNHead(512, nclass, norm_layer)
- self.score_pool4 = nn.Conv2d(512, nclass, 1)
- if aux:
- self.auxlayer = _FCNHead(512, nclass, norm_layer)
-
- self.__setattr__('exclusive', ['head', 'score_pool4', 'auxlayer'] if aux else ['head', 'score_pool4'])
-
- def forward(self, x):
- pool4 = self.pool4(x)
- pool5 = self.pool5(pool4)
-
- outputs = []
- score_fr = self.head(pool5)
-
- score_pool4 = self.score_pool4(pool4)
-
- upscore2 = F.interpolate(score_fr, score_pool4.size()[2:], mode='bilinear', align_corners=True)
- fuse_pool4 = upscore2 + score_pool4
-
- out = F.interpolate(fuse_pool4, x.size()[2:], mode='bilinear', align_corners=True)
- outputs.append(out)
-
- if self.aux:
- auxout = self.auxlayer(pool5)
- auxout = F.interpolate(auxout, x.size()[2:], mode='bilinear', align_corners=True)
- outputs.append(auxout)
-
- #return tuple(outputs)
- return outputs[0]
-
- class FCN8s(nn.Module):
- def __init__(self, nclass, backbone='vgg16', aux=False, pretrained_base=True, norm_layer=nn.BatchNorm2d, **kwargs):
- super(FCN8s, self).__init__()
- self.aux = aux
- if backbone == 'vgg16':
- self.pretrained = vgg16(pretrained=pretrained_base).features
- else:
- raise RuntimeError('unknown backbone: {}'.format(backbone))
- self.pool3 = nn.Sequential(*self.pretrained[:17])
- self.pool4 = nn.Sequential(*self.pretrained[17:24])
- self.pool5 = nn.Sequential(*self.pretrained[24:])
- self.head = _FCNHead(512, nclass, norm_layer)
- self.score_pool3 = nn.Conv2d(256, nclass, 1)
- self.score_pool4 = nn.Conv2d(512, nclass, 1)
- if aux:
- self.auxlayer = _FCNHead(512, nclass, norm_layer)
-
- self.__setattr__('exclusive',
- ['head', 'score_pool3', 'score_pool4', 'auxlayer'] if aux else ['head', 'score_pool3',
- 'score_pool4'])
-
- def forward(self, x):
- pool3 = self.pool3(x)
- pool4 = self.pool4(pool3)
- pool5 = self.pool5(pool4)
-
- outputs = []
- score_fr = self.head(pool5)
-
- score_pool4 = self.score_pool4(pool4)
- score_pool3 = self.score_pool3(pool3)
-
- upscore2 = F.interpolate(score_fr, score_pool4.size()[2:], mode='bilinear', align_corners=True)
- fuse_pool4 = upscore2 + score_pool4
-
- upscore_pool4 = F.interpolate(fuse_pool4, score_pool3.size()[2:], mode='bilinear', align_corners=True)
- fuse_pool3 = upscore_pool4 + score_pool3
-
- out = F.interpolate(fuse_pool3, x.size()[2:], mode='bilinear', align_corners=True)
- outputs.append(out)
-
- if self.aux:
- auxout = self.auxlayer(pool5)
- auxout = F.interpolate(auxout, x.size()[2:], mode='bilinear', align_corners=True)
- outputs.append(auxout)
-
- return tuple(outputs)
-
-
- class _FCNHead(nn.Module):
- def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, **kwargs):
- super(_FCNHead, self).__init__()
- inter_channels = in_channels // 4
- self.block = nn.Sequential(
- nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
- norm_layer(inter_channels),
- nn.ReLU(inplace=True),
- nn.Dropout(0.1),
- nn.Conv2d(inter_channels, channels, 1)
- )
-
- def forward(self, x):
- return self.block(x)
-
-
- def get_fcn32s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
- pretrained_base=True, **kwargs):
- acronyms = {
- 'pascal_voc': 'pascal_voc',
- 'pascal_aug': 'pascal_aug',
- 'ade20k': 'ade',
- 'coco': 'coco',
- 'citys': 'citys',
- }
- from ..data.dataloader import datasets
- model = FCN32s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
- if pretrained:
- from .model_store import get_model_file
- device = torch.device(kwargs['local_rank'])
- model.load_state_dict(torch.load(get_model_file('fcn32s_%s_%s' % (backbone, acronyms[dataset]), root=root),
- map_location=device))
- return model
-
-
- def get_fcn16s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
- pretrained_base=True, **kwargs):
- acronyms = {
- 'pascal_voc': 'pascal_voc',
- 'pascal_aug': 'pascal_aug',
- 'ade20k': 'ade',
- 'coco': 'coco',
- 'citys': 'citys',
- }
- from ..data.dataloader import datasets
- model = FCN16s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
- if pretrained:
- from .model_store import get_model_file
- device = torch.device(kwargs['local_rank'])
- model.load_state_dict(torch.load(get_model_file('fcn16s_%s_%s' % (backbone, acronyms[dataset]), root=root),
- map_location=device))
- return model
-
-
- def get_fcn8s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/.torch/models',
- pretrained_base=True, **kwargs):
- acronyms = {
- 'pascal_voc': 'pascal_voc',
- 'pascal_aug': 'pascal_aug',
- 'ade20k': 'ade',
- 'coco': 'coco',
- 'citys': 'citys',
- }
- from ..data.dataloader import datasets
- model = FCN8s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
- if pretrained:
- from .model_store import get_model_file
- device = torch.device(kwargs['local_rank'])
- model.load_state_dict(torch.load(get_model_file('fcn8s_%s_%s' % (backbone, acronyms[dataset]), root=root),
- map_location=device))
- return model
-
-
- def get_fcn32s_vgg16_voc(**kwargs):
- return get_fcn32s('pascal_voc', 'vgg16', **kwargs)
-
-
- def get_fcn16s_vgg16_voc(**kwargs):
- return get_fcn16s('pascal_voc', 'vgg16', **kwargs)
-
-
- def get_fcn8s_vgg16_voc(**kwargs):
- return get_fcn8s('pascal_voc', 'vgg16', **kwargs)
-
-
- if __name__ == "__main__":
- model = FCN16s(21)
- print(model)
- input = torch.rand(2, 3, 224,224)
- #target = torch.zeros(4, 512, 512).cuda()
- #model.eval()
- #print(model)
- loss = model(input)
- print(loss)
- print(loss.shape)
- import torch
- from thop import profile
- from torchsummary import summary
- flop,params=profile(model,input_size=(1,3,512,512))
- print('flops:{:.3f}G\nparams:{:.3f}M'.format(flop/1e9, params/1e6))
|