城管三模型代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 line
5.0KB

  1. import os
  2. import time
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from DMPRUtils.DMPR_process import DMPR_process, plot_points
  7. from DMPRUtils.model.detector import DirectionalPointDetector
  8. from DMPRUtils.yolo_net import Model
  9. from DMPR_YOLO.jointUtil import dmpr_yolo
  10. from STDCUtils.STDC_process import STDC_process
  11. from STDCUtils.models.model_stages import BiSeNet
  12. from STDC_YOLO.yolo_stdc_joint import stdc_yolo
  13. from conf import config
  14. from models.experimental import attempt_load
  15. from models.yolo_process import yolo_process
  16. from utils.plots import plot_one_box
  17. from utils.torch_utils import select_device
  18. def main():
  19. ##预先设置的参数
  20. device_ = '0' ##选定模型,可选 cpu,'0','1'
  21. ##以下参数目前不可改
  22. Detweights = 'weights/urbanManagement/yolo/best1201.pt'
  23. seg_nclass = 2
  24. DMPRweights = "weights/urbanManagement/DMPR/dp_detector_372_1204.pth"
  25. conf_thres, iou_thres, classes = 0.25, 0.45, 3
  26. labelnames = "weights/yolov5/class5/labelnames.json"
  27. rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127],
  28. [127, 255, 0], [0, 255, 127], [0, 127, 255], [127, 0, 255], [255, 127, 255], [255, 255, 127],
  29. [127, 255, 255], [0, 255, 255], [255, 127, 255], [127, 255, 255], [0, 127, 0], [0, 0, 127],
  30. [0, 255, 255]]
  31. allowedList = [0, 1, 2, 3]
  32. ##加载模型,准备好显示字符
  33. device = select_device(device_)
  34. half = device.type != 'cpu' # half precision only supported on CUDA
  35. # yolov5 model
  36. model = attempt_load(Detweights, map_location=device)
  37. if half:
  38. model.half()
  39. # load args
  40. args = config.get_parser_for_inference().parse_args()
  41. # STDC model
  42. STDC_model = BiSeNet(backbone=args.backbone, n_classes=args.n_classes,
  43. use_boundary_2=args.use_boundary_2, use_boundary_4=args.use_boundary_4,
  44. use_boundary_8=args.use_boundary_8, use_boundary_16=args.use_boundary_16,
  45. use_conv_last=args.use_conv_last).to(device)
  46. STDC_model.load_state_dict(torch.load(args.respth))
  47. STDC_model.eval()
  48. # DMPR model
  49. # DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  50. # DMPRmodel.load_state_dict(torch.load(DMPRweights))
  51. DMPRmodel = Model(args.cfg, ch=3).to(device)
  52. DMPRmodel.load_state_dict(torch.load(DMPRweights))
  53. # 图像测试
  54. # impth = 'images/input'
  55. # impth = 'images/debug'
  56. # impth = '/home/thsw/ssd/zjc/cityManagement_test'
  57. impth = '/home/thsw/WJ/zjc/AI/images/pic2'
  58. # impth = '/home/thsw/WJ/zjc/AI_old/images/input_0'
  59. # outpth = 'images/output'
  60. outpth = 'images/debug_out'
  61. folders = os.listdir(impth)
  62. for file in folders:
  63. imgpath = os.path.join(impth, file)
  64. img0 = cv2.imread(imgpath)
  65. assert img0 is not None, 'Image Not Found ' + imgpath
  66. t_start = time.time()
  67. # yolo process
  68. det0 = yolo_process(img0, model, device, args, half)
  69. det0 = det0.cpu().detach().numpy()
  70. t_yolo = time.time()
  71. print(f't_yolo. ({t_yolo - t_start:.3f}s)')
  72. t_stdc = time.time()
  73. # STDC process
  74. det2 = STDC_process(img0, STDC_model, device, args.stdc_new_hw)
  75. det2[det2 == 1] = 255
  76. t_stdc_inf = time.time()
  77. print(f't_stdc_inf. ({t_stdc_inf - t_stdc:.3f}s)')
  78. # STDC joint yolo
  79. det0 = stdc_yolo(det2, det0)
  80. t_stdc_yolo = time.time()
  81. print(f't_stdc_joint. ({t_stdc_yolo - t_stdc_inf:.3f}s)')
  82. # plot所有box
  83. # for *xyxy, conf, cls in reversed(det0):
  84. # label = f'{int(cls)} {conf:.2f}'
  85. # plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
  86. # DMPR process
  87. det1 = DMPR_process(img0, DMPRmodel, device, args)
  88. det1 = det1.cpu().detach().numpy()
  89. #
  90. t_dmpr = time.time()
  91. print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)')
  92. # 绘制角点
  93. plot_points(img0, det1)
  94. # yolo joint DMPR
  95. cls = 0 #需要过滤的box类别
  96. joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border)
  97. #
  98. t_joint = time.time()
  99. print(f't_joint. ({t_joint - t_dmpr:.3f}s)')
  100. t_end = time.time()
  101. print(f'Done. ({t_end - t_start:.3f}s)')
  102. # 绘制膨胀box
  103. for *xyxy, flag in dilate_box:
  104. plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
  105. # #
  106. # # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box
  107. for *xyxy, conf, cls, flag in reversed(joint_det):
  108. if flag == 0:
  109. # label = f'{int(cls)} {conf:.2f}'
  110. label = None
  111. plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
  112. # save
  113. mask = det2[..., np.newaxis].repeat(3, 2)
  114. img_seg = 0.3*mask + img0
  115. save_path = os.path.join(outpth, file)
  116. cv2.imwrite(save_path, img_seg)
  117. if __name__ == '__main__':
  118. main()