You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

207 lines
7.7KB

  1. import torch.nn.functional as F
  2. import torch
  3. class DecDecoder_test(object):
  4. def __init__(self, K, conf_thresh, num_classes):
  5. self.K = K
  6. self.conf_thresh = conf_thresh
  7. self.num_classes = num_classes
  8. def _topk(self, scores):
  9. batch, cat, height, width = scores.size()
  10. topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), self.K)
  11. topk_inds = topk_inds % (height * width)
  12. topk_ys = (topk_inds // width).int().float()
  13. topk_xs = (topk_inds % width).int().float()
  14. topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), self.K)
  15. topk_clses = (topk_ind // self.K).int()
  16. topk_inds = self._gather_feat( topk_inds.view(batch, -1, 1), topk_ind).view(batch, self.K)
  17. topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, self.K)
  18. topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, self.K)
  19. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  20. def _nms(self, heat, kernel=3):
  21. hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=(kernel - 1) // 2)
  22. keep = (hmax == heat).float()
  23. return heat * keep
  24. def _gather_feat(self, feat, ind, mask=None):
  25. dim = feat.size(2)
  26. ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
  27. feat = feat.gather(1, ind)
  28. '''
  29. if mask is not None:
  30. mask = mask.unsqueeze(2).expand_as(feat)
  31. feat = feat[mask]
  32. feat = feat.view(-1, dim)
  33. '''
  34. return feat
  35. def _tranpose_and_gather_feat(self, feat, ind):
  36. feat = feat.permute(0, 2, 3, 1).contiguous()
  37. feat = feat.view(feat.size(0), -1, feat.size(3))
  38. feat = self._gather_feat(feat, ind)
  39. return feat
  40. def ctdet_decode(self, pr_decs):
  41. heat = pr_decs['hm']
  42. wh = pr_decs['wh']
  43. reg = pr_decs['reg']
  44. cls_theta = pr_decs['cls_theta']
  45. batch, c, height, width = heat.size()
  46. heat = self._nms(heat)
  47. scores, inds, clses, ys, xs = self._topk(heat)
  48. reg = self._tranpose_and_gather_feat(reg, inds)
  49. reg = reg.view(batch, self.K, 2)
  50. xs = xs.view(batch, self.K, 1) + reg[:, :, 0:1]
  51. ys = ys.view(batch, self.K, 1) + reg[:, :, 1:2]
  52. clses = clses.view(batch, self.K, 1).float()
  53. scores = scores.view(batch, self.K, 1)
  54. wh = self._tranpose_and_gather_feat(wh, inds)
  55. wh = wh.view(batch, self.K, 10)
  56. # add
  57. cls_theta = self._tranpose_and_gather_feat(cls_theta, inds)
  58. cls_theta = cls_theta.view(batch, self.K, 1)
  59. mask = (cls_theta>0.8).float().view(batch, self.K, 1)
  60. #
  61. tt_x = (xs+wh[..., 0:1])*mask + (xs)*(1.-mask)
  62. tt_y = (ys+wh[..., 1:2])*mask + (ys-wh[..., 9:10]/2)*(1.-mask)
  63. rr_x = (xs+wh[..., 2:3])*mask + (xs+wh[..., 8:9]/2)*(1.-mask)
  64. rr_y = (ys+wh[..., 3:4])*mask + (ys)*(1.-mask)
  65. bb_x = (xs+wh[..., 4:5])*mask + (xs)*(1.-mask)
  66. bb_y = (ys+wh[..., 5:6])*mask + (ys+wh[..., 9:10]/2)*(1.-mask)
  67. ll_x = (xs+wh[..., 6:7])*mask + (xs-wh[..., 8:9]/2)*(1.-mask)
  68. ll_y = (ys+wh[..., 7:8])*mask + (ys)*(1.-mask)
  69. #
  70. detections = torch.cat([xs, # cen_x
  71. ys, # cen_y
  72. tt_x,
  73. tt_y,
  74. rr_x,
  75. rr_y,
  76. bb_x,
  77. bb_y,
  78. ll_x,
  79. ll_y,
  80. scores,
  81. clses],
  82. dim=2)
  83. return detections
  84. class DecDecoder(object):
  85. def __init__(self, K, conf_thresh, num_classes):
  86. self.K = K
  87. self.conf_thresh = conf_thresh
  88. self.num_classes = num_classes
  89. def _topk(self, scores):
  90. batch, cat, height, width = scores.size()
  91. topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), self.K)
  92. topk_inds = topk_inds % (height * width)
  93. topk_ys = (topk_inds // width).int().float()
  94. topk_xs = (topk_inds % width).int().float()
  95. topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), self.K)
  96. topk_clses = (topk_ind // self.K).int()
  97. topk_inds = self._gather_feat( topk_inds.view(batch, -1, 1), topk_ind).view(batch, self.K)
  98. topk_ys = self._gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, self.K)
  99. topk_xs = self._gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, self.K)
  100. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  101. def _nms(self, heat, kernel=3):
  102. hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=(kernel - 1) // 2)
  103. keep = (hmax == heat).float()
  104. return heat * keep
  105. def _gather_feat(self, feat, ind, mask=None):
  106. dim = feat.size(2)
  107. ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
  108. feat = feat.gather(1, ind)
  109. if mask is not None:
  110. mask = mask.unsqueeze(2).expand_as(feat)
  111. feat = feat[mask]
  112. feat = feat.view(-1, dim)
  113. return feat
  114. def _tranpose_and_gather_feat(self, feat, ind):
  115. feat = feat.permute(0, 2, 3, 1).contiguous()
  116. feat = feat.view(feat.size(0), -1, feat.size(3))
  117. feat = self._gather_feat(feat, ind)
  118. return feat
  119. def ctdet_decode(self, pr_decs):
  120. heat = pr_decs['hm']
  121. wh = pr_decs['wh']
  122. reg = pr_decs['reg']
  123. cls_theta = pr_decs['cls_theta']
  124. batch, c, height, width = heat.size()
  125. heat = self._nms(heat)
  126. scores, inds, clses, ys, xs = self._topk(heat)
  127. reg = self._tranpose_and_gather_feat(reg, inds)
  128. reg = reg.view(batch, self.K, 2)
  129. xs = xs.view(batch, self.K, 1) + reg[:, :, 0:1]
  130. ys = ys.view(batch, self.K, 1) + reg[:, :, 1:2]
  131. clses = clses.view(batch, self.K, 1).float()
  132. scores = scores.view(batch, self.K, 1)
  133. wh = self._tranpose_and_gather_feat(wh, inds)
  134. wh = wh.view(batch, self.K, 10)
  135. # add
  136. cls_theta = self._tranpose_and_gather_feat(cls_theta, inds)
  137. cls_theta = cls_theta.view(batch, self.K, 1)
  138. mask = (cls_theta>0.8).float().view(batch, self.K, 1)
  139. #
  140. tt_x = (xs+wh[..., 0:1])*mask + (xs)*(1.-mask)
  141. tt_y = (ys+wh[..., 1:2])*mask + (ys-wh[..., 9:10]/2)*(1.-mask)
  142. rr_x = (xs+wh[..., 2:3])*mask + (xs+wh[..., 8:9]/2)*(1.-mask)
  143. rr_y = (ys+wh[..., 3:4])*mask + (ys)*(1.-mask)
  144. bb_x = (xs+wh[..., 4:5])*mask + (xs)*(1.-mask)
  145. bb_y = (ys+wh[..., 5:6])*mask + (ys+wh[..., 9:10]/2)*(1.-mask)
  146. ll_x = (xs+wh[..., 6:7])*mask + (xs-wh[..., 8:9]/2)*(1.-mask)
  147. ll_y = (ys+wh[..., 7:8])*mask + (ys)*(1.-mask)
  148. #
  149. detections = torch.cat([xs, # cen_x
  150. ys, # cen_y
  151. tt_x,
  152. tt_y,
  153. rr_x,
  154. rr_y,
  155. bb_x,
  156. bb_y,
  157. ll_x,
  158. ll_y,
  159. scores,
  160. clses],
  161. dim=2)
  162. #return detections
  163. index = (scores>self.conf_thresh).squeeze(0).squeeze(1)
  164. detections = detections[:,index,:]
  165. #print('####line203 decoder.py ', detections.size(),scores.size())
  166. return detections.data.cpu().numpy()