交通事故检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
6.6KB

  1. from models.model_stages import BiSeNet
  2. from predict_city.heliushuju import Heliushuju
  3. from torch.utils.data import DataLoader
  4. import pandas as pd
  5. import numpy as np
  6. import os
  7. import argparse
  8. import cv2
  9. import torch
  10. import torchvision.transforms as transforms
  11. from trafficDetectionUtils import PostProcessing
  12. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  13. class MscEvalV0(object):
  14. def __init__(self, scaleH=1/3, scaleW=1/3, ignore_label=255):
  15. self.ignore_label = ignore_label
  16. self.scaleH = scaleH
  17. self.scaleW = scaleW
  18. self.to_tensor = transforms.Compose([
  19. transforms.ToTensor(),
  20. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
  21. ])
  22. def __call__(self, net, dl, n_classes):
  23. # evaluate
  24. label_info = get_label_info('./class_dict.csv')
  25. maskPath = '../trafficDetectionTestData/trafficAccidentTest/masks'
  26. testImagePath = '../trafficDetectionTestData/trafficAccidentTest/images'
  27. File1 = os.listdir(testImagePath)
  28. for file in File1:
  29. print('####beg to :', file)
  30. txtRowContent = []
  31. VehicleCoordinate = []
  32. singleTxtContent = []
  33. txtPath = '../trafficDetectionTestData/trafficAccidentTest/detections' + os.sep + file[:-4] + '.txt'
  34. testImage = testImagePath + os.sep + file
  35. testImageArray = cv2.imread(testImage)
  36. txtContent = open(txtPath, 'r', encoding='utf-8')
  37. content = txtContent.readlines()
  38. for line in content:
  39. if line[0].isnumeric():
  40. e = line.splitlines(False)[0]
  41. f = e.split(',', -1)
  42. if float(f[0]) == 0:
  43. for i in range(len(f)):
  44. txtRowContent.append(float(f[i]))
  45. VehicleCoordinate.append((int(txtRowContent[1]), int(txtRowContent[2])))
  46. VehicleCoordinate.append((int(txtRowContent[3]), int(txtRowContent[4])))
  47. singleTxtContent.append(txtRowContent)
  48. txtRowContent = []
  49. # print("line53, singleTxtContent: ", singleTxtContent)
  50. txtContent.close()
  51. # 字典形式传参数
  52. traffic_dict = {'label_info': label_info, 'RoadArea': 16000, 'roadVehicleAngle': 15,
  53. 'vehicleCOOR': VehicleCoordinate, 'roundness': 0.7, 'cls': 0, 'vehicleFactor': 0.6, 'det': singleTxtContent,
  54. 'modelSize': (640, 360), 'testImageName': file, 'radius': 50, 'distanceFlag': False, 'vehicleFlag': False}
  55. save_path = './demo/' + file
  56. mask = cv2.imread(maskPath + os.sep + file[:-4] + '_mask.png')
  57. mask = mask[:, :, 0]
  58. targetList, time_infos = PostProcessing(mask, testImageArray, traffic_dict, file)
  59. # print("line64", time_infos, mask.shape, traffic_dict)
  60. print("line64", time_infos)
  61. # print("line65", targetList)
  62. # 在测试图片上画出检测框
  63. for i in range(len(targetList)):
  64. if targetList[i][10] != 3:
  65. X1 = targetList[i][1]
  66. Y1 = targetList[i][2]
  67. X2 = targetList[i][3]
  68. Y2 = targetList[i][4]
  69. cv2.rectangle(testImageArray, (int(X1), int(Y1)), (int(X2), int(Y2)), (0, 0, 255), thickness=3, lineType=cv2.LINE_AA)
  70. font = cv2.FONT_HERSHEY_SIMPLEX
  71. cv2.putText(testImageArray, str(format(targetList[i][9], ".2f")), (int(X1) + 4, int(Y1 - 1)), font, 1, (0, 255, 0), 2, cv2.LINE_AA)
  72. cv2.imwrite(save_path, testImageArray)
  73. def get_label_info(csv_path):
  74. ann = pd.read_csv(csv_path)
  75. label = {}
  76. for iter, row in ann.iterrows():
  77. label_name = row['name']
  78. r = row['r']
  79. g = row['g']
  80. b = row['b']
  81. label[label_name] = [int(r), int(g), int(b)]
  82. return label
  83. def evaluatev0(respth='', dspth='', backbone='', scaleH=1/3, scaleW=1/3,use_boundary_2=False,
  84. use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  85. # dataset
  86. batchsize = 1
  87. n_workers = 0
  88. dsval = Heliushuju(dspth, mode='test')
  89. dl = DataLoader(dsval,
  90. batch_size=batchsize,
  91. shuffle=False,
  92. num_workers=n_workers,
  93. drop_last=False)
  94. n_classes = 3
  95. # print("backbone:", backbone)
  96. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  97. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  98. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  99. use_conv_last=use_conv_last)
  100. net.load_state_dict(torch.load(respth))
  101. net.cuda()
  102. net.eval()
  103. with torch.no_grad():
  104. single_scale = MscEvalV0(scaleH=scaleH, scaleW=scaleW)
  105. single_scale(net, dl, 3)
  106. if __name__ == "__main__":
  107. parser = argparse.ArgumentParser()
  108. parser.add_argument('--weights', nargs='+', type=str, default='./model_save/pths/best.pt', help='model.pt path(s)')
  109. parser.add_argument('--source', type=str, default='./data/test/images', help='source') # file/folder, 0 for webcam
  110. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
  111. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
  112. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
  113. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  114. parser.add_argument('--view-img', action='store_true', help='display results')
  115. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
  116. parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
  117. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
  118. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
  119. parser.add_argument('--augment', action='store_true', help='augmented inference')
  120. parser.add_argument('--update', action='store_true', help='update all models')
  121. opt = parser.parse_args()
  122. # stdc_360X640_highWay.pth model_final.pth
  123. evaluatev0(respth='./model_save/pths/stdc_360X640_highWay.pth', dspth='../trafficDetectionTestData/trafficAccidentTest/masks', backbone='STDCNet813', scaleH=1 / 3,
  124. scaleW=1 / 3, use_boundary_2=False, use_boundary_4=False, use_boundary_8=False,
  125. use_boundary_16=False, use_conv_last=False)