您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

65 行
2.7KB

  1. """Defines the detector network structure."""
  2. import torch
  3. from torch import nn
  4. from DMPRUtils.model.network import define_halve_unit, define_detector_block
  5. class YetAnotherDarknet(nn.modules.Module):
  6. """Yet another darknet, imitating darknet-53 with depth of darknet-19."""
  7. def __init__(self, input_channel_size, depth_factor):
  8. super(YetAnotherDarknet, self).__init__()
  9. layers = []
  10. # 0
  11. layers += [nn.Conv2d(input_channel_size, depth_factor, kernel_size=3,
  12. stride=1, padding=1, bias=False)]
  13. layers += [nn.BatchNorm2d(depth_factor)]
  14. layers += [nn.LeakyReLU(0.1)]
  15. # 1
  16. layers += define_halve_unit(depth_factor)
  17. layers += define_detector_block(depth_factor)
  18. # 2
  19. depth_factor *= 2
  20. layers += define_halve_unit(depth_factor)
  21. layers += define_detector_block(depth_factor)
  22. # 3
  23. depth_factor *= 2
  24. layers += define_halve_unit(depth_factor)
  25. layers += define_detector_block(depth_factor)
  26. layers += define_detector_block(depth_factor)
  27. # 4
  28. depth_factor *= 2
  29. layers += define_halve_unit(depth_factor)
  30. layers += define_detector_block(depth_factor)
  31. layers += define_detector_block(depth_factor)
  32. # 5
  33. depth_factor *= 2
  34. layers += define_halve_unit(depth_factor)
  35. layers += define_detector_block(depth_factor)
  36. self.model = nn.Sequential(*layers)
  37. def forward(self, *x):
  38. return self.model(x[0])
  39. class DirectionalPointDetector(nn.modules.Module):
  40. """Detector for point with direction."""
  41. def __init__(self, input_channel_size, depth_factor, output_channel_size):
  42. super(DirectionalPointDetector, self).__init__()
  43. self.extract_feature = YetAnotherDarknet(input_channel_size,
  44. depth_factor)
  45. layers = []
  46. layers += define_detector_block(16 * depth_factor)
  47. layers += define_detector_block(16 * depth_factor)
  48. layers += [nn.Conv2d(32 * depth_factor, output_channel_size,
  49. kernel_size=1, stride=1, padding=0, bias=False)]
  50. self.predict = nn.Sequential(*layers)
  51. def forward(self, *x):
  52. prediction = self.predict(self.extract_feature(x[0]))
  53. # 4 represents that there are 4 value: confidence, shape, offset_x,
  54. # offset_y, whose range is between [0, 1].
  55. point_pred, angle_pred = torch.split(prediction, 4, dim=1)
  56. point_pred = torch.sigmoid(point_pred)
  57. angle_pred = torch.tanh(angle_pred)
  58. return torch.cat((point_pred, angle_pred), dim=1)