You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
9.2KB

  1. import math
  2. import os
  3. import time
  4. from collections import namedtuple
  5. import cv2
  6. import numpy as np
  7. import torch
  8. from torchvision.transforms import ToTensor
  9. from DMPRUtils.model import DirectionalPointDetector
  10. from utils.datasets import letterbox
  11. from utils.general import clip_coords
  12. from utils.torch_utils import select_device
  13. #from DMPRUtils.trtUtils import TrtForwardCase
  14. #import segutils.trtUtils.segTrtForward as TrtForwardCase
  15. from segutils.trtUtils import segTrtForward
  16. MarkingPoint = namedtuple('MarkingPoint', ['x', 'y', 'direction', 'shape'])
  17. def plot_points(image, pred_points, line_thickness=3):
  18. """Plot marking points on the image."""
  19. if len(pred_points):
  20. tl = line_thickness or round(0.002 * (image.shape[0] + image.shape[1]) / 2) + 1 # line/font thickness
  21. tf = max(tl - 1, 1) # font thickness
  22. for conf, *point in pred_points:
  23. p0_x, p0_y = int(point[0]), int(point[1])
  24. cos_val = math.cos(point[2])
  25. sin_val = math.sin(point[2])
  26. p1_x = int(p0_x + 20 * cos_val * tl)
  27. p1_y = int(p0_y + 20 * sin_val * tl)
  28. p2_x = int(p0_x - 10 * sin_val * tl)
  29. p2_y = int(p0_y + 10 * cos_val * tl)
  30. p3_x = int(p0_x + 10 * sin_val * tl)
  31. p3_y = int(p0_y - 10 * cos_val * tl)
  32. cv2.line(image, (p0_x, p0_y), (p1_x, p1_y), (0, 0, 255), thickness=tl)
  33. cv2.putText(image, str(float(conf)), (p0_x, p0_y), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), thickness=tf)
  34. if point[3] > 0.5:
  35. cv2.line(image, (p0_x, p0_y), (p2_x, p2_y), (0, 0, 255), thickness=tl)
  36. else:
  37. cv2.line(image, (p2_x, p2_y), (p3_x, p3_y), (0, 0, 255), thickness=tf)
  38. def preprocess_image(image):
  39. """Preprocess numpy image to torch tensor."""
  40. if image.shape[0] != 640 or image.shape[1] != 640:
  41. image = cv2.resize(image, (640, 640))
  42. return torch.unsqueeze(ToTensor()(image), 0)
  43. def non_maximum_suppression(pred_points):
  44. """Perform non-maxmum suppression on marking points."""
  45. t1 = time.time()
  46. suppressed = [False] * len(pred_points)
  47. for i in range(len(pred_points) - 1):
  48. for j in range(i + 1, len(pred_points)):
  49. i_x = pred_points[i][1].x
  50. i_y = pred_points[i][1].y
  51. j_x = pred_points[j][1].x
  52. j_y = pred_points[j][1].y
  53. # 0.0625 = 1 / 16
  54. if abs(j_x - i_x) < 0.0625 and abs(j_y - i_y) < 0.0625:
  55. idx = i if pred_points[i][0] < pred_points[j][0] else j
  56. suppressed[idx] = True
  57. if any(suppressed):
  58. unsupres_pred_points = []
  59. for i, supres in enumerate(suppressed):
  60. if not supres:
  61. unsupres_pred_points.append(pred_points[i])
  62. return unsupres_pred_points
  63. t2 = time.time()
  64. print(f'nms: {t2 - t1:.3f}s')
  65. return pred_points
  66. def ms(t2,t1):
  67. return ('%.1f '%( (t2-t1)*1000 ) )
  68. def get_predicted_points(prediction, thresh):
  69. """Get marking points from one predicted feature map."""
  70. t1 = time.time()
  71. assert isinstance(prediction, torch.Tensor)
  72. prediction = prediction.permute(1, 2, 0).contiguous() # prediction (20, 20, 6)
  73. height = prediction.shape[0]
  74. width = prediction.shape[1]
  75. j = torch.arange(prediction.shape[1], device=prediction.device).float().repeat(prediction.shape[0], 1).unsqueeze(dim=2)
  76. i = torch.arange(prediction.shape[0], device=prediction.device).float().view(prediction.shape[0], 1).repeat(1,prediction.shape[1]).unsqueeze(dim=2)
  77. prediction = torch.cat((prediction, j, i), dim=2).view(-1, 8).contiguous()
  78. t2 = time.time()
  79. # 过滤小于thresh的置信度
  80. mask = prediction[..., 0] > thresh
  81. t3 = time.time()
  82. prediction = prediction[mask]
  83. t4 = time.time()
  84. prediction[..., 2] = (prediction[..., 2] + prediction[..., 6]) / width
  85. prediction[..., 3] = (prediction[..., 3] + prediction[..., 7]) / height
  86. direction = torch.atan2(prediction[..., 5], prediction[..., 4])
  87. prediction = torch.stack((prediction[..., 0], prediction[..., 2], prediction[..., 3], direction, prediction[..., 1]), dim=1)
  88. t5 = time.time()
  89. timeInfo = 'rerange:%s scoreFilter:%s , getMask:%s stack:%s '%( ms(t2,t1),ms(t3,t2),ms(t4,t3),ms(t5,t4) )
  90. #print('-'*20,timeInfo)
  91. return prediction,timeInfo
  92. def get_predicted_points_np(prediction, thresh):
  93. """Get marking points from one predicted feature map."""
  94. t1 = time.time()
  95. prediction = prediction.permute(1, 2, 0).contiguous() # prediction (20, 20, 6)
  96. t1_1 = time.time()
  97. prediction = prediction.cpu().detach().numpy()
  98. t1_2 = time.time()
  99. height,width = prediction.shape[0:2]
  100. i,j = np.mgrid[0:height, 0:width]
  101. i = np.expand_dims(i,axis=2);j = np.expand_dims(j,axis=2)
  102. #print('##line112:',i.shape,j.shape,prediction.shape)
  103. prediction = np.concatenate( (prediction,i,j),axis=2 )
  104. prediction = prediction.reshape(-1,8)
  105. t2 = time.time()
  106. mask = prediction[..., 0] > thresh
  107. t3 = time.time()
  108. prediction = prediction[mask]
  109. t4 = time.time()
  110. prediction[..., 2] = (prediction[..., 2] + prediction[..., 6]) / width
  111. prediction[..., 3] = (prediction[..., 3] + prediction[..., 7]) / height
  112. direction = np.arctan(prediction[..., 5:6], prediction[..., 4:5])
  113. #print('-'*20,prediction.shape,direction.shape)
  114. prediction = np.hstack((prediction[:, 0:1], prediction[:, 2:3], prediction[:, 3:4], direction, prediction[:, 1:2]))
  115. #print('-line126:','-'*20,type(prediction),prediction.shape)
  116. t5 = time.time()
  117. timeInfo = 'permute:%s Tocpu:%s rerange:%s scoreFilter:%s , getMask:%s stack:%s '%( ms(t1_1,t1) , ms(t1_2,t1_1),ms(t2,t1_2),ms(t3,t2),ms(t4,t3),ms(t5,t4) )
  118. print('-'*20,timeInfo,prediction.shape)
  119. return prediction
  120. def detect_marking_points(detector, image, thresh, device,modelType='pth'):
  121. """Given image read from opencv, return detected marking points."""
  122. t1 = time.time()
  123. image_preprocess = preprocess_image(image).to(device)
  124. if modelType=='pth':
  125. prediction = detector(image_preprocess)
  126. #print(prediction)
  127. elif modelType=='trt':
  128. a=0
  129. prediction = segTrtForward(detector,[image_preprocess ])
  130. #print(prediction)
  131. torch.cuda.synchronize(device)
  132. t2 = time.time()
  133. rets,timeInfo = get_predicted_points(prediction[0], thresh)
  134. string_t2 = ' infer:%s postprocess:%s'%(ms(t2,t1),timeInfo)
  135. return rets
  136. def scale_coords2(img1_shape, coords, img0_shape, ratio_pad=None):
  137. # Rescale coords (xy) from img1_shape to img0_shape
  138. if ratio_pad is None: # calculate from img0_shape
  139. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  140. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  141. else:
  142. gain = ratio_pad[0][0]
  143. pad = ratio_pad[1]
  144. # 百分比x,y转换为实际x,y
  145. height, width = img1_shape
  146. if isinstance(coords, torch.Tensor):
  147. coords[:, 0] = torch.round(width * coords[:, 0] - 0.5)
  148. coords[:, 1] = torch.round(height * coords[:, 1] - 0.5)
  149. else:
  150. coords[:, 0] = (width * coords[:, 0] + 0.5).astype(np.int32)
  151. coords[:, 1] = (height * coords[:, 1] + 0.5).astype(np.int32)
  152. coords[:, 0] -= pad[0] # x padding
  153. coords[:, 1] -= pad[1] # y padding
  154. coords[:, :3] /= gain
  155. #恢复成原始图片尺寸
  156. if isinstance(coords, torch.Tensor):
  157. coords[:, 0].clamp_(0, img0_shape[1])
  158. coords[:, 1].clamp_(0, img0_shape[0])
  159. else:
  160. coords[:, 0] = np.clip( coords[:, 0], 0,img0_shape[1] )
  161. coords[:, 1] = np.clip( coords[:, 1], 0,img0_shape[0] )
  162. return coords
  163. def DMPR_process(img0, model, device, DMPRmodelPar):
  164. t0 = time.time()
  165. height, width, _ = img0.shape
  166. img, ratio, (dw, dh) = letterbox(img0, DMPRmodelPar['dmprimg_size'], auto=False)
  167. t1 = time.time()
  168. #print('###line188:', height, width, img.shape)
  169. det = detect_marking_points(model, img, DMPRmodelPar['dmpr_thresh'], device,modelType=DMPRmodelPar['modelType'])
  170. t2 = time.time()
  171. if len(det):
  172. det[:, 1:3] = scale_coords2(img.shape[:2], det[:, 1:3], img0.shape)
  173. t3 = time.time()
  174. timeInfos = 'dmpr:%1.f (lettbox:%.1f dectect:%.1f scaleBack:%.1f) '%( (t3-t0)*1000,(t1-t0)*1000,(t2-t1)*1000,(t3-t2)*1000, )
  175. return det,timeInfos
  176. if __name__ == '__main__':
  177. impath = r'I:\zjc\weiting1\Images'
  178. file = 'DJI_0001_8.jpg'
  179. imgpath = os.path.join(impath, file)
  180. img0 = cv2.imread(imgpath)
  181. device_ = '0'
  182. device = select_device(device_)
  183. args = config.get_parser_for_inference().parse_args()
  184. model = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  185. weights = r"E:\pycharmProject\DMPR-PS\weights\dp_detector_499.pth"
  186. model.load_state_dict(torch.load(weights))
  187. det = DMPR_process(img0, model, device, args)
  188. plot_points(img0, det)
  189. cv2.imwrite(file, img0, [int(cv2.IMWRITE_JPEG_QUALITY), 100])