高速公路违停检测
Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

125 rindas
4.6KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import os
  4. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  5. from logger import setup_logger
  6. from models.model_stages import BiSeNet
  7. from predict_city.heliushuju2 import Heliushuju
  8. import cv2
  9. import torch
  10. import torch.nn as nn
  11. from torch.utils.data import DataLoader
  12. import torch.nn.functional as F
  13. import torch.distributed as dist
  14. import os.path as osp
  15. import logging
  16. import time
  17. import numpy as np
  18. from tqdm import tqdm
  19. import math
  20. import pandas as pd
  21. # from cv2 import getTickCount, getTickFrequency # 原始
  22. class MscEvalV0(object):
  23. def __init__(self, scale=0.5, ignore_label=255):
  24. self.ignore_label = ignore_label
  25. self.scale = scale
  26. def __call__(self, net, dl, n_classes):
  27. ## evaluate
  28. label_info = get_label_info('./class_dict.csv')
  29. hist = torch.zeros(n_classes, n_classes).cuda().detach()
  30. diter = enumerate(tqdm(dl))
  31. for i, (imgs, _, img_tt) in diter:
  32. loop_start = cv2.getTickCount()
  33. #N, _, H, W = label.shape
  34. #label = label.squeeze(1).cuda()
  35. #size = label.size()[-2:]
  36. imgs = imgs.cuda()
  37. N, C, H, W = imgs.size()
  38. new_hw = [int(H*self.scale), int(W*self.scale)]
  39. print(new_hw)
  40. imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
  41. logits = net(imgs)[0]
  42. logits = F.interpolate(logits, size=imgs.size()[-2:],
  43. mode='bilinear', align_corners=True)
  44. probs = torch.softmax(logits, dim=1)
  45. preds = torch.argmax(probs, dim=1)
  46. preds_squeeze = preds.squeeze(0)
  47. preds_squeeze_predict = colour_code_segmentation(np.array(preds_squeeze.cpu()), label_info)
  48. print(preds_squeeze_predict.shape)
  49. preds_squeeze_predict = cv2.resize(np.uint8(preds_squeeze_predict), (W,H))
  50. save_path = './demo/' + img_tt[0] + '.png'
  51. # save_path = '/host/data/ChangeDetection/niushoushanhe_demo/19_/' + img_tt[0] + '.png'
  52. cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR))
  53. loop_time = cv2.getTickCount() - loop_start
  54. tool_time = loop_time/(cv2.getTickFrequency())
  55. running_fps = int(1/tool_time)
  56. print('running_fps:', running_fps)
  57. def colour_code_segmentation(image, label_values):
  58. label_values = [label_values[key] for key in label_values]
  59. colour_codes = np.array(label_values)
  60. x = colour_codes[image.astype(int)]
  61. return x
  62. def get_label_info(csv_path):
  63. ann = pd.read_csv(csv_path)
  64. label = {}
  65. for iter, row in ann.iterrows():
  66. label_name = row['name']
  67. r = row['r']
  68. g = row['g']
  69. b = row['b']
  70. label[label_name] = [int(r), int(g), int(b)]
  71. return label
  72. def evaluatev0(respth='', dspth='', backbone='', scale=0.75, use_boundary_2=False,
  73. use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  74. print('scale', scale)
  75. ## dataset
  76. batchsize = 1
  77. n_workers = 0
  78. # dsval = Heliushuju(dspth, mode='val') # 原始
  79. dsval = Heliushuju(dspth, mode='test') # 改动
  80. dl = DataLoader(dsval,
  81. batch_size = batchsize,
  82. shuffle = False,
  83. num_workers = n_workers,
  84. drop_last = False)
  85. n_classes = 2####################################################################
  86. print("backbone:", backbone)
  87. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  88. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  89. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  90. use_conv_last=use_conv_last)
  91. net.load_state_dict(torch.load(respth))
  92. net.cuda()
  93. net.eval()
  94. with torch.no_grad():
  95. single_scale = MscEvalV0(scale=scale)
  96. single_scale(net, dl, 2)
  97. if __name__ == "__main__":
  98. #STDC2-Seg75 mIoU 0.7704
  99. # 原始
  100. # evaluatev0('/host/data/segmentation/Buildings_checkpoints/checkpoints3/wurenji_train_STDC1-Seg/pths/model_maxmIOU75.pth',
  101. # dspth='/host/data/segmentation/Buildings3/images_12/', backbone='STDCNet813', scale=0.75,
  102. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  103. # 改动
  104. evaluatev0('./checkpoints_1720/wurenji_train_STDC1-Seg/pths/model_maxmIOU75.pth',
  105. dspth='./data/segmentation/shuiyufenge_1720/', backbone='STDCNet813', scale=0.75,
  106. use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)