用kafka接收消息
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.7KB

  1. """Joint Pyramid Upsampling"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. __all__ = ['JPU']
  6. class SeparableConv2d(nn.Module):
  7. def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1,
  8. dilation=1, bias=False, norm_layer=nn.BatchNorm2d):
  9. super(SeparableConv2d, self).__init__()
  10. self.conv = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias)
  11. self.bn = norm_layer(inplanes)
  12. self.pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
  13. def forward(self, x):
  14. x = self.conv(x)
  15. x = self.bn(x)
  16. x = self.pointwise(x)
  17. return x
  18. # copy from: https://github.com/wuhuikai/FastFCN/blob/master/encoding/nn/customize.py
  19. class JPU(nn.Module):
  20. def __init__(self, in_channels, width=512, norm_layer=nn.BatchNorm2d, **kwargs):
  21. super(JPU, self).__init__()
  22. self.conv5 = nn.Sequential(
  23. nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False),
  24. norm_layer(width),
  25. nn.ReLU(True))
  26. self.conv4 = nn.Sequential(
  27. nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False),
  28. norm_layer(width),
  29. nn.ReLU(True))
  30. self.conv3 = nn.Sequential(
  31. nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False),
  32. norm_layer(width),
  33. nn.ReLU(True))
  34. self.dilation1 = nn.Sequential(
  35. SeparableConv2d(3 * width, width, 3, padding=1, dilation=1, bias=False),
  36. norm_layer(width),
  37. nn.ReLU(True))
  38. self.dilation2 = nn.Sequential(
  39. SeparableConv2d(3 * width, width, 3, padding=2, dilation=2, bias=False),
  40. norm_layer(width),
  41. nn.ReLU(True))
  42. self.dilation3 = nn.Sequential(
  43. SeparableConv2d(3 * width, width, 3, padding=4, dilation=4, bias=False),
  44. norm_layer(width),
  45. nn.ReLU(True))
  46. self.dilation4 = nn.Sequential(
  47. SeparableConv2d(3 * width, width, 3, padding=8, dilation=8, bias=False),
  48. norm_layer(width),
  49. nn.ReLU(True))
  50. def forward(self, *inputs):
  51. feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])]
  52. size = feats[-1].size()[2:]
  53. feats[-2] = F.interpolate(feats[-2], size, mode='bilinear', align_corners=True)
  54. feats[-3] = F.interpolate(feats[-3], size, mode='bilinear', align_corners=True)
  55. feat = torch.cat(feats, dim=1)
  56. feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)],
  57. dim=1)
  58. return inputs[0], inputs[1], inputs[2], feat