高速公路违停检测
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

156 行
7.7KB

  1. # 最新版违停检测代码
  2. from models.model_stages import BiSeNet
  3. from predict_city.heliushuju import Heliushuju
  4. from torch.utils.data import DataLoader
  5. import numpy as np
  6. import os
  7. import argparse
  8. import cv2
  9. import torch
  10. import torchvision.transforms as transforms
  11. import matplotlib.pyplot as plt
  12. # from complexIllegalParkingUtilsNewest import mixNoParking_road_postprocess
  13. from complexIllegalParkingUtilsNewest import mixNoParking_road_postprocess
  14. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  15. # print("line15", torch.cuda.is_available())
  16. class MscEvalV0(object):
  17. def __init__(self, scaleH=1 / 3, scaleW=1 / 3, ignore_label=255):
  18. self.ignore_label = ignore_label
  19. self.scaleH = scaleH
  20. self.scaleW = scaleW
  21. self.to_tensor = transforms.Compose([
  22. transforms.ToTensor(),
  23. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  24. ])
  25. # IllegalParkingTestData
  26. def __call__(self, net, dl, n_classes):
  27. # evaluate
  28. maskPath = '../IllegalParkingTestData/masks'
  29. testImagePath = '../IllegalParkingTestData/images'
  30. File1 = os.listdir(testImagePath)
  31. for file in File1:
  32. print('####beg to :', file)
  33. txtRowContent = []
  34. saveVehicleCoordinate = []
  35. singleTxtContent = []
  36. txtPath = '../IllegalParkingTestData/detections' + os.sep + file[:-4] + '.txt'
  37. testImage = testImagePath + os.sep + file
  38. testImageArray = cv2.imread(testImage)
  39. txtContent = open(txtPath, 'r', encoding='utf-8')
  40. content = txtContent.readlines()
  41. for line in content:
  42. if line[0].isnumeric():
  43. e = line.splitlines(False)[0]
  44. f = e.split(',', -1)
  45. if float(f[-1]) == 0:
  46. for i in range(len(f)):
  47. txtRowContent.append(float(f[i]))
  48. saveVehicleCoordinate.append((int(txtRowContent[1]), int(txtRowContent[2])))
  49. saveVehicleCoordinate.append((int(txtRowContent[3]), int(txtRowContent[4])))
  50. singleTxtContent.append(txtRowContent)
  51. txtRowContent = []
  52. txtContent.close()
  53. # print("line57, singleTxtContent: ", singleTxtContent)
  54. mask = cv2.imread(maskPath + os.sep + file[:-4] + '_mask.png')
  55. imgName = file[:-4] + '.png'
  56. mask = mask[:, :, 0]
  57. # 字典形式传参数
  58. traffic_dict = {'RoadArea': 16000, 'roundness': 0.5, 'laneArea': 2, 'modelSize': (1920, 1080), 'testImageName': file, 'fitOrder': 2}
  59. # print('####line63: det results ', singleTxtContent, mask.shape, np.max(mask),np.min(mask))
  60. save_path = './demo/' + file
  61. # targetList, time_infos, finalLane, lane_line, abc = mixNoParking_road_postprocess(singleTxtContent, mask, traffic_dict)
  62. # targetList, time_infos = mixNoParking_road_postprocess(singleTxtContent, mask, traffic_dict, imgName)
  63. targetList, time_infos = mixNoParking_road_postprocess(singleTxtContent, mask, traffic_dict)
  64. print('####line66:', time_infos)
  65. # print("line65", targetList)
  66. """在测试图片上画出检测框"""
  67. for i in range(len(targetList)):
  68. if targetList[i][7] != 0:
  69. X1 = targetList[i][0]
  70. Y1 = targetList[i][1]
  71. X2 = targetList[i][2]
  72. Y2 = targetList[i][3]
  73. cv2.rectangle(testImageArray, (int(X1), int(Y1)), (int(X2), int(Y2)), (0, 0, 255), thickness=3,
  74. lineType=cv2.LINE_AA)
  75. font = cv2.FONT_HERSHEY_SIMPLEX
  76. cv2.putText(testImageArray, str(format(targetList[i][6], ".2f")), (int(X1) + 4, int(Y1 - 1)), font, 1,
  77. (0, 255, 0), 2, cv2.LINE_AA)
  78. cv2.imwrite(save_path, testImageArray)
  79. # """分别将最左侧和最右侧车道线簇中的点连起来,并显示"""
  80. # for k in range(len(finalLane)):
  81. # for i in range(len(finalLane[k])):
  82. # if i + 1 <= len(finalLane[k]) - 1:
  83. # cv2.line(lane_line, (int(finalLane[k][i][0]), int(finalLane[k][i][1])),
  84. # (int(finalLane[k][i + 1][0]), int(finalLane[k][i + 1][1])), (0, 0, 255), thickness=2,
  85. # lineType=cv2.LINE_AA)
  86. # else:
  87. # break
  88. # cv2.imwrite('./demo/' + 'realLane_' + '{}'.format(file[:-4]) + '.png', lane_line)
  89. # """分别将最左侧和最右侧车道线簇拟合的二次曲线画出来"""
  90. # y = np.array(list(range(0, 1080)))
  91. # x1 = abc[0] * (y ** 2) + abc[1] * y + abc[2]
  92. # x2 = abc[3] * (y ** 2) + abc[4] * y + abc[5]
  93. # plt.plot(x1, y);
  94. # plt.plot(x2, y);
  95. # plt.imshow(lane_line)
  96. # plt.savefig('./demo/' + 'fitLane_' + '{}'.format(file[:-4]) + '.png')
  97. # plt.show()
  98. def evaluatev0(respth='', dspth='', backbone='', scaleH=1 / 3, scaleW=1 / 3, use_boundary_2=False,
  99. use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  100. # dataset
  101. batchsize = 1
  102. n_workers = 0
  103. dsval = Heliushuju(dspth, mode='test')
  104. dl = DataLoader(dsval,
  105. batch_size=batchsize,
  106. shuffle=False,
  107. num_workers=n_workers,
  108. drop_last=False)
  109. n_classes = 4
  110. # print("backbone:", backbone)
  111. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  112. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  113. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  114. use_conv_last=use_conv_last)
  115. net.load_state_dict(torch.load(respth))
  116. net.cuda()
  117. net.eval()
  118. with torch.no_grad():
  119. single_scale = MscEvalV0(scaleH=scaleH, scaleW=scaleW)
  120. single_scale(net, dl, 4)
  121. if __name__ == "__main__":
  122. parser = argparse.ArgumentParser()
  123. parser.add_argument('--weights', nargs='+', type=str, default='./model_save/pths/best.pt', help='model.pt path(s)')
  124. parser.add_argument('--source', type=str, default='./data/test/images', help='source') # file/folder, 0 for webcam
  125. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  126. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  127. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  128. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  129. parser.add_argument('--view-img', action='store_true', help='display results')
  130. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  131. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  132. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  133. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  134. parser.add_argument('--augment', action='store_true', help='augmented inference')
  135. parser.add_argument('--update', action='store_true', help='update all models')
  136. opt = parser.parse_args()
  137. evaluatev0(respth='./model_save/pths/stdc_360X640_highWayParking.pth',
  138. dspth='/home/thsw/WJ/zyy/IllegalParkingTestData/masks', backbone='STDCNet813',
  139. scaleH=1 / 3,
  140. scaleW=1 / 3, use_boundary_2=False, use_boundary_4=False, use_boundary_8=False,
  141. use_boundary_16=False, use_conv_last=False)