高速公路违停检测
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

150 行
5.3KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import os
  4. os.environ['CUDA_VISIBLE_DEVICES'] = '1'
  5. from logger import setup_logger
  6. from models.model_stages import BiSeNet
  7. from predict_city.heliushuju import Heliushuju
  8. import cv2
  9. import sys
  10. import torch
  11. import torch.nn as nn
  12. from torch.utils.data import DataLoader
  13. import torch.nn.functional as F
  14. import torch.distributed as dist
  15. import os.path as osp
  16. import logging
  17. import time
  18. import numpy as np
  19. from tqdm import tqdm
  20. import math
  21. import pandas as pd
  22. import matplotlib.pyplot as plt
  23. # from cv2 import getTickCount, getTickFrequency
  24. class MscEvalV0(object):
  25. def __init__(self, scale=0.5, ignore_label=255):
  26. self.ignore_label = ignore_label
  27. self.scale = scale
  28. def __call__(self, net, dl, n_classes):
  29. # evaluate
  30. label_info = get_label_info('./class_dict.csv')
  31. hist = torch.zeros(n_classes, n_classes).cuda().detach()
  32. diter = enumerate(tqdm(dl))
  33. for i, (imgs, label, img_tt) in diter:
  34. loop_start = cv2.getTickCount()
  35. N, _, H, W = label.shape
  36. label = label.squeeze(1).cuda()
  37. size = label.size()[-2:]
  38. imgs = imgs.cuda()
  39. N, C, H, W = imgs.size()
  40. new_hw = [int(H*self.scale), int(W*self.scale)]
  41. print(new_hw)
  42. imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
  43. logits = net(imgs)[0]
  44. loop_time = cv2.getTickCount() - loop_start
  45. tool_time = loop_time/(cv2.getTickFrequency())
  46. running_fps = int(1/tool_time)
  47. print('running_fps:', running_fps)
  48. logits = F.interpolate(logits, size=size,
  49. mode='bilinear', align_corners=True)
  50. probs = torch.softmax(logits, dim=1)
  51. preds = torch.argmax(probs, dim=1)
  52. preds_squeeze = preds.squeeze(0)
  53. preds_squeeze_predict = colour_code_segmentation(np.array(preds_squeeze.cpu()), label_info)
  54. print(preds_squeeze_predict.shape)
  55. preds_squeeze_predict = cv2.resize(np.uint(preds_squeeze_predict), (W, H))
  56. save_path = './demo/' + img_tt[0] + '.png'
  57. cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR))
  58. # preds_squeeze_predict = preds_squeeze.cpu().numpy().copy()
  59. # plt.imshow(preds_squeeze_predict) ;plt.show()
  60. # preds_3chs = np.zeros( (*preds_squeeze_predict.shape,3 ))
  61. # preds_3chs[...,0]=preds_squeeze_predict.copy()
  62. # preds_3chs[...,1]=preds_squeeze_predict.copy()
  63. # preds_3chs[...,2]=preds_squeeze_predict.copy()
  64. # preds_3chs = (preds_3chs*255).astype(np.uint8)
  65. #
  66. # # print('####line66',preds_squeeze_predict.shape)
  67. # preds_squeeze_predict = cv2.resize(np.uint(preds_squeeze_predict), (W,H))
  68. # save_path = './demo/' + img_tt[0] + '.png'
  69. # cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR))
  70. # print('#####DEBUG#####')
  71. # sys.exit(0)
  72. def colour_code_segmentation(image, label_values):
  73. label_values = [label_values[key] for key in label_values]
  74. colour_codes = np.array(label_values)
  75. x = colour_codes[image.astype(int)]
  76. return x
  77. def get_label_info(csv_path):
  78. ann = pd.read_csv(csv_path)
  79. label = {}
  80. for iter, row in ann.iterrows():
  81. label_name = row['name']
  82. r = row['r']
  83. g = row['g']
  84. b = row['b']
  85. label[label_name] = [int(r), int(g), int(b)]
  86. return label
  87. def evaluatev0(respth='', dspth='', backbone='', scale=0.75, use_boundary_2=False,
  88. use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  89. print('scale', scale)
  90. ## dataset
  91. batchsize = 1
  92. n_workers = 0
  93. # dsval = Heliushuju(dspth, mode='val') # 原始
  94. dsval = Heliushuju(dspth, mode='test') # 改动
  95. dl = DataLoader(dsval,
  96. batch_size = batchsize,
  97. shuffle = False,
  98. num_workers = n_workers,
  99. drop_last = False)
  100. n_classes = 3
  101. print("backbone:", backbone)
  102. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  103. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  104. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  105. use_conv_last=use_conv_last)
  106. net.load_state_dict(torch.load(respth))
  107. net.cuda()
  108. net.eval()
  109. with torch.no_grad():
  110. single_scale = MscEvalV0(scale=scale)
  111. single_scale(net, dl, 2)
  112. if __name__ == "__main__":
  113. # STDC2-Seg75 mIoU 0.7704
  114. # 原始
  115. # evaluatev0('/host/data/segmentation/Buildings_checkpoints/checkpoints3/wurenji_train_STDC1-Seg/pths/model_maxmIOU75.pth',
  116. # dspth='/host/data/segmentation/Buildings2/images_12/', backbone='STDCNet813', scale=0.75,
  117. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  118. # 改动
  119. evaluatev0('./checkpoints_1720/wurenji_train_STDC1-Seg/pths/model_final.pth',
  120. dspth='./data/segmentation/shuiyufenge_1720/', backbone='STDCNet813', scale=0.75,
  121. use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)