交通事故检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
5.3KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import os
  4. os.environ['CUDA_VISIBLE_DEVICES'] = '1'
  5. from logger import setup_logger
  6. from models.model_stages import BiSeNet
  7. from predict_city.heliushuju import Heliushuju
  8. import cv2
  9. import sys
  10. import torch
  11. import torch.nn as nn
  12. from torch.utils.data import DataLoader
  13. import torch.nn.functional as F
  14. import torch.distributed as dist
  15. import os.path as osp
  16. import logging
  17. import time
  18. import numpy as np
  19. from tqdm import tqdm
  20. import math
  21. import pandas as pd
  22. import matplotlib.pyplot as plt
  23. # from cv2 import getTickCount, getTickFrequency
  24. class MscEvalV0(object):
  25. def __init__(self, scale=0.5, ignore_label=255):
  26. self.ignore_label = ignore_label
  27. self.scale = scale
  28. def __call__(self, net, dl, n_classes):
  29. # evaluate
  30. label_info = get_label_info('./class_dict.csv')
  31. hist = torch.zeros(n_classes, n_classes).cuda().detach()
  32. diter = enumerate(tqdm(dl))
  33. for i, (imgs, label, img_tt) in diter:
  34. loop_start = cv2.getTickCount()
  35. N, _, H, W = label.shape
  36. label = label.squeeze(1).cuda()
  37. size = label.size()[-2:]
  38. imgs = imgs.cuda()
  39. N, C, H, W = imgs.size()
  40. new_hw = [int(H*self.scale), int(W*self.scale)]
  41. print(new_hw)
  42. imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
  43. logits = net(imgs)[0]
  44. loop_time = cv2.getTickCount() - loop_start
  45. tool_time = loop_time/(cv2.getTickFrequency())
  46. running_fps = int(1/tool_time)
  47. print('running_fps:', running_fps)
  48. logits = F.interpolate(logits, size=size,
  49. mode='bilinear', align_corners=True)
  50. probs = torch.softmax(logits, dim=1)
  51. preds = torch.argmax(probs, dim=1)
  52. preds_squeeze = preds.squeeze(0)
  53. preds_squeeze_predict = colour_code_segmentation(np.array(preds_squeeze.cpu()), label_info)
  54. print(preds_squeeze_predict.shape)
  55. preds_squeeze_predict = cv2.resize(np.uint(preds_squeeze_predict), (W, H))
  56. save_path = './demo/' + img_tt[0] + '.png'
  57. cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR))
  58. # preds_squeeze_predict = preds_squeeze.cpu().numpy().copy()
  59. # plt.imshow(preds_squeeze_predict) ;plt.show()
  60. # preds_3chs = np.zeros( (*preds_squeeze_predict.shape,3 ))
  61. # preds_3chs[...,0]=preds_squeeze_predict.copy()
  62. # preds_3chs[...,1]=preds_squeeze_predict.copy()
  63. # preds_3chs[...,2]=preds_squeeze_predict.copy()
  64. # preds_3chs = (preds_3chs*255).astype(np.uint8)
  65. #
  66. # # print('####line66',preds_squeeze_predict.shape)
  67. # preds_squeeze_predict = cv2.resize(np.uint(preds_squeeze_predict), (W,H))
  68. # save_path = './demo/' + img_tt[0] + '.png'
  69. # cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR))
  70. # print('#####DEBUG#####')
  71. # sys.exit(0)
  72. def colour_code_segmentation(image, label_values):
  73. label_values = [label_values[key] for key in label_values]
  74. colour_codes = np.array(label_values)
  75. x = colour_codes[image.astype(int)]
  76. return x
  77. def get_label_info(csv_path):
  78. ann = pd.read_csv(csv_path)
  79. label = {}
  80. for iter, row in ann.iterrows():
  81. label_name = row['name']
  82. r = row['r']
  83. g = row['g']
  84. b = row['b']
  85. label[label_name] = [int(r), int(g), int(b)]
  86. return label
  87. def evaluatev0(respth='', dspth='', backbone='', scale=0.75, use_boundary_2=False,
  88. use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  89. print('scale', scale)
  90. ## dataset
  91. batchsize = 1
  92. n_workers = 0
  93. # dsval = Heliushuju(dspth, mode='val') # 原始
  94. dsval = Heliushuju(dspth, mode='test') # 改动
  95. dl = DataLoader(dsval,
  96. batch_size = batchsize,
  97. shuffle = False,
  98. num_workers = n_workers,
  99. drop_last = False)
  100. n_classes = 3
  101. print("backbone:", backbone)
  102. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  103. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  104. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  105. use_conv_last=use_conv_last)
  106. net.load_state_dict(torch.load(respth))
  107. net.cuda()
  108. net.eval()
  109. with torch.no_grad():
  110. single_scale = MscEvalV0(scale=scale)
  111. single_scale(net, dl, 2)
  112. if __name__ == "__main__":
  113. # STDC2-Seg75 mIoU 0.7704
  114. # 原始
  115. # evaluatev0('/host/data/segmentation/Buildings_checkpoints/checkpoints3/wurenji_train_STDC1-Seg/pths/model_maxmIOU75.pth',
  116. # dspth='/host/data/segmentation/Buildings2/images_12/', backbone='STDCNet813', scale=0.75,
  117. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  118. # 改动
  119. evaluatev0('./checkpoints_1720/wurenji_train_STDC1-Seg/pths/model_final.pth',
  120. dspth='./data/segmentation/shuiyufenge_1720/', backbone='STDCNet813', scale=0.75,
  121. use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)