城管三模型代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
8.2KB

  1. import math
  2. import os
  3. import time
  4. from collections import namedtuple
  5. import cv2
  6. import numpy as np
  7. import torch
  8. from torchvision.transforms import ToTensor
  9. from DMPRUtils.model import DirectionalPointDetector
  10. from conf import config
  11. from utils.datasets import letterbox
  12. from utils.general import clip_coords
  13. from utils.torch_utils import select_device
  14. MarkingPoint = namedtuple('MarkingPoint', ['x', 'y', 'direction', 'shape'])
  15. def plot_points(image, pred_points, line_thickness=3):
  16. """Plot marking points on the image."""
  17. if pred_points.size:
  18. tl = line_thickness or round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1 # line/font thickness
  19. tf = max(tl - 1, 1) # font thickness
  20. for conf, *point in pred_points:
  21. p0_x, p0_y = int(point[0]), int(point[1])
  22. cos_val = math.cos(point[2])
  23. sin_val = math.sin(point[2])
  24. p1_x = int(p0_x + 20 * cos_val * tl)
  25. p1_y = int(p0_y + 20 * sin_val * tl)
  26. p2_x = int(p0_x - 10 * sin_val * tl)
  27. p2_y = int(p0_y + 10 * cos_val * tl)
  28. p3_x = int(p0_x + 10 * sin_val * tl)
  29. p3_y = int(p0_y - 10 * cos_val * tl)
  30. cv2.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255), thickness=tl)
  31. cv2.putText(image, str(float(conf)), (p0_x, p0_y), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), thickness=tf)
  32. if point[3] > 0.5:
  33. cv2.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255), thickness=tl)
  34. else:
  35. cv2.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255), thickness=tf)
  36. def preprocess_image(image):
  37. """Preprocess numpy image to torch tensor."""
  38. if image.shape[0] != 640 or image.shape[1] != 640:
  39. image = cv2.resize(image, (640, 640))
  40. return torch.unsqueeze(ToTensor()(image), 0)
  41. def non_maximum_suppression(pred_points):
  42. """Perform non-maxmum suppression on marking points."""
  43. suppressed = [False] * len(pred_points)
  44. for i in range(len(pred_points) - 1):
  45. for j in range(i + 1, len(pred_points)):
  46. i_x = pred_points[i][1].x
  47. i_y = pred_points[i][1].y
  48. j_x = pred_points[j][1].x
  49. j_y = pred_points[j][1].y
  50. # 0.0625 = 1 / 16
  51. if abs(j_x - i_x) < 0.0625 and abs(j_y - i_y) < 0.0625:
  52. idx = i if pred_points[i][0] < pred_points[j][0] else j
  53. suppressed[idx] = True
  54. if any(suppressed):
  55. unsupres_pred_points = []
  56. for i, supres in enumerate(suppressed):
  57. if not supres:
  58. unsupres_pred_points.append(pred_points[i])
  59. return unsupres_pred_points
  60. return pred_points
  61. def get_predicted_points(prediction, thresh):
  62. """Get marking points from one predicted feature map."""
  63. assert isinstance(prediction, torch.Tensor)
  64. predicted_points = []
  65. prediction = prediction.detach().cpu().numpy()
  66. for i in range(prediction.shape[1]):
  67. for j in range(prediction.shape[2]):
  68. if prediction[0, i, j] >= thresh:
  69. xval = (j + prediction[2, i, j]) / prediction.shape[2]
  70. yval = (i + prediction[3, i, j]) / prediction.shape[1]
  71. # if not (config.BOUNDARY_THRESH <= xval <= 1-config.BOUNDARY_THRESH
  72. # and config.BOUNDARY_THRESH <= yval <= 1-config.BOUNDARY_THRESH):
  73. # continue
  74. cos_value = prediction[4, i, j]
  75. sin_value = prediction[5, i, j]
  76. direction = math.atan2(sin_value, cos_value)
  77. marking_point = MarkingPoint(
  78. xval, yval, direction, prediction[1, i, j])
  79. predicted_points.append((prediction[0, i, j], marking_point))
  80. return non_maximum_suppression(predicted_points)
  81. def get_predicted_points2(prediction, thresh):
  82. """Get marking points from one predicted feature map."""
  83. assert isinstance(prediction, torch.Tensor)
  84. # predicted_points = []
  85. # prediction = prediction.detach().cpu().numpy()
  86. # for i in range(prediction.shape[1]):
  87. # for j in range(prediction.shape[2]):
  88. # if prediction[0, i, j] >= thresh:
  89. # xval = (j + prediction[2, i, j]) / prediction.shape[2]
  90. # yval = (i + prediction[3, i, j]) / prediction.shape[1]
  91. # # if not (config.BOUNDARY_THRESH <= xval <= 1-config.BOUNDARY_THRESH
  92. # # and config.BOUNDARY_THRESH <= yval <= 1-config.BOUNDARY_THRESH):
  93. # # continue
  94. # cos_value = prediction[4, i, j]
  95. # sin_value = prediction[5, i, j]
  96. # direction = math.atan2(sin_value, cos_value)
  97. # marking_point = MarkingPoint(
  98. # xval, yval, direction, prediction[1, i, j])
  99. # predicted_points.append((prediction[0, i, j], marking_point))
  100. prediction = prediction.permute(1, 2, 0).contiguous() # prediction (20, 20, 6)
  101. height = prediction.shape[0]
  102. width = prediction.shape[1]
  103. j = torch.arange(prediction.shape[1], device=prediction.device).float().repeat(prediction.shape[0], 1).unsqueeze(dim=2)
  104. i = torch.arange(prediction.shape[0], device=prediction.device).float().view(prediction.shape[0], 1).repeat(1, prediction.shape[1]).unsqueeze(dim=2)
  105. prediction = torch.cat((prediction, j, i), dim=2)
  106. # 过滤小于thresh的置信度
  107. prediction = prediction[prediction[..., 0] > thresh]
  108. prediction[..., 2] = (prediction[..., 2] + prediction[..., 6]) / width
  109. prediction[..., 3] = (prediction[..., 3] + prediction[..., 7]) / height
  110. direction = torch.atan2(prediction[..., 5], prediction[..., 4])
  111. prediction = torch.stack((prediction[..., 0], prediction[..., 2], prediction[..., 3], direction, prediction[..., 1]), dim=1)
  112. # return non_maximum_suppression(predicted_points)
  113. return prediction
  114. def detect_marking_points(detector, image, thresh, device):
  115. """Given image read from opencv, return detected marking points."""
  116. # t1 = time.time()
  117. # torch.cuda.synchronize(device)
  118. prediction = detector(preprocess_image(image).to(device))
  119. # torch.cuda.synchronize(device)
  120. # t2 = time.time()
  121. # print(f'detector: {t2 - t1:.3f}s')
  122. return get_predicted_points2(prediction[0], thresh)
  123. def scale_coords2(img1_shape, coords, img0_shape, ratio_pad=None):
  124. # Rescale coords (xy) from img1_shape to img0_shape
  125. if ratio_pad is None: # calculate from img0_shape
  126. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  127. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  128. else:
  129. gain = ratio_pad[0][0]
  130. pad = ratio_pad[1]
  131. # 百分比x,y转换为实际x,y
  132. height, width = img1_shape
  133. coords[:, 0] = torch.round(width * coords[:, 0] - 0.5)
  134. coords[:, 1] = torch.round(height * coords[:, 1] - 0.5)
  135. coords[:, 0] -= pad[0] # x padding
  136. coords[:, 1] -= pad[1] # y padding
  137. coords[:, :3] /= gain
  138. #恢复成原始图片尺寸
  139. coords[:, 0].clamp_(0, img0_shape[1])
  140. coords[:, 1].clamp_(0, img0_shape[0])
  141. return coords
  142. def DMPR_process(img0, model, device, args):
  143. height, width, _ = img0.shape
  144. img, ratio, (dw, dh) = letterbox(img0, args.dmprimg_size, auto=False)
  145. det = detect_marking_points(model, img, args.dmpr_thresh, device)
  146. # if not pred:
  147. # return torch.tensor([])
  148. # # 由list转为tensor
  149. # det = torch.tensor([[conf, *tup] for conf, tup in pred])
  150. if len(det):
  151. det[:, 1:3] = scale_coords2(img.shape[:2], det[:, 1:3], img0.shape)
  152. # conf, x, y, θ, shape
  153. return det
  154. if __name__ == '__main__':
  155. impath = r'I:\zjc\weiting1\Images'
  156. file = 'DJI_0001_8.jpg'
  157. imgpath = os.path.join(impath, file)
  158. img0 = cv2.imread(imgpath)
  159. device_ = '0'
  160. device = select_device(device_)
  161. args = config.get_parser_for_inference().parse_args()
  162. model = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  163. weights = r"E:\pycharmProject\DMPR-PS\weights\dp_detector_499.pth"
  164. model.load_state_dict(torch.load(weights))
  165. det = DMPR_process(img0, model, device, args)
  166. plot_points(img0, det)
  167. cv2.imwrite(file, img0, [int(cv2.IMWRITE_JPEG_QUALITY), 100])