交通事故检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

316 lines
11KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import os
  4. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  5. from logger import setup_logger
  6. from models.model_stages import BiSeNet
  7. from heliushuju import Heliushuju
  8. import cv2
  9. import torch
  10. import torch.nn as nn
  11. from torch.utils.data import DataLoader
  12. import torch.nn.functional as F
  13. import torch.distributed as dist
  14. import os.path as osp
  15. import logging
  16. import time
  17. import numpy as np
  18. from tqdm import tqdm
  19. import math
  20. import pandas as pd
  21. class MscEvalV0(object):
  22. def __init__(self, scale=0.5, ignore_label=255):
  23. self.ignore_label = ignore_label
  24. self.scale = scale
  25. def __call__(self, net, dl, n_classes):
  26. ## evaluate
  27. label_info = get_label_info('./class_dict.csv')
  28. hist = torch.zeros(n_classes, n_classes).cuda().detach()
  29. if dist.is_initialized() and dist.get_rank() != 0:
  30. diter = enumerate(dl)
  31. else:
  32. diter = enumerate(tqdm(dl))
  33. for i, (imgs, label) in diter:
  34. # label = torch.argmax(label, dim=4) # 添加
  35. N, _, H, W = label.shape
  36. label = label.squeeze(1).cuda()
  37. size = label.size()[-2:]
  38. imgs = imgs.cuda()
  39. N, C, H, W = imgs.size()
  40. new_hw = [int(H*self.scale), int(W*self.scale)]
  41. imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
  42. logits = net(imgs)[0]
  43. logits = F.interpolate(logits, size=size,
  44. mode='bilinear', align_corners=True)
  45. probs = torch.softmax(logits, dim=1)
  46. preds = torch.argmax(probs, dim=1)
  47. preds_squeeze = preds.squeeze(0)
  48. preds_squeeze_predict = colour_code_segmentation(np.array(preds_squeeze.cpu()), label_info)
  49. preds_squeeze_predict = cv2.resize(np.uint(preds_squeeze_predict), (W,H))
  50. save_path = './demo/predict%d.png' % i
  51. cv2.imwrite(save_path, cv2.cvtColor(np.uint8(preds_squeeze_predict), cv2.COLOR_RGB2BGR))
  52. keep = label != self.ignore_label
  53. hist += torch.bincount(
  54. label[keep] * n_classes + preds[keep],
  55. minlength=n_classes ** 2
  56. ).view(n_classes, n_classes).float()
  57. if dist.is_initialized():
  58. dist.all_reduce(hist, dist.ReduceOp.SUM)
  59. ious = hist.diag() / (hist.sum(dim=0) + hist.sum(dim=1) - hist.diag())
  60. miou = ious.mean()
  61. return miou.item()
  62. def colour_code_segmentation(image, label_values):
  63. label_values = [label_values[key] for key in label_values]
  64. colour_codes = np.array(label_values)
  65. x = colour_codes[image.astype(int)]
  66. return x
  67. def get_label_info(csv_path):
  68. ann = pd.read_csv(csv_path)
  69. label = {}
  70. for iter, row in ann.iterrows():
  71. label_name = row['name']
  72. r = row['r']
  73. g = row['g']
  74. b = row['b']
  75. label[label_name] = [int(r), int(g), int(b)]
  76. return label
  77. def evaluatev0(respth='./pretrained', dspth='./data', backbone='CatNetSmall', scale=0.75, use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  78. print('scale', scale)
  79. print('use_boundary_2', use_boundary_2)
  80. print('use_boundary_4', use_boundary_4)
  81. print('use_boundary_8', use_boundary_8)
  82. print('use_boundary_16', use_boundary_16)
  83. ## dataset
  84. batchsize = 1
  85. n_workers = 0
  86. # dsval = Heliushuju(dspth, mode='val') # 原始
  87. dsval = Heliushuju(dspth, mode='test') # 改动
  88. dl = DataLoader(dsval,
  89. batch_size = batchsize,
  90. shuffle = False,
  91. num_workers = n_workers,
  92. drop_last = False)
  93. n_classes = 2####################################################################
  94. print("backbone:", backbone)
  95. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  96. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  97. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  98. use_conv_last=use_conv_last)
  99. net.load_state_dict(torch.load(respth))
  100. net.cuda()
  101. net.eval()
  102. with torch.no_grad():
  103. single_scale = MscEvalV0(scale=scale)
  104. mIOU = single_scale(net, dl, 2)
  105. logger = logging.getLogger()
  106. logger.info('mIOU is: %s\n', mIOU)
  107. class MscEval(object):
  108. def __init__(self,
  109. model,
  110. dataloader,
  111. scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75],
  112. n_classes = 2,
  113. lb_ignore = 255,
  114. cropsize = 1024,
  115. flip = True,
  116. *args, **kwargs):
  117. self.scales = scales
  118. self.n_classes = n_classes
  119. self.lb_ignore = lb_ignore
  120. self.flip = flip
  121. self.cropsize = cropsize
  122. ## dataloader
  123. self.dl = dataloader
  124. self.net = model
  125. def pad_tensor(self, inten, size):
  126. N, C, H, W = inten.size()
  127. outten = torch.zeros(N, C, size[0], size[1]).cuda()
  128. outten.requires_grad = False
  129. margin_h, margin_w = size[0]-H, size[1]-W
  130. hst, hed = margin_h//2, margin_h//2+H
  131. wst, wed = margin_w//2, margin_w//2+W
  132. outten[:, :, hst:hed, wst:wed] = inten
  133. return outten, [hst, hed, wst, wed]
  134. def eval_chip(self, crop):
  135. with torch.no_grad():
  136. out = self.net(crop)[0]
  137. prob = F.softmax(out, 1)
  138. if self.flip:
  139. crop = torch.flip(crop, dims=(3,))
  140. out = self.net(crop)[0]
  141. out = torch.flip(out, dims=(3,))
  142. prob += F.softmax(out, 1)
  143. prob = torch.exp(prob)
  144. return prob
  145. def crop_eval(self, im):
  146. cropsize = self.cropsize
  147. stride_rate = 5/6.
  148. N, C, H, W = im.size()
  149. long_size, short_size = (H,W) if H>W else (W,H)
  150. if long_size < cropsize:
  151. im, indices = self.pad_tensor(im, (cropsize, cropsize))
  152. prob = self.eval_chip(im)
  153. prob = prob[:, :, indices[0]:indices[1], indices[2]:indices[3]]
  154. else:
  155. stride = math.ceil(cropsize*stride_rate)
  156. if short_size < cropsize:
  157. if H < W:
  158. im, indices = self.pad_tensor(im, (cropsize, W))
  159. else:
  160. im, indices = self.pad_tensor(im, (H, cropsize))
  161. N, C, H, W = im.size()
  162. n_x = math.ceil((W-cropsize)/stride)+1
  163. n_y = math.ceil((H-cropsize)/stride)+1
  164. prob = torch.zeros(N, self.n_classes, H, W).cuda()
  165. prob.requires_grad = False
  166. for iy in range(n_y):
  167. for ix in range(n_x):
  168. hed, wed = min(H, stride*iy+cropsize), min(W, stride*ix+cropsize)
  169. hst, wst = hed-cropsize, wed-cropsize
  170. chip = im[:, :, hst:hed, wst:wed]
  171. prob_chip = self.eval_chip(chip)
  172. prob[:, :, hst:hed, wst:wed] += prob_chip
  173. if short_size < cropsize:
  174. prob = prob[:, :, indices[0]:indices[1], indices[2]:indices[3]]
  175. return prob
  176. def scale_crop_eval(self, im, scale):
  177. N, C, H, W = im.size()
  178. new_hw = [int(H*scale), int(W*scale)]
  179. im = F.interpolate(im, new_hw, mode='bilinear', align_corners=True)
  180. prob = self.crop_eval(im)
  181. prob = F.interpolate(prob, (H, W), mode='bilinear', align_corners=True)
  182. return prob
  183. def compute_hist(self, pred, lb):
  184. n_classes = self.n_classes
  185. ignore_idx = self.lb_ignore
  186. keep = np.logical_not(lb==ignore_idx)
  187. merge = pred[keep] * n_classes + lb[keep]
  188. hist = np.bincount(merge, minlength=n_classes**2)
  189. hist = hist.reshape((n_classes, n_classes))
  190. return hist
  191. def evaluate(self):
  192. ## evaluate
  193. n_classes = self.n_classes
  194. hist = np.zeros((n_classes, n_classes), dtype=np.float32)
  195. dloader = tqdm(self.dl)
  196. if dist.is_initialized() and not dist.get_rank()==0:
  197. dloader = self.dl
  198. for i, (imgs, label) in enumerate(dloader):
  199. N, _, H, W = label.shape
  200. probs = torch.zeros((N, self.n_classes, H, W))
  201. probs.requires_grad = False
  202. imgs = imgs.cuda()
  203. for sc in self.scales:
  204. # prob = self.scale_crop_eval(imgs, sc)
  205. prob = self.eval_chip(imgs)
  206. probs += prob.detach().cpu()
  207. probs = probs.data.numpy()
  208. preds = np.argmax(probs, axis=1)
  209. hist_once = self.compute_hist(preds, label.data.numpy().squeeze(1))
  210. hist = hist + hist_once
  211. IOUs = np.diag(hist) / (np.sum(hist, axis=0)+np.sum(hist, axis=1)-np.diag(hist))
  212. mIOU = np.mean(IOUs)
  213. return mIOU
  214. def evaluate(respth='./resv1_catnet/pths/', dspth='./data'):
  215. ## logger
  216. logger = logging.getLogger()
  217. ## model
  218. logger.info('\n')
  219. logger.info('===='*20)
  220. logger.info('evaluating the model ...\n')
  221. logger.info('setup and restore model')
  222. n_classes = 19
  223. net = BiSeNet(n_classes=n_classes)
  224. net.load_state_dict(torch.load(respth))
  225. net.cuda()
  226. net.eval()
  227. ## dataset
  228. batchsize = 5
  229. n_workers = 2
  230. # dsval = CityScapes(dspth, mode='val') # 原始
  231. dsval = Heliushuju(dspth, mode='test') # 改动
  232. dl = DataLoader(dsval,
  233. batch_size = batchsize,
  234. shuffle = False,
  235. num_workers = n_workers,
  236. drop_last = False)
  237. ## evaluator
  238. logger.info('compute the mIOU')
  239. evaluator = MscEval(net, dl, scales=[1], flip = False)
  240. ## eval
  241. mIOU = evaluator.evaluate()
  242. logger.info('mIOU is: {:.6f}'.format(mIOU))
  243. if __name__ == "__main__":
  244. log_dir = 'evaluation_logs/'
  245. if not os.path.exists(log_dir):
  246. os.makedirs(log_dir)
  247. setup_logger(log_dir)
  248. #STDC1-Seg50 mIoU 0.7222
  249. # evaluatev0('./checkpoints/STDC1-Seg/model_maxmIOU50.pth', dspth='./data', backbone='STDCNet813', scale=0.5,
  250. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  251. #STDC1-Seg75 mIoU 0.7450
  252. # evaluatev0('./checkpoints/STDC1-Seg/model_maxmIOU75.pth', dspth='./data', backbone='STDCNet813', scale=0.75,
  253. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  254. #STDC2-Seg50 mIoU 0.7424
  255. # evaluatev0('./checkpoints/STDC2-Seg/model_maxmIOU50.pth', dspth='./data', backbone='STDCNet1446', scale=0.5,
  256. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  257. #STDC2-Seg75 mIoU 0.7704
  258. evaluatev0('./checkpoints_1720/wurenji_train_STDC1-Seg/pths/model_maxmIOU50.pth', dspth='./data/segmentation/shuiyufenge_1720/', backbone='STDCNet813', scale=0.75,
  259. use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)