61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
"""Base Model for Semantic Segmentation"""
|
|
import torch.nn as nn
|
|
|
|
from ..nn import JPU
|
|
from .base_models.resnetv1b import resnet50_v1s, resnet101_v1s, resnet152_v1s
|
|
|
|
__all__ = ['SegBaseModel']
|
|
|
|
|
|
class SegBaseModel(nn.Module):
|
|
r"""Base Model for Semantic Segmentation
|
|
|
|
Parameters
|
|
----------
|
|
backbone : string
|
|
Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
|
|
'resnet101' or 'resnet152').
|
|
"""
|
|
|
|
def __init__(self, nclass, aux, backbone='resnet50', jpu=False, pretrained_base=True, **kwargs):
|
|
super(SegBaseModel, self).__init__()
|
|
dilated = False if jpu else True
|
|
self.aux = aux
|
|
self.nclass = nclass
|
|
if backbone == 'resnet50':
|
|
self.pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
|
|
elif backbone == 'resnet101':
|
|
self.pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
|
|
elif backbone == 'resnet152':
|
|
self.pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
|
|
else:
|
|
raise RuntimeError('unknown backbone: {}'.format(backbone))
|
|
|
|
self.jpu = JPU([512, 1024, 2048], width=512, **kwargs) if jpu else None
|
|
|
|
def base_forward(self, x):
|
|
"""forwarding pre-trained network"""
|
|
x = self.pretrained.conv1(x)
|
|
x = self.pretrained.bn1(x)
|
|
x = self.pretrained.relu(x)
|
|
x = self.pretrained.maxpool(x)
|
|
c1 = self.pretrained.layer1(x)
|
|
c2 = self.pretrained.layer2(c1)
|
|
c3 = self.pretrained.layer3(c2)
|
|
c4 = self.pretrained.layer4(c3)
|
|
|
|
if self.jpu:
|
|
return self.jpu(c1, c2, c3, c4)
|
|
else:
|
|
return c1, c2, c3, c4 #返回的是layer1,2,3,4的输出
|
|
|
|
def evaluate(self, x):
|
|
"""evaluating network with inputs and targets"""
|
|
return self.forward(x)[0]
|
|
|
|
def demo(self, x):
|
|
pred = self.forward(x)
|
|
if self.aux:
|
|
pred = pred[0]
|
|
return pred
|