You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

364 lines
13KB

  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. import os,sys
  5. #print( os.path.abspath( os.path.dirname(os.path.dirname(__file__) )) )
  6. sys.path.append(os.path.abspath( os.path.dirname(os.path.dirname(__file__) )) )
  7. from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
  8. accuracy, get_world_size, interpolate,
  9. is_dist_avail_and_initialized)
  10. from .backbone import build_backbone
  11. from .matcher import build_matcher_crowd
  12. import numpy as np
  13. import time
  14. # the network frmawork of the regression branch
  15. class RegressionModel(nn.Module):
  16. def __init__(self, num_features_in, num_anchor_points=4, feature_size=256):
  17. super(RegressionModel, self).__init__()
  18. self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
  19. self.act1 = nn.ReLU()
  20. self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
  21. self.act2 = nn.ReLU()
  22. self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
  23. self.act3 = nn.ReLU()
  24. self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
  25. self.act4 = nn.ReLU()
  26. self.output = nn.Conv2d(feature_size, num_anchor_points * 2, kernel_size=3, padding=1)
  27. # sub-branch forward
  28. def forward(self, x):
  29. out = self.conv1(x)
  30. out = self.act1(out)
  31. out = self.conv2(out)
  32. out = self.act2(out)
  33. out = self.output(out)
  34. out = out.permute(0, 2, 3, 1)
  35. return out.contiguous().view(out.shape[0], -1, 2)
  36. # the network frmawork of the classification branch
  37. class ClassificationModel(nn.Module):
  38. def __init__(self, num_features_in, num_anchor_points=4, num_classes=80, prior=0.01, feature_size=256):
  39. super(ClassificationModel, self).__init__()
  40. self.num_classes = num_classes
  41. self.num_anchor_points = num_anchor_points
  42. self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
  43. self.act1 = nn.ReLU()
  44. self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
  45. self.act2 = nn.ReLU()
  46. self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
  47. self.act3 = nn.ReLU()
  48. self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
  49. self.act4 = nn.ReLU()
  50. self.output = nn.Conv2d(feature_size, num_anchor_points * num_classes, kernel_size=3, padding=1)
  51. self.output_act = nn.Sigmoid()
  52. # sub-branch forward
  53. def forward(self, x):
  54. out = self.conv1(x)
  55. out = self.act1(out)
  56. out = self.conv2(out)
  57. out = self.act2(out)
  58. out = self.output(out)
  59. out1 = out.permute(0, 2, 3, 1)
  60. batch_size, width, height, _ = out1.shape
  61. out2 = out1.view(batch_size, width, height, self.num_anchor_points, self.num_classes)
  62. return out2.contiguous().view(x.shape[0], -1, self.num_classes)
  63. # generate the reference points in grid layout
  64. def generate_anchor_points(stride=16, row=3, line=3):
  65. row_step = stride / row
  66. line_step = stride / line
  67. shift_x = (np.arange(1, line + 1) - 0.5) * line_step - stride / 2
  68. shift_y = (np.arange(1, row + 1) - 0.5) * row_step - stride / 2
  69. shift_x, shift_y = np.meshgrid(shift_x, shift_y)
  70. anchor_points = np.vstack((
  71. shift_x.ravel(), shift_y.ravel()
  72. )).transpose()
  73. return anchor_points
  74. # shift the meta-anchor to get an acnhor points
  75. def shift(shape, stride, anchor_points):
  76. shift_x = (np.arange(0, shape[1]) + 0.5) * stride
  77. shift_y = (np.arange(0, shape[0]) + 0.5) * stride
  78. shift_x, shift_y = np.meshgrid(shift_x, shift_y)
  79. shifts = np.vstack((
  80. shift_x.ravel(), shift_y.ravel()
  81. )).transpose()
  82. A = anchor_points.shape[0]
  83. K = shifts.shape[0]
  84. all_anchor_points = (anchor_points.reshape((1, A, 2)) + shifts.reshape((1, K, 2)).transpose((1, 0, 2)))
  85. all_anchor_points = all_anchor_points.reshape((K * A, 2))
  86. return all_anchor_points
  87. # this class generate all reference points on all pyramid levels
  88. class AnchorPoints(nn.Module):
  89. def __init__(self, pyramid_levels=None, strides=None, row=3, line=3):
  90. super(AnchorPoints, self).__init__()
  91. if pyramid_levels is None:
  92. self.pyramid_levels = [3, 4, 5, 6, 7]
  93. else:
  94. self.pyramid_levels = pyramid_levels
  95. if strides is None:
  96. self.strides = [2 ** x for x in self.pyramid_levels]
  97. self.row = row
  98. self.line = line
  99. def forward(self, image):
  100. image_shape = image.shape[2:]
  101. image_shape = np.array(image_shape)
  102. image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]
  103. all_anchor_points = np.zeros((0, 2)).astype(np.float32)
  104. # get reference points for each level
  105. for idx, p in enumerate(self.pyramid_levels):
  106. anchor_points = generate_anchor_points(2**p, row=self.row, line=self.line)
  107. shifted_anchor_points = shift(image_shapes[idx], self.strides[idx], anchor_points)
  108. all_anchor_points = np.append(all_anchor_points, shifted_anchor_points, axis=0)
  109. all_anchor_points = np.expand_dims(all_anchor_points, axis=0)
  110. # send reference points to device
  111. if torch.cuda.is_available():
  112. return torch.from_numpy(all_anchor_points.astype(np.float32)).cuda()
  113. else:
  114. return torch.from_numpy(all_anchor_points.astype(np.float32))
  115. class Decoder(nn.Module):
  116. def __init__(self, C3_size, C4_size, C5_size, feature_size=256):
  117. super(Decoder, self).__init__()
  118. # upsample C5 to get P5 from the FPN paper
  119. self.P5_1 = nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0)
  120. self.P5_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
  121. self.P5_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
  122. # add P5 elementwise to C4
  123. self.P4_1 = nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0)
  124. self.P4_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
  125. self.P4_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
  126. # add P4 elementwise to C3
  127. self.P3_1 = nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0)
  128. self.P3_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
  129. self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
  130. def forward(self, inputs):
  131. C3, C4, C5 = inputs
  132. P5_x = self.P5_1(C5)
  133. P5_upsampled_x = self.P5_upsampled(P5_x)
  134. P5_x = self.P5_2(P5_x)
  135. P4_x = self.P4_1(C4)
  136. P4_x = P5_upsampled_x + P4_x
  137. P4_upsampled_x = self.P4_upsampled(P4_x)
  138. P4_x = self.P4_2(P4_x)
  139. P3_x = self.P3_1(C3)
  140. P3_x = P3_x + P4_upsampled_x
  141. P3_x = self.P3_2(P3_x)
  142. return [P3_x, P4_x, P5_x]
  143. # the defenition of the P2PNet model
  144. class P2PNet(nn.Module):
  145. def __init__(self, backbone, row=2, line=2,anchorFlag=True):
  146. super().__init__()
  147. self.backbone = backbone
  148. self.num_classes = 2
  149. self.anchorFlag = anchorFlag
  150. # the number of all anchor points
  151. num_anchor_points = row * line
  152. self.regression = RegressionModel(num_features_in=256, num_anchor_points=num_anchor_points)
  153. self.classification = ClassificationModel(num_features_in=256, \
  154. num_classes=self.num_classes, \
  155. num_anchor_points=num_anchor_points)
  156. if self.anchorFlag:
  157. self.anchor_points = AnchorPoints(pyramid_levels=[3,], row=row, line=line)
  158. self.fpn = Decoder(256, 512, 512)
  159. def forward(self, samples: NestedTensor):
  160. # get the backbone features
  161. features = self.backbone(samples)
  162. # forward the feature pyramid
  163. features_fpn = self.fpn([features[1], features[2], features[3]])
  164. batch_size = features[0].shape[0]
  165. # print("line227", batch_size)
  166. # run the regression and classification branch
  167. regression = self.regression(features_fpn[1]) * 100 # 8x
  168. classification = self.classification(features_fpn[1])
  169. if self.anchorFlag:
  170. anchor_points = self.anchor_points(samples).repeat(batch_size, 1, 1)
  171. #decode the points as prediction
  172. output_coord = regression + anchor_points
  173. else:
  174. output_coord = regression
  175. output_class = classification
  176. out = {'pred_logits': output_class, 'pred_points': output_coord}
  177. return out
  178. class SetCriterion_Crowd(nn.Module):
  179. def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
  180. """ Create the criterion.
  181. Parameters:
  182. num_classes: number of object categories, omitting the special no-object category
  183. matcher: module able to compute a matching between targets and proposals
  184. weight_dict: dict containing as key the names of the losses and as values their relative weight.
  185. eos_coef: relative classification weight applied to the no-object category
  186. losses: list of all the losses to be applied. See get_loss for list of available losses.
  187. """
  188. super().__init__()
  189. self.num_classes = num_classes
  190. self.matcher = matcher
  191. self.weight_dict = weight_dict
  192. self.eos_coef = eos_coef
  193. self.losses = losses
  194. empty_weight = torch.ones(self.num_classes + 1)
  195. empty_weight[0] = self.eos_coef
  196. self.register_buffer('empty_weight', empty_weight)
  197. def loss_labels(self, outputs, targets, indices, num_points):
  198. """Classification loss (NLL)
  199. targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
  200. """
  201. assert 'pred_logits' in outputs
  202. src_logits = outputs['pred_logits']
  203. idx = self._get_src_permutation_idx(indices)
  204. target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
  205. target_classes = torch.full(src_logits.shape[:2], 0,
  206. dtype=torch.int64, device=src_logits.device)
  207. target_classes[idx] = target_classes_o
  208. loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
  209. losses = {'loss_ce': loss_ce}
  210. return losses
  211. def loss_points(self, outputs, targets, indices, num_points):
  212. assert 'pred_points' in outputs
  213. idx = self._get_src_permutation_idx(indices)
  214. src_points = outputs['pred_points'][idx]
  215. target_points = torch.cat([t['point'][i] for t, (_, i) in zip(targets, indices)], dim=0)
  216. loss_bbox = F.mse_loss(src_points, target_points, reduction='none')
  217. losses = {}
  218. losses['loss_point'] = loss_bbox.sum() / num_points
  219. return losses
  220. def _get_src_permutation_idx(self, indices):
  221. # permute predictions following indices
  222. batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
  223. src_idx = torch.cat([src for (src, _) in indices])
  224. return batch_idx, src_idx
  225. def _get_tgt_permutation_idx(self, indices):
  226. # permute targets following indices
  227. batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  228. tgt_idx = torch.cat([tgt for (_, tgt) in indices])
  229. return batch_idx, tgt_idx
  230. def get_loss(self, loss, outputs, targets, indices, num_points, **kwargs):
  231. loss_map = {
  232. 'labels': self.loss_labels,
  233. 'points': self.loss_points,
  234. }
  235. assert loss in loss_map, f'do you really want to compute {loss} loss?'
  236. return loss_map[loss](outputs, targets, indices, num_points, **kwargs)
  237. def forward(self, outputs, targets):
  238. """ This performs the loss computation.
  239. Parameters:
  240. outputs: dict of tensors, see the output specification of the model for the format
  241. targets: list of dicts, such that len(targets) == batch_size.
  242. The expected keys in each dict depends on the losses applied, see each loss' doc
  243. """
  244. output1 = {'pred_logits': outputs['pred_logits'], 'pred_points': outputs['pred_points']}
  245. indices1 = self.matcher(output1, targets)
  246. num_points = sum(len(t["labels"]) for t in targets)
  247. num_points = torch.as_tensor([num_points], dtype=torch.float, device=next(iter(output1.values())).device)
  248. if is_dist_avail_and_initialized():
  249. torch.distributed.all_reduce(num_points)
  250. num_boxes = torch.clamp(num_points / get_world_size(), min=1).item()
  251. losses = {}
  252. for loss in self.losses:
  253. losses.update(self.get_loss(loss, output1, targets, indices1, num_boxes))
  254. return losses
  255. # create the P2PNet model
  256. def build(args, training):
  257. # treats persons as a single class
  258. num_classes = 1
  259. backbone = build_backbone(args)
  260. model = P2PNet(backbone, args.row, args.line,anchorFlag=args.anchorFlag)
  261. if not training:
  262. return model
  263. weight_dict = {'loss_ce': 1, 'loss_points': args.point_loss_coef}
  264. losses = ['labels', 'points']
  265. matcher = build_matcher_crowd(args)
  266. criterion = SetCriterion_Crowd(num_classes, \
  267. matcher=matcher, weight_dict=weight_dict, \
  268. eos_coef=args.eos_coef, losses=losses)
  269. return model, criterion