交通事故检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
4.7KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import os
  4. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  5. from logger import setup_logger
  6. from models.model_stages import BiSeNet
  7. from predict_city.heliushuju2 import Heliushuju
  8. import cv2
  9. import torch
  10. import torch.nn as nn
  11. from torch.utils.data import DataLoader
  12. import torch.nn.functional as F
  13. import torch.distributed as dist
  14. import os.path as osp
  15. import logging
  16. import time
  17. import numpy as np
  18. from tqdm import tqdm
  19. import math
  20. import pandas as pd
  21. # from cv2 import getTickCount, getTickFrequency
  22. class MscEvalV0(object):
  23. def __init__(self, scale=0.5, ignore_label=255):
  24. self.ignore_label = ignore_label
  25. self.scale = scale
  26. def __call__(self, net, dl, n_classes):
  27. ## evaluate
  28. label_info = get_label_info('./class_dict.csv')
  29. hist = torch.zeros(n_classes, n_classes).cuda().detach()
  30. diter = enumerate(tqdm(dl))
  31. for i, (imgs, img_n, img_tt) in diter:
  32. loop_start = cv2.getTickCount()
  33. time0 = time.time()#################
  34. img_n = img_n[0]
  35. img_n = np.array(img_n.cpu())
  36. #img_n =
  37. #print(img_n.dtype,img_n.shape)
  38. #N, _, H, W = label.shape
  39. #label = label.squeeze(1).cuda()
  40. #size = label.size()[-2:]
  41. imgs = imgs.cuda()
  42. #print('img',imgs.shape)
  43. N, C, H, W = imgs.size()
  44. # new_hw = [360, 640]
  45. new_hw = [int(H * self.scale), int(W * self.scale)]
  46. print(new_hw)
  47. imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
  48. logits = net(imgs)[0]
  49. #print('logits:', logits.shape)
  50. time1 = time.time()#######################
  51. tm = (time1-time0)*1000
  52. print('##precess:%.1f'%(tm))
  53. loop_time = cv2.getTickCount() - loop_start
  54. tool_time = loop_time/(cv2.getTickFrequency())
  55. running_fps = int(1/tool_time)
  56. print('running_fps:', running_fps)
  57. #print('imgs.size():', imgs.size()[-2:])
  58. logits = F.interpolate(logits, size=(H,W),
  59. mode='bilinear', align_corners=True)
  60. probs = torch.softmax(logits, dim=1)
  61. preds = torch.argmax(probs, dim=1)
  62. #print('preds:', preds.shape)
  63. preds_squeeze = preds.squeeze(0)
  64. preds_squeeze[preds_squeeze != 0] = 255
  65. preds_squeeze = np.array(preds_squeeze.cpu())
  66. preds_squeeze = np.uint8(preds_squeeze)
  67. #print('preds_squeeze:', preds_squeeze.shape)
  68. _, binary = cv2.threshold(preds_squeeze,220,255,cv2.THRESH_BINARY)
  69. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
  70. img2 = cv2.drawContours(img_n,contours,-1,(0,0,255),8)
  71. save_path = './demo/' + img_tt[0] + '.png'
  72. cv2.imwrite(save_path,img2)
  73. def colour_code_segmentation(image, label_values):
  74. label_values = [label_values[key] for key in label_values]
  75. colour_codes = np.array(label_values)
  76. x = colour_codes[image.astype(int)]
  77. return x
  78. def get_label_info(csv_path):
  79. ann = pd.read_csv(csv_path)
  80. label = {}
  81. for iter, row in ann.iterrows():
  82. label_name = row['name']
  83. r = row['r']
  84. g = row['g']
  85. b = row['b']
  86. label[label_name] = [int(r), int(g), int(b)]
  87. return label
  88. def evaluatev0(respth='', dspth='', backbone='', scale=0.75, use_boundary_2=False,
  89. use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  90. print('scale', scale)
  91. ## dataset
  92. batchsize = 1
  93. n_workers = 0
  94. dsval = Heliushuju(dspth, mode='val')
  95. dl = DataLoader(dsval,
  96. batch_size = batchsize,
  97. shuffle = False,
  98. num_workers = n_workers,
  99. drop_last = False)
  100. n_classes = 2####################################################################
  101. print("backbone:", backbone)
  102. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  103. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  104. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  105. use_conv_last=use_conv_last)
  106. net.load_state_dict(torch.load(respth))
  107. net.cuda()
  108. net.eval()
  109. with torch.no_grad():
  110. single_scale = MscEvalV0(scale=scale)
  111. single_scale(net, dl, 2)
  112. if __name__ == "__main__":
  113. #STDC2-Seg75 mIoU 0.7704
  114. evaluatev0('./checkpoints_1648/wurenji_train_STDC1-Seg/pths/model_maxmIOU75_0.934_1024.pth', dspth='/home/data/lijiwen/wurenjiqifei/images/', backbone='STDCNet813', scale=0.75,
  115. use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)