城管三模型代码
Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

139 lines
5.0KB

  1. import os
  2. import time
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from DMPRUtils.DMPR_process import DMPR_process, plot_points
  7. from DMPRUtils.model.detector import DirectionalPointDetector
  8. from DMPRUtils.yolo_net import Model
  9. from DMPR_YOLO.jointUtil import dmpr_yolo
  10. from STDCUtils.STDC_process import STDC_process
  11. from STDCUtils.models.model_stages import BiSeNet
  12. from STDC_YOLO.yolo_stdc_joint import stdc_yolo
  13. from conf import config
  14. from models.experimental import attempt_load
  15. from models.yolo_process import yolo_process
  16. from utils.plots import plot_one_box
  17. from utils.torch_utils import select_device
  18. def main():
  19. ##预先设置的参数
  20. device_ = '0' ##选定模型,可选 cpu,'0','1'
  21. ##以下参数目前不可改
  22. Detweights = 'weights/urbanManagement/yolo/best1201.pt'
  23. seg_nclass = 2
  24. DMPRweights = "weights/urbanManagement/DMPR/dp_detector_372_1204.pth"
  25. conf_thres, iou_thres, classes = 0.25, 0.45, 3
  26. labelnames = "weights/yolov5/class5/labelnames.json"
  27. rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127],
  28. [127, 255, 0], [0, 255, 127], [0, 127, 255], [127, 0, 255], [255, 127, 255], [255, 255, 127],
  29. [127, 255, 255], [0, 255, 255], [255, 127, 255], [127, 255, 255], [0, 127, 0], [0, 0, 127],
  30. [0, 255, 255]]
  31. allowedList = [0, 1, 2, 3]
  32. ##加载模型,准备好显示字符
  33. device = select_device(device_)
  34. half = device.type != 'cpu' # half precision only supported on CUDA
  35. # yolov5 model
  36. model = attempt_load(Detweights, map_location=device)
  37. if half:
  38. model.half()
  39. # load args
  40. args = config.get_parser_for_inference().parse_args()
  41. # STDC model
  42. STDC_model = BiSeNet(backbone=args.backbone, n_classes=args.n_classes,
  43. use_boundary_2=args.use_boundary_2, use_boundary_4=args.use_boundary_4,
  44. use_boundary_8=args.use_boundary_8, use_boundary_16=args.use_boundary_16,
  45. use_conv_last=args.use_conv_last).to(device)
  46. STDC_model.load_state_dict(torch.load(args.respth))
  47. STDC_model.eval()
  48. # DMPR model
  49. # DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device)
  50. # DMPRmodel.load_state_dict(torch.load(DMPRweights))
  51. DMPRmodel = Model(args.cfg, ch=3).to(device)
  52. DMPRmodel.load_state_dict(torch.load(DMPRweights))
  53. # 图像测试
  54. # impth = 'images/input'
  55. # impth = 'images/debug'
  56. # impth = '/home/thsw/ssd/zjc/cityManagement_test'
  57. impth = '/home/thsw/WJ/zjc/AI/images/pic2'
  58. # impth = '/home/thsw/WJ/zjc/AI_old/images/input_0'
  59. # outpth = 'images/output'
  60. outpth = 'images/debug_out'
  61. folders = os.listdir(impth)
  62. for file in folders:
  63. imgpath = os.path.join(impth, file)
  64. img0 = cv2.imread(imgpath)
  65. assert img0 is not None, 'Image Not Found ' + imgpath
  66. t_start = time.time()
  67. # yolo process
  68. det0 = yolo_process(img0, model, device, args, half)
  69. det0 = det0.cpu().detach().numpy()
  70. t_yolo = time.time()
  71. print(f't_yolo. ({t_yolo - t_start:.3f}s)')
  72. t_stdc = time.time()
  73. # STDC process
  74. det2 = STDC_process(img0, STDC_model, device, args.stdc_new_hw)
  75. det2[det2 == 1] = 255
  76. t_stdc_inf = time.time()
  77. print(f't_stdc_inf. ({t_stdc_inf - t_stdc:.3f}s)')
  78. # STDC joint yolo
  79. det0 = stdc_yolo(det2, det0)
  80. t_stdc_yolo = time.time()
  81. print(f't_stdc_joint. ({t_stdc_yolo - t_stdc_inf:.3f}s)')
  82. # plot所有box
  83. # for *xyxy, conf, cls in reversed(det0):
  84. # label = f'{int(cls)} {conf:.2f}'
  85. # plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
  86. # DMPR process
  87. det1 = DMPR_process(img0, DMPRmodel, device, args)
  88. det1 = det1.cpu().detach().numpy()
  89. #
  90. t_dmpr = time.time()
  91. print(f't_dmpr. ({t_dmpr - t_yolo:.3f}s)')
  92. # 绘制角点
  93. plot_points(img0, det1)
  94. # yolo joint DMPR
  95. cls = 0 #需要过滤的box类别
  96. joint_det, dilate_box = dmpr_yolo(det1, det0, img0.shape, cls, args.scale_ratio, args.border)
  97. #
  98. t_joint = time.time()
  99. print(f't_joint. ({t_joint - t_dmpr:.3f}s)')
  100. t_end = time.time()
  101. print(f'Done. ({t_end - t_start:.3f}s)')
  102. # 绘制膨胀box
  103. for *xyxy, flag in dilate_box:
  104. plot_one_box(xyxy, img0, color=rainbows[int(cls)], line_thickness=2)
  105. # #
  106. # # 绘制删除满足 在膨胀框内 && 角度差小于90度 的box
  107. for *xyxy, conf, cls, flag in reversed(joint_det):
  108. if flag == 0:
  109. # label = f'{int(cls)} {conf:.2f}'
  110. label = None
  111. plot_one_box(xyxy, img0, label=label, color=rainbows[int(cls)], line_thickness=2)
  112. # save
  113. mask = det2[..., np.newaxis].repeat(3, 2)
  114. img_seg = 0.3*mask + img0
  115. save_path = os.path.join(outpth, file)
  116. cv2.imwrite(save_path, img_seg)
  117. if __name__ == '__main__':
  118. main()