用kafka接收消息
Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

61 lines
2.0KB

  1. """Base Model for Semantic Segmentation"""
  2. import torch.nn as nn
  3. from ..nn import JPU
  4. from .base_models.resnetv1b import resnet50_v1s, resnet101_v1s, resnet152_v1s
  5. __all__ = ['SegBaseModel']
  6. class SegBaseModel(nn.Module):
  7. r"""Base Model for Semantic Segmentation
  8. Parameters
  9. ----------
  10. backbone : string
  11. Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
  12. 'resnet101' or 'resnet152').
  13. """
  14. def __init__(self, nclass, aux, backbone='resnet50', jpu=False, pretrained_base=True, **kwargs):
  15. super(SegBaseModel, self).__init__()
  16. dilated = False if jpu else True
  17. self.aux = aux
  18. self.nclass = nclass
  19. if backbone == 'resnet50':
  20. self.pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
  21. elif backbone == 'resnet101':
  22. self.pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
  23. elif backbone == 'resnet152':
  24. self.pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
  25. else:
  26. raise RuntimeError('unknown backbone: {}'.format(backbone))
  27. self.jpu = JPU([512, 1024, 2048], width=512, **kwargs) if jpu else None
  28. def base_forward(self, x):
  29. """forwarding pre-trained network"""
  30. x = self.pretrained.conv1(x)
  31. x = self.pretrained.bn1(x)
  32. x = self.pretrained.relu(x)
  33. x = self.pretrained.maxpool(x)
  34. c1 = self.pretrained.layer1(x)
  35. c2 = self.pretrained.layer2(c1)
  36. c3 = self.pretrained.layer3(c2)
  37. c4 = self.pretrained.layer4(c3)
  38. if self.jpu:
  39. return self.jpu(c1, c2, c3, c4)
  40. else:
  41. return c1, c2, c3, c4 #返回的是layer1,2,3,4的输出
  42. def evaluate(self, x):
  43. """evaluating network with inputs and targets"""
  44. return self.forward(x)[0]
  45. def demo(self, x):
  46. pred = self.forward(x)
  47. if self.aux:
  48. pred = pred[0]
  49. return pred