高速公路违停检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
4.6KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import os
  4. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  5. from logger import setup_logger
  6. from models.model_stages import BiSeNet
  7. from predict_city.heliushuju2 import Heliushuju
  8. import cv2
  9. import torch
  10. import torch.nn as nn
  11. from torch.utils.data import DataLoader
  12. import torch.nn.functional as F
  13. import torch.distributed as dist
  14. import os.path as osp
  15. import logging
  16. import time
  17. import numpy as np
  18. from tqdm import tqdm
  19. import math
  20. import pandas as pd
  21. # from cv2 import getTickCount, getTickFrequency # 原始
  22. class MscEvalV0(object):
  23. def __init__(self, scale=0.5, ignore_label=255):
  24. self.ignore_label = ignore_label
  25. self.scale = scale
  26. def __call__(self, net, dl, n_classes):
  27. ## evaluate
  28. label_info = get_label_info('./class_dict.csv')
  29. hist = torch.zeros(n_classes, n_classes).cuda().detach()
  30. diter = enumerate(tqdm(dl))
  31. for i, (imgs, _, img_tt) in diter:
  32. loop_start = cv2.getTickCount()
  33. #N, _, H, W = label.shape
  34. #label = label.squeeze(1).cuda()
  35. #size = label.size()[-2:]
  36. imgs = imgs.cuda()
  37. N, C, H, W = imgs.size()
  38. new_hw = [int(H*self.scale), int(W*self.scale)]
  39. print(new_hw)
  40. imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
  41. logits = net(imgs)[0]
  42. logits = F.interpolate(logits, size=imgs.size()[-2:],
  43. mode='bilinear', align_corners=True)
  44. probs = torch.softmax(logits, dim=1)
  45. preds = torch.argmax(probs, dim=1)
  46. preds_squeeze = preds.squeeze(0)
  47. preds_squeeze_predict = colour_code_segmentation(np.array(preds_squeeze.cpu()), label_info)
  48. print(preds_squeeze_predict.shape)
  49. preds_squeeze_predict = cv2.resize(np.uint8(preds_squeeze_predict), (W,H))
  50. save_path = './demo/' + img_tt[0] + '.png'
  51. # save_path = '/host/data/ChangeDetection/niushoushanhe_demo/19_/' + img_tt[0] + '.png'
  52. cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR))
  53. loop_time = cv2.getTickCount() - loop_start
  54. tool_time = loop_time/(cv2.getTickFrequency())
  55. running_fps = int(1/tool_time)
  56. print('running_fps:', running_fps)
  57. def colour_code_segmentation(image, label_values):
  58. label_values = [label_values[key] for key in label_values]
  59. colour_codes = np.array(label_values)
  60. x = colour_codes[image.astype(int)]
  61. return x
  62. def get_label_info(csv_path):
  63. ann = pd.read_csv(csv_path)
  64. label = {}
  65. for iter, row in ann.iterrows():
  66. label_name = row['name']
  67. r = row['r']
  68. g = row['g']
  69. b = row['b']
  70. label[label_name] = [int(r), int(g), int(b)]
  71. return label
  72. def evaluatev0(respth='', dspth='', backbone='', scale=0.75, use_boundary_2=False,
  73. use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  74. print('scale', scale)
  75. ## dataset
  76. batchsize = 1
  77. n_workers = 0
  78. # dsval = Heliushuju(dspth, mode='val') # 原始
  79. dsval = Heliushuju(dspth, mode='test') # 改动
  80. dl = DataLoader(dsval,
  81. batch_size = batchsize,
  82. shuffle = False,
  83. num_workers = n_workers,
  84. drop_last = False)
  85. n_classes = 2####################################################################
  86. print("backbone:", backbone)
  87. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  88. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  89. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  90. use_conv_last=use_conv_last)
  91. net.load_state_dict(torch.load(respth))
  92. net.cuda()
  93. net.eval()
  94. with torch.no_grad():
  95. single_scale = MscEvalV0(scale=scale)
  96. single_scale(net, dl, 2)
  97. if __name__ == "__main__":
  98. #STDC2-Seg75 mIoU 0.7704
  99. # 原始
  100. # evaluatev0('/host/data/segmentation/Buildings_checkpoints/checkpoints3/wurenji_train_STDC1-Seg/pths/model_maxmIOU75.pth',
  101. # dspth='/host/data/segmentation/Buildings3/images_12/', backbone='STDCNet813', scale=0.75,
  102. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  103. # 改动
  104. evaluatev0('./checkpoints_1720/wurenji_train_STDC1-Seg/pths/model_maxmIOU75.pth',
  105. dspth='./data/segmentation/shuiyufenge_1720/', backbone='STDCNet813', scale=0.75,
  106. use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)