"""Base Model for Semantic Segmentation""" import torch.nn as nn from ..nn import JPU from .base_models.resnetv1b import resnet50_v1s, resnet101_v1s, resnet152_v1s __all__ = ['SegBaseModel'] class SegBaseModel(nn.Module): r"""Base Model for Semantic Segmentation Parameters ---------- backbone : string Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 'resnet101' or 'resnet152'). """ def __init__(self, nclass, aux, backbone='resnet50', jpu=False, pretrained_base=True, **kwargs): super(SegBaseModel, self).__init__() dilated = False if jpu else True self.aux = aux self.nclass = nclass if backbone == 'resnet50': self.pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) elif backbone == 'resnet101': self.pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) elif backbone == 'resnet152': self.pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) else: raise RuntimeError('unknown backbone: {}'.format(backbone)) self.jpu = JPU([512, 1024, 2048], width=512, **kwargs) if jpu else None def base_forward(self, x): """forwarding pre-trained network""" x = self.pretrained.conv1(x) x = self.pretrained.bn1(x) x = self.pretrained.relu(x) x = self.pretrained.maxpool(x) c1 = self.pretrained.layer1(x) c2 = self.pretrained.layer2(c1) c3 = self.pretrained.layer3(c2) c4 = self.pretrained.layer4(c3) if self.jpu: return self.jpu(c1, c2, c3, c4) else: return c1, c2, c3, c4 #返回的是layer1,2,3,4的输出 def evaluate(self, x): """evaluating network with inputs and targets""" return self.forward(x)[0] def demo(self, x): pred = self.forward(x) if self.aux: pred = pred[0] return pred