交通事故检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

322 lines
11KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. from logger import setup_logger
  4. from models.model_stages import BiSeNet
  5. from cityscapes import CityScapes
  6. import torch
  7. import torch.nn as nn
  8. from torch.utils.data import DataLoader
  9. import torch.nn.functional as F
  10. import torch.distributed as dist
  11. import os
  12. import os.path as osp
  13. import logging
  14. import time
  15. import numpy as np
  16. from tqdm import tqdm
  17. import math
  18. class MscEvalV0(object):
  19. def __init__(self, scale=0.5, ignore_label=255):
  20. self.ignore_label = ignore_label
  21. self.scale = scale
  22. def __call__(self, net, dl, n_classes):
  23. # evaluate
  24. hist = torch.zeros(n_classes, n_classes).cuda().detach()
  25. if dist.is_initialized() and dist.get_rank() != 0:
  26. diter = enumerate(dl)
  27. else:
  28. diter = enumerate(tqdm(dl))
  29. for i, (imgs, label) in diter:
  30. # label = torch.argmax(label, dim=4) # 添加
  31. # print("11111111111111111111111111111111")
  32. # print(label.shape)
  33. # print("2222222222222222222222222222222222")
  34. N, _, H, W = label.shape # 原始
  35. # N, _, H, W = label.shape[0:-1] # 改动
  36. label = label.squeeze(1).cuda() # 原始
  37. # label = label.cuda() # 改动
  38. # print("33333333333333333333333333")
  39. # print(label.shape)
  40. # print("55555555555555555555555555")
  41. size = label.size()[-2:]
  42. imgs = imgs.cuda()
  43. N, C, H, W = imgs.size()
  44. new_hw = [int(H*self.scale), int(W*self.scale)]
  45. imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
  46. logits = net(imgs)[0]
  47. logits = F.interpolate(logits, size=size, mode='bilinear', align_corners=True)
  48. probs = torch.softmax(logits, dim=1)
  49. preds = torch.argmax(probs, dim=1)
  50. keep = label != self.ignore_label
  51. # print("333333333333333333333")
  52. # print(keep)
  53. # print("666666666666666666666666")
  54. hist += torch.bincount(label[keep] * n_classes + preds[keep], minlength=n_classes ** 2).view(n_classes, n_classes).float() # 原始
  55. if dist.is_initialized():
  56. dist.all_reduce(hist, dist.ReduceOp.SUM)
  57. # print("1111111111111111111111111111")
  58. # print(hist.sum(dim=0))
  59. # print("222222222222222222222222222")
  60. # print(hist.sum(dim=1))
  61. # print("3333333333333333333333333333333")
  62. # print(hist.diag())
  63. # print("5555555555555555555555555555555")
  64. ious = hist.diag() / (hist.sum(dim=0) + hist.sum(dim=1) - hist.diag())
  65. # print("6666666666666666666666666666666666")
  66. # print(ious)
  67. # print("7777777777777777777777777777777777")
  68. miou = ious.mean()
  69. # print("88888888888888888888888888888888888")
  70. # print(miou)
  71. # print("99999999999999999999999999999999999")
  72. # print("111111111111111111111111111111111111")
  73. # print(miou.item())
  74. # print("222222222222222222222222222222222222")
  75. return miou.item()
  76. def evaluatev0(respth='./pretrained', dspth='./data', backbone='CatNetSmall', scale=0.75, use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  77. print('scale', scale)
  78. print('use_boundary_2', use_boundary_2)
  79. print('use_boundary_4', use_boundary_4)
  80. print('use_boundary_8', use_boundary_8)
  81. print('use_boundary_16', use_boundary_16)
  82. ## dataset
  83. batchsize = 5
  84. n_workers = 2
  85. dsval = CityScapes(dspth, mode='val')
  86. dl = DataLoader(dsval,
  87. batch_size = batchsize,
  88. shuffle = False,
  89. num_workers = n_workers,
  90. drop_last = False)
  91. n_classes = 19
  92. print("backbone:", backbone)
  93. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  94. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  95. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  96. use_conv_last=use_conv_last)
  97. net.load_state_dict(torch.load(respth))
  98. net.cuda()
  99. net.eval()
  100. with torch.no_grad():
  101. single_scale = MscEvalV0(scale=scale)
  102. mIOU = single_scale(net, dl, 19)
  103. logger = logging.getLogger()
  104. logger.info('mIOU is: %s\n', mIOU)
  105. class MscEval(object):
  106. def __init__(self,
  107. model,
  108. dataloader,
  109. scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75],
  110. n_classes = 19,
  111. lb_ignore = 255,
  112. cropsize = 1024,
  113. flip = True,
  114. *args, **kwargs):
  115. self.scales = scales
  116. self.n_classes = n_classes
  117. self.lb_ignore = lb_ignore
  118. self.flip = flip
  119. self.cropsize = cropsize
  120. ## dataloader
  121. self.dl = dataloader
  122. self.net = model
  123. def pad_tensor(self, inten, size):
  124. N, C, H, W = inten.size()
  125. outten = torch.zeros(N, C, size[0], size[1]).cuda()
  126. outten.requires_grad = False
  127. margin_h, margin_w = size[0]-H, size[1]-W
  128. hst, hed = margin_h//2, margin_h//2+H
  129. wst, wed = margin_w//2, margin_w//2+W
  130. outten[:, :, hst:hed, wst:wed] = inten
  131. return outten, [hst, hed, wst, wed]
  132. def eval_chip(self, crop):
  133. with torch.no_grad():
  134. out = self.net(crop)[0]
  135. prob = F.softmax(out, 1)
  136. if self.flip:
  137. crop = torch.flip(crop, dims=(3,))
  138. out = self.net(crop)[0]
  139. out = torch.flip(out, dims=(3,))
  140. prob += F.softmax(out, 1)
  141. prob = torch.exp(prob)
  142. return prob
  143. def crop_eval(self, im):
  144. cropsize = self.cropsize
  145. stride_rate = 5/6.
  146. N, C, H, W = im.size()
  147. long_size, short_size = (H,W) if H>W else (W,H)
  148. if long_size < cropsize:
  149. im, indices = self.pad_tensor(im, (cropsize, cropsize))
  150. prob = self.eval_chip(im)
  151. prob = prob[:, :, indices[0]:indices[1], indices[2]:indices[3]]
  152. else:
  153. stride = math.ceil(cropsize*stride_rate)
  154. if short_size < cropsize:
  155. if H < W:
  156. im, indices = self.pad_tensor(im, (cropsize, W))
  157. else:
  158. im, indices = self.pad_tensor(im, (H, cropsize))
  159. N, C, H, W = im.size()
  160. n_x = math.ceil((W-cropsize)/stride)+1
  161. n_y = math.ceil((H-cropsize)/stride)+1
  162. prob = torch.zeros(N, self.n_classes, H, W).cuda()
  163. prob.requires_grad = False
  164. for iy in range(n_y):
  165. for ix in range(n_x):
  166. hed, wed = min(H, stride*iy+cropsize), min(W, stride*ix+cropsize)
  167. hst, wst = hed-cropsize, wed-cropsize
  168. chip = im[:, :, hst:hed, wst:wed]
  169. prob_chip = self.eval_chip(chip)
  170. prob[:, :, hst:hed, wst:wed] += prob_chip
  171. if short_size < cropsize:
  172. prob = prob[:, :, indices[0]:indices[1], indices[2]:indices[3]]
  173. return prob
  174. def scale_crop_eval(self, im, scale):
  175. N, C, H, W = im.size()
  176. new_hw = [int(H*scale), int(W*scale)]
  177. im = F.interpolate(im, new_hw, mode='bilinear', align_corners=True)
  178. prob = self.crop_eval(im)
  179. prob = F.interpolate(prob, (H, W), mode='bilinear', align_corners=True)
  180. return prob
  181. def compute_hist(self, pred, lb):
  182. n_classes = self.n_classes
  183. ignore_idx = self.lb_ignore
  184. keep = np.logical_not(lb==ignore_idx)
  185. merge = pred[keep] * n_classes + lb[keep]
  186. hist = np.bincount(merge, minlength=n_classes**2)
  187. hist = hist.reshape((n_classes, n_classes))
  188. return hist
  189. def evaluate(self):
  190. ## evaluate
  191. n_classes = self.n_classes
  192. hist = np.zeros((n_classes, n_classes), dtype=np.float32)
  193. dloader = tqdm(self.dl)
  194. if dist.is_initialized() and not dist.get_rank()==0:
  195. dloader = self.dl
  196. for i, (imgs, label) in enumerate(dloader):
  197. N, _, H, W = label.shape
  198. probs = torch.zeros((N, self.n_classes, H, W))
  199. probs.requires_grad = False
  200. imgs = imgs.cuda()
  201. for sc in self.scales:
  202. # prob = self.scale_crop_eval(imgs, sc)
  203. prob = self.eval_chip(imgs)
  204. probs += prob.detach().cpu()
  205. probs = probs.data.numpy()
  206. preds = np.argmax(probs, axis=1)
  207. hist_once = self.compute_hist(preds, label.data.numpy().squeeze(1))
  208. hist = hist + hist_once
  209. IOUs = np.diag(hist) / (np.sum(hist, axis=0)+np.sum(hist, axis=1)-np.diag(hist))
  210. mIOU = np.mean(IOUs)
  211. return mIOU
  212. def evaluate(respth='./resv1_catnet/pths/', dspth='./data'):
  213. ## logger
  214. logger = logging.getLogger()
  215. ## model
  216. logger.info('\n')
  217. logger.info('===='*20)
  218. logger.info('evaluating the model ...\n')
  219. logger.info('setup and restore model')
  220. n_classes = 19
  221. net = BiSeNet(n_classes=n_classes)
  222. net.load_state_dict(torch.load(respth))
  223. net.cuda()
  224. net.eval()
  225. ## dataset
  226. batchsize = 5
  227. n_workers = 2
  228. dsval = CityScapes(dspth, mode='val')
  229. dl = DataLoader(dsval,
  230. batch_size = batchsize,
  231. shuffle = False,
  232. num_workers = n_workers,
  233. drop_last = False)
  234. ## evaluator
  235. logger.info('compute the mIOU')
  236. evaluator = MscEval(net, dl, scales=[1], flip = False)
  237. ## eval
  238. mIOU = evaluator.evaluate()
  239. logger.info('mIOU is: {:.6f}'.format(mIOU))
  240. if __name__ == "__main__":
  241. log_dir = 'evaluation_logs/'
  242. if not os.path.exists(log_dir):
  243. os.makedirs(log_dir)
  244. setup_logger(log_dir)
  245. #STDC1-Seg50 mIoU 0.7222
  246. # evaluatev0('./checkpoints/STDC1-Seg/model_maxmIOU50.pth', dspth='./data', backbone='STDCNet813', scale=0.5,
  247. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  248. #STDC1-Seg75 mIoU 0.7450
  249. # evaluatev0('./checkpoints/STDC1-Seg/model_maxmIOU75.pth', dspth='./data', backbone='STDCNet813', scale=0.75,
  250. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  251. #STDC2-Seg50 mIoU 0.7424
  252. # evaluatev0('./checkpoints/STDC2-Seg/model_maxmIOU50.pth', dspth='./data', backbone='STDCNet1446', scale=0.5,
  253. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  254. #STDC2-Seg75 mIoU 0.7704
  255. # evaluatev0('./checkpoints/STDC2-Seg/model_maxmIOU75.pth', dspth='./data', backbone='STDCNet1446', scale=0.75,
  256. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  257. evaluatev0('./checkpoints_1720/wurenji_train_STDC1-Seg/pths/model_maxmIOU75.pth',
  258. dspth='./data/segmentation/shuiyufenge_1720/', backbone='STDCNet1446', scale=0.75,
  259. use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)