交通事故检测代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

282 lines
9.4KB

  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. from logger import setup_logger
  4. from models.model_stages import BiSeNet
  5. from cityscapes import CityScapes
  6. import torch
  7. import torch.nn as nn
  8. from torch.utils.data import DataLoader
  9. import torch.nn.functional as F
  10. import torch.distributed as dist
  11. import os
  12. import os.path as osp
  13. import logging
  14. import time
  15. import numpy as np
  16. from tqdm import tqdm
  17. import math
  18. class MscEvalV0(object):
  19. def __init__(self, scale=0.5, ignore_label=255):
  20. self.ignore_label = ignore_label
  21. self.scale = scale
  22. def __call__(self, net, dl, n_classes):
  23. ## evaluate
  24. hist = torch.zeros(n_classes, n_classes).cuda().detach()
  25. if dist.is_initialized() and dist.get_rank() != 0:
  26. diter = enumerate(dl)
  27. else:
  28. diter = enumerate(tqdm(dl))
  29. for i, (imgs, label) in diter:
  30. N, _, H, W = label.shape
  31. label = label.squeeze(1).cuda()
  32. size = label.size()[-2:]
  33. imgs = imgs.cuda()
  34. N, C, H, W = imgs.size()
  35. new_hw = [int(H*self.scale), int(W*self.scale)]
  36. imgs = F.interpolate(imgs, new_hw, mode='bilinear', align_corners=True)
  37. logits = net(imgs)[0]
  38. logits = F.interpolate(logits, size=size,
  39. mode='bilinear', align_corners=True)
  40. probs = torch.softmax(logits, dim=1)
  41. preds = torch.argmax(probs, dim=1)
  42. keep = label != self.ignore_label
  43. hist += torch.bincount(
  44. label[keep] * n_classes + preds[keep],
  45. minlength=n_classes ** 2
  46. ).view(n_classes, n_classes).float()
  47. if dist.is_initialized():
  48. dist.all_reduce(hist, dist.ReduceOp.SUM)
  49. ious = hist.diag() / (hist.sum(dim=0) + hist.sum(dim=1) - hist.diag())
  50. miou = ious.mean()
  51. return miou.item()
  52. def evaluatev0(respth='./pretrained', dspth='./data', backbone='CatNetSmall', scale=0.75, use_boundary_2=False, use_boundary_4=False, use_boundary_8=False, use_boundary_16=False, use_conv_last=False):
  53. print('scale', scale)
  54. print('use_boundary_2', use_boundary_2)
  55. print('use_boundary_4', use_boundary_4)
  56. print('use_boundary_8', use_boundary_8)
  57. print('use_boundary_16', use_boundary_16)
  58. ## dataset
  59. batchsize = 5
  60. n_workers = 2
  61. dsval = CityScapes(dspth, mode='val')
  62. dl = DataLoader(dsval,
  63. batch_size = batchsize,
  64. shuffle = False,
  65. num_workers = n_workers,
  66. drop_last = False)
  67. n_classes = 19
  68. print("backbone:", backbone)
  69. net = BiSeNet(backbone=backbone, n_classes=n_classes,
  70. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4,
  71. use_boundary_8=use_boundary_8, use_boundary_16=use_boundary_16,
  72. use_conv_last=use_conv_last)
  73. net.load_state_dict(torch.load(respth))
  74. net.cuda()
  75. net.eval()
  76. with torch.no_grad():
  77. single_scale = MscEvalV0(scale=scale)
  78. mIOU = single_scale(net, dl, 19)
  79. logger = logging.getLogger()
  80. logger.info('mIOU is: %s\n', mIOU)
  81. class MscEval(object):
  82. def __init__(self,
  83. model,
  84. dataloader,
  85. scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75],
  86. n_classes = 19,
  87. lb_ignore = 255,
  88. cropsize = 1024,
  89. flip = True,
  90. *args, **kwargs):
  91. self.scales = scales
  92. self.n_classes = n_classes
  93. self.lb_ignore = lb_ignore
  94. self.flip = flip
  95. self.cropsize = cropsize
  96. ## dataloader
  97. self.dl = dataloader
  98. self.net = model
  99. def pad_tensor(self, inten, size):
  100. N, C, H, W = inten.size()
  101. outten = torch.zeros(N, C, size[0], size[1]).cuda()
  102. outten.requires_grad = False
  103. margin_h, margin_w = size[0]-H, size[1]-W
  104. hst, hed = margin_h//2, margin_h//2+H
  105. wst, wed = margin_w//2, margin_w//2+W
  106. outten[:, :, hst:hed, wst:wed] = inten
  107. return outten, [hst, hed, wst, wed]
  108. def eval_chip(self, crop):
  109. with torch.no_grad():
  110. out = self.net(crop)[0]
  111. prob = F.softmax(out, 1)
  112. if self.flip:
  113. crop = torch.flip(crop, dims=(3,))
  114. out = self.net(crop)[0]
  115. out = torch.flip(out, dims=(3,))
  116. prob += F.softmax(out, 1)
  117. prob = torch.exp(prob)
  118. return prob
  119. def crop_eval(self, im):
  120. cropsize = self.cropsize
  121. stride_rate = 5/6.
  122. N, C, H, W = im.size()
  123. long_size, short_size = (H,W) if H>W else (W,H)
  124. if long_size < cropsize:
  125. im, indices = self.pad_tensor(im, (cropsize, cropsize))
  126. prob = self.eval_chip(im)
  127. prob = prob[:, :, indices[0]:indices[1], indices[2]:indices[3]]
  128. else:
  129. stride = math.ceil(cropsize*stride_rate)
  130. if short_size < cropsize:
  131. if H < W:
  132. im, indices = self.pad_tensor(im, (cropsize, W))
  133. else:
  134. im, indices = self.pad_tensor(im, (H, cropsize))
  135. N, C, H, W = im.size()
  136. n_x = math.ceil((W-cropsize)/stride)+1
  137. n_y = math.ceil((H-cropsize)/stride)+1
  138. prob = torch.zeros(N, self.n_classes, H, W).cuda()
  139. prob.requires_grad = False
  140. for iy in range(n_y):
  141. for ix in range(n_x):
  142. hed, wed = min(H, stride*iy+cropsize), min(W, stride*ix+cropsize)
  143. hst, wst = hed-cropsize, wed-cropsize
  144. chip = im[:, :, hst:hed, wst:wed]
  145. prob_chip = self.eval_chip(chip)
  146. prob[:, :, hst:hed, wst:wed] += prob_chip
  147. if short_size < cropsize:
  148. prob = prob[:, :, indices[0]:indices[1], indices[2]:indices[3]]
  149. return prob
  150. def scale_crop_eval(self, im, scale):
  151. N, C, H, W = im.size()
  152. new_hw = [int(H*scale), int(W*scale)]
  153. im = F.interpolate(im, new_hw, mode='bilinear', align_corners=True)
  154. prob = self.crop_eval(im)
  155. prob = F.interpolate(prob, (H, W), mode='bilinear', align_corners=True)
  156. return prob
  157. def compute_hist(self, pred, lb):
  158. n_classes = self.n_classes
  159. ignore_idx = self.lb_ignore
  160. keep = np.logical_not(lb==ignore_idx)
  161. merge = pred[keep] * n_classes + lb[keep]
  162. hist = np.bincount(merge, minlength=n_classes**2)
  163. hist = hist.reshape((n_classes, n_classes))
  164. return hist
  165. def evaluate(self):
  166. ## evaluate
  167. n_classes = self.n_classes
  168. hist = np.zeros((n_classes, n_classes), dtype=np.float32)
  169. dloader = tqdm(self.dl)
  170. if dist.is_initialized() and not dist.get_rank()==0:
  171. dloader = self.dl
  172. for i, (imgs, label) in enumerate(dloader):
  173. N, _, H, W = label.shape
  174. probs = torch.zeros((N, self.n_classes, H, W))
  175. probs.requires_grad = False
  176. imgs = imgs.cuda()
  177. for sc in self.scales:
  178. # prob = self.scale_crop_eval(imgs, sc)
  179. prob = self.eval_chip(imgs)
  180. probs += prob.detach().cpu()
  181. probs = probs.data.numpy()
  182. preds = np.argmax(probs, axis=1)
  183. hist_once = self.compute_hist(preds, label.data.numpy().squeeze(1))
  184. hist = hist + hist_once
  185. IOUs = np.diag(hist) / (np.sum(hist, axis=0)+np.sum(hist, axis=1)-np.diag(hist))
  186. mIOU = np.mean(IOUs)
  187. return mIOU
  188. def evaluate(respth='./resv1_catnet/pths/', dspth='./data'):
  189. ## logger
  190. logger = logging.getLogger()
  191. ## model
  192. logger.info('\n')
  193. logger.info('===='*20)
  194. logger.info('evaluating the model ...\n')
  195. logger.info('setup and restore model')
  196. n_classes = 19
  197. net = BiSeNet(n_classes=n_classes)
  198. net.load_state_dict(torch.load(respth))
  199. net.cuda()
  200. net.eval()
  201. ## dataset
  202. batchsize = 5
  203. n_workers = 2
  204. dsval = CityScapes(dspth, mode='val')
  205. dl = DataLoader(dsval,
  206. batch_size = batchsize,
  207. shuffle = False,
  208. num_workers = n_workers,
  209. drop_last = False)
  210. ## evaluator
  211. logger.info('compute the mIOU')
  212. evaluator = MscEval(net, dl, scales=[1], flip = False)
  213. ## eval
  214. mIOU = evaluator.evaluate()
  215. logger.info('mIOU is: {:.6f}'.format(mIOU))
  216. if __name__ == "__main__":
  217. log_dir = 'evaluation_logs/'
  218. if not os.path.exists(log_dir):
  219. os.makedirs(log_dir)
  220. setup_logger(log_dir)
  221. #STDC1-Seg50 mIoU 0.7222
  222. # evaluatev0('./checkpoints/STDC1-Seg/model_maxmIOU50.pth', dspth='./data', backbone='STDCNet813', scale=0.5,
  223. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  224. #STDC1-Seg75 mIoU 0.7450
  225. # evaluatev0('./checkpoints/STDC1-Seg/model_maxmIOU75.pth', dspth='./data', backbone='STDCNet813', scale=0.75,
  226. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  227. #STDC2-Seg50 mIoU 0.7424
  228. # evaluatev0('./checkpoints/STDC2-Seg/model_maxmIOU50.pth', dspth='./data', backbone='STDCNet1446', scale=0.5,
  229. # use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)
  230. #STDC2-Seg75 mIoU 0.7704
  231. evaluatev0('./checkpoints/STDC2-Seg/model_maxmIOU75.pth', dspth='./data', backbone='STDCNet1446', scale=0.75,
  232. use_boundary_2=False, use_boundary_4=False, use_boundary_8=True, use_boundary_16=False)