城管三模型代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.7KB

  1. import os
  2. import time
  3. import cv2
  4. import torch
  5. from DMPRUtils.DMPR_process import DMPR_process, plot_points
  6. from DMPRUtils.model.detector import DirectionalPointDetector
  7. from DMPR_YOLO.jointUtil import dmpr_yolo
  8. from conf import config
  9. from models.experimental import attempt_load
  10. from models.yolo_process import yolo_process
  11. from utils.plots import plot_one_box
  12. from utils.torch_utils import select_device
  13. def main():
  14. ##预先设置的参数
  15. device_ = '0' ##选定模型,可选 cpu,'0','1'
  16. ##以下参数目前不可改
  17. Detweights = 'weights/urbanManagement/yolo/best.pt'
  18. seg_nclass = 2
  19. DMPRweights = "weights/urbanManagement/DMPR/dp_detector_499.pth"
  20. conf_thres, iou_thres, classes = 0.25, 0.45, 3
  21. labelnames = "weights/yolov5/class5/labelnames.json"
  22. rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127],
  23. [127, 255, 0], [0, 255, 127], [0, 127, 255], [127, 0, 255], [255, 127, 255], [255, 255, 127],
  24. [127, 255, 255], [0, 255, 255], [255, 127, 255], [127, 255, 255], [0, 127, 0], [0, 0, 127],
  25. [0, 255, 255]]
  26. allowedList = [0, 1, 2, 3]
  27. ##加载模型,准备好显示字符
  28. device = select_device(device_)
  29. half = device.type != 'cpu' # half precision only supported on CUDA
  30. # yolov5 model
  31. model = attempt_load(Detweights, map_location=device)
  32. if half:
  33. model.half()
  34. # DMPR model
  35. args = config.get_parser_for_inference().parse_args()
  36. DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  37. DMPRmodel.load_state_dict(torch.load(DMPRweights))
  38. # 图像测试
  39. # impth = 'images/input'
  40. impth = 'images/debug'
  41. # outpth = 'images/output'
  42. outpth = 'images/debug_out'
  43. folders = os.listdir(impth)
  44. for file in folders:
  45. imgpath = os.path.join(impth, file)
  46. img0 = cv2.imread(imgpath)
  47. assert img0 is not None, 'Image Not Found ' + imgpath
  48. # t_start = time.time()
  49. # yolo process
  50. det0 = yolo_process(img0, model, device, args, half)
  51. det0 = det0.cpu().detach().numpy()
  52. # t_yolo = time.time()
  53. # print(f't_yolo. ({t_yolo - t_start:.3f}s)')
  54. # plot所有box
  55. # for *xyxy, conf, cls in reversed(det0):
  56. # label = f'{int(cls)} {conf:.2f}'
  57. # plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
  58. # DMPR process
  59. det1 = DMPR_process(img0, DMPRmodel, device, args)
  60. det1 = det1.cpu().detach().numpy()
  61. # t_dmpr = time.time()
  62. # print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)')
  63. # 绘制角点
  64. plot_points(img0, det1)
  65. # save
  66. # cv2.imwrite(file, img0)
  67. # yolo joint DMPR
  68. cls = 0 #需要过滤的box类别
  69. joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls)
  70. # t_joint = time.time()
  71. # print(f't_joint. ({t_joint - t_dmpr:.3f}s)')
  72. # t_end = time.time()
  73. # print(f'Done. ({t_end - t_start:.3f}s)')
  74. # 绘制膨胀box
  75. for *xyxy, flag in dilate_box:
  76. plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
  77. #
  78. # # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box
  79. for *xyxy, conf, cls, flag in reversed(joint_det):
  80. if flag == 0:
  81. # label = f'{int(cls)} {conf:.2f}'
  82. label = None
  83. plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
  84. # save
  85. save_path = os.path.join(outpth, file)
  86. cv2.imwrite(save_path, img0)
  87. if __name__ == '__main__':
  88. main()