高速公路违停检测
Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

train.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. #!/usr/bin/python
  2. # -*- encoding: utf-8 -*-
  3. import os
  4. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  5. from logger import setup_logger
  6. from models.model_stages import BiSeNet
  7. # from heliushuju import Heliushuju
  8. from heliushuju_process import Heliushuju
  9. from loss.loss import OhemCELoss
  10. from loss.detail_loss import DetailAggregateLoss
  11. # from evaluation import MscEvalV0
  12. from evaluation_process import MscEvalV0
  13. from optimizer_loss import Optimizer
  14. import sys
  15. import torch
  16. import torch.nn as nn
  17. from torch.utils.data import DataLoader
  18. import torch.nn.functional as F
  19. import torch.distributed as dist
  20. import os.path as osp
  21. import logging
  22. import time
  23. import datetime
  24. import argparse
  25. logger = logging.getLogger()
  26. def str2bool(v):
  27. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  28. return True
  29. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  30. return False
  31. else:
  32. raise argparse.ArgumentTypeError('Unsupported value encountered.')
  33. def parse_args():
  34. parse = argparse.ArgumentParser()
  35. parse.add_argument(
  36. '--local_rank',
  37. dest = 'local_rank',
  38. type = int,
  39. default = -1, # yuanshi
  40. # default=0, # gaidong
  41. )
  42. parse.add_argument(
  43. '--n_workers_train',
  44. dest = 'n_workers_train',
  45. type = int,
  46. default = 8,####8
  47. )
  48. parse.add_argument(
  49. '--n_workers_val',
  50. dest = 'n_workers_val',
  51. type = int,
  52. default = 2,###0
  53. )
  54. parse.add_argument(
  55. '--n_img_per_gpu',
  56. dest = 'n_img_per_gpu',
  57. type = int,
  58. default = 8,
  59. )
  60. parse.add_argument(
  61. '--max_iter',
  62. dest = 'max_iter',
  63. type = int,
  64. default = 43000, # 60000
  65. )
  66. parse.add_argument(
  67. '--save_iter_sep',
  68. dest = 'save_iter_sep',
  69. type = int,
  70. default = 1000,
  71. )
  72. parse.add_argument(
  73. '--warmup_steps',
  74. dest = 'warmup_steps',
  75. type = int,
  76. default = 1000,
  77. )
  78. parse.add_argument(
  79. '--mode',
  80. dest = 'mode',
  81. type = str,
  82. default = 'train',
  83. )
  84. parse.add_argument(
  85. '--ckpt',
  86. dest = 'ckpt',
  87. type = str,
  88. default = None,
  89. )
  90. parse.add_argument(
  91. '--respath',
  92. dest = 'respath',
  93. type = str,
  94. # default = 'checkpoints_1720/wurenji_train_STDC1-Seg', # 原始
  95. default='./model_save', # 改动
  96. )
  97. parse.add_argument(
  98. '--backbone',
  99. dest = 'backbone',
  100. type = str,
  101. default = 'STDCNet813',##'CatNetSmall'
  102. )
  103. parse.add_argument(
  104. '--pretrain_path',
  105. dest = 'pretrain_path',
  106. type = str,
  107. default='./checkpoints2/STDCNet813M_73.91.tar',
  108. )
  109. parse.add_argument(
  110. '--use_conv_last',
  111. dest = 'use_conv_last',
  112. type = str2bool,
  113. default = False,
  114. )
  115. parse.add_argument(
  116. '--use_boundary_2',
  117. dest = 'use_boundary_2',
  118. type = str2bool,
  119. default = False,
  120. )
  121. parse.add_argument(
  122. '--use_boundary_4',
  123. dest = 'use_boundary_4',
  124. type = str2bool,
  125. default = False,
  126. )
  127. parse.add_argument(
  128. '--use_boundary_8',
  129. dest = 'use_boundary_8',
  130. type = str2bool,
  131. default = True, # False
  132. )
  133. parse.add_argument(
  134. '--use_boundary_16',
  135. dest = 'use_boundary_16',
  136. type = str2bool,
  137. default = False,
  138. )
  139. return parse.parse_args()
  140. def train():
  141. args = parse_args()
  142. save_pth_path = os.path.join(args.respath, 'pths')
  143. dspth = './data/'
  144. # print(save_pth_path)
  145. # print(osp.exists(save_pth_path))
  146. # if not osp.exists(save_pth_path) and dist.get_rank()==0:
  147. if not osp.exists(save_pth_path):
  148. os.makedirs(save_pth_path)
  149. torch.cuda.set_device(args.local_rank)
  150. ########################################################################fenbushi
  151. # dist.init_process_group(
  152. # backend = 'nccl',
  153. # init_method = 'tcp://127.0.0.1:33274',
  154. # world_size = torch.cuda.device_count(),
  155. # rank=args.local_rank
  156. # )
  157. setup_logger(args.respath)
  158. ## dataset
  159. # n_classes = 2 # 原始
  160. n_classes = 3 # 改动
  161. n_img_per_gpu = args.n_img_per_gpu
  162. n_workers_train = args.n_workers_train
  163. n_workers_val = args.n_workers_val
  164. use_boundary_16 = args.use_boundary_16
  165. use_boundary_8 = args.use_boundary_8
  166. use_boundary_4 = args.use_boundary_4
  167. use_boundary_2 = args.use_boundary_2
  168. mode = args.mode # train
  169. cropsize = [1024, 512]
  170. randomscale = (0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)
  171. ##################################################################################################fenbushi
  172. # if dist.get_rank()==0:
  173. # logger.info('n_workers_train: {}'.format(n_workers_train))
  174. # logger.info('n_workers_val: {}'.format(n_workers_val))
  175. # logger.info('use_boundary_2: {}'.format(use_boundary_2))
  176. # logger.info('use_boundary_4: {}'.format(use_boundary_4))
  177. # logger.info('use_boundary_8: {}'.format(use_boundary_8))
  178. # logger.info('use_boundary_16: {}'.format(use_boundary_16))
  179. # logger.info('mode: {}'.format(args.mode))
  180. ds = Heliushuju(dspth, cropsize=cropsize, mode=mode, randomscale=randomscale)
  181. sampler = None
  182. # #################################################################################################fenbushi
  183. # sampler = torch.utils.data.distributed.DistributedSampler(ds)
  184. dl = DataLoader(ds,
  185. batch_size = n_img_per_gpu,
  186. shuffle = False,
  187. sampler = sampler,
  188. num_workers = n_workers_train,
  189. pin_memory = False,
  190. drop_last = True)
  191. # exit(0)
  192. dsval = Heliushuju(dspth, mode='val', randomscale=randomscale)
  193. # x,y = ds[0]
  194. # x, y = dsval[0]
  195. # sys.exit(0)
  196. sampler_val = None
  197. ##################################################################################################fenbushi
  198. # sampler_val = torch.utils.data.distributed.DistributedSampler(dsval)
  199. dlval = DataLoader(dsval,
  200. batch_size = 1,
  201. shuffle = False,
  202. sampler = sampler_val,
  203. num_workers = n_workers_val,
  204. drop_last = False)
  205. ## model
  206. ignore_idx = 255
  207. net = BiSeNet(backbone=args.backbone, n_classes=n_classes, pretrain_model=args.pretrain_path,
  208. use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4, use_boundary_8=use_boundary_8,
  209. use_boundary_16=use_boundary_16, use_conv_last=args.use_conv_last)
  210. if not args.ckpt is None:
  211. net.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
  212. net.cuda()
  213. net.train()
  214. ##################################################################################################fenbushi
  215. # net = nn.parallel.DistributedDataParallel(net,
  216. # device_ids = [args.local_rank, ],
  217. # output_device = args.local_rank,
  218. # find_unused_parameters=True
  219. # )
  220. net = nn.DataParallel(net, device_ids=[0])###########################################################################
  221. score_thres = 0.7
  222. n_min = n_img_per_gpu*cropsize[0]*cropsize[1]//16
  223. criteria_p = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
  224. criteria_16 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
  225. criteria_32 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
  226. boundary_loss_func = DetailAggregateLoss()
  227. ## optimizer
  228. maxmIOU50 = 0.
  229. maxmIOU75 = 0.
  230. momentum = 0.9
  231. weight_decay = 5e-4
  232. lr_start = 1e-2
  233. max_iter = args.max_iter
  234. save_iter_sep = args.save_iter_sep
  235. power = 0.9
  236. warmup_steps = args.warmup_steps
  237. warmup_start_lr = 1e-5
  238. ##################################################################################################fenbushi
  239. # if dist.get_rank()==0:
  240. # print('max_iter: ', max_iter)
  241. # print('save_iter_sep: ', save_iter_sep)
  242. # print('warmup_steps: ', warmup_steps)
  243. print('max_iter: ', max_iter)
  244. print('save_iter_sep: ', save_iter_sep)
  245. print('warmup_steps: ', warmup_steps)
  246. optim = Optimizer(
  247. model = net.module,
  248. loss = boundary_loss_func,
  249. lr0 = lr_start,
  250. momentum = momentum,
  251. wd = weight_decay,
  252. warmup_steps = warmup_steps,
  253. warmup_start_lr = warmup_start_lr,
  254. max_iter = max_iter,
  255. power = power)
  256. ## train loop
  257. msg_iter = 50
  258. loss_avg = []
  259. loss_boundery_bce = []
  260. loss_boundery_dice = []
  261. st = glob_st = time.time()
  262. diter = iter(dl)
  263. # diter = enumerate(dl)
  264. epoch = 0
  265. for it in range(max_iter):
  266. try:
  267. im, lb = diter.__next__()
  268. # print(im.size()[0])
  269. # im, lb = next(diter)
  270. if not im.size()[0]==n_img_per_gpu: raise StopIteration
  271. except StopIteration:
  272. epoch += 1
  273. # sampler.set_epoch(epoch)
  274. diter = iter(dl)
  275. im, lb = next(diter)
  276. im = im.cuda()
  277. lb = lb.cuda()
  278. H, W = im.size()[2:]
  279. lb = torch.squeeze(lb, 1) # lb.shape : torch.Size([8, 360, 640])
  280. # print("11111111111111111111")
  281. # print(lb.shape)
  282. # print("111111111111111111")
  283. # lb = torch.argmax(lb, dim=3) # 添加(训练高速路时,需要添加这行代码,训练水域分割时,将这行代码注释掉)
  284. optim.zero_grad()
  285. if use_boundary_2 and use_boundary_4 and use_boundary_8:
  286. out, out16, out32, detail2, detail4, detail8 = net(im)
  287. if (not use_boundary_2) and use_boundary_4 and use_boundary_8:
  288. out, out16, out32, detail4, detail8 = net(im)
  289. if (not use_boundary_2) and (not use_boundary_4) and use_boundary_8:#######True
  290. out, out16, out32, detail8 = net(im)
  291. if (not use_boundary_2) and (not use_boundary_4) and (not use_boundary_8):
  292. out, out16, out32 = net(im)
  293. # lossp = criteria_p(out, lb)
  294. # loss2 = criteria_16(out16, lb)
  295. # loss3 = criteria_32(out32, lb)
  296. # out=torch.tensor(out, dtype=torch.float64)
  297. # out16=torch.tensor(out16, dtype=torch.float64)
  298. # out32=torch.tensor(out32, dtype=torch.float64)
  299. # out=out.long()
  300. # out16=out16.long()
  301. # out32=out32.long()
  302. # lb=lb.long()
  303. lossp = criteria_p(out, lb)
  304. loss2 = criteria_16(out16, lb)
  305. loss3 = criteria_32(out32, lb)
  306. boundery_bce_loss = 0.
  307. boundery_dice_loss = 0.
  308. if use_boundary_2:
  309. # if dist.get_rank()==0:
  310. # print('use_boundary_2')
  311. boundery_bce_loss2, boundery_dice_loss2 = boundary_loss_func(detail2, lb)
  312. boundery_bce_loss += boundery_bce_loss2
  313. boundery_dice_loss += boundery_dice_loss2
  314. if use_boundary_4:
  315. # if dist.get_rank()==0:
  316. # print('use_boundary_4')
  317. boundery_bce_loss4, boundery_dice_loss4 = boundary_loss_func(detail4, lb)
  318. boundery_bce_loss += boundery_bce_loss4
  319. boundery_dice_loss += boundery_dice_loss4
  320. if use_boundary_8:######
  321. # if dist.get_rank()==0:
  322. # print('use_boundary_8')
  323. boundery_bce_loss8, boundery_dice_loss8 = boundary_loss_func(detail8, lb)
  324. boundery_bce_loss += boundery_bce_loss8
  325. boundery_dice_loss += boundery_dice_loss8
  326. loss = lossp + loss2 + loss3 + boundery_bce_loss + boundery_dice_loss
  327. loss.backward()
  328. optim.step()
  329. loss_avg.append(loss.item())
  330. loss_boundery_bce.append(boundery_bce_loss.item())
  331. loss_boundery_dice.append(boundery_dice_loss.item())
  332. ## print training log message
  333. if (it+1)%msg_iter==0:
  334. loss_avg = sum(loss_avg) / len(loss_avg)
  335. lr = optim.lr
  336. ed = time.time()
  337. t_intv, glob_t_intv = ed - st, ed - glob_st
  338. eta = int((max_iter - it) * (glob_t_intv / it))
  339. eta = str(datetime.timedelta(seconds=eta))
  340. loss_boundery_bce_avg = sum(loss_boundery_bce) / len(loss_boundery_bce)
  341. loss_boundery_dice_avg = sum(loss_boundery_dice) / len(loss_boundery_dice)
  342. msg = ', '.join([
  343. 'it: {it}/{max_it}',
  344. 'lr: {lr:4f}',
  345. 'loss: {loss:.4f}',
  346. 'boundery_bce_loss: {boundery_bce_loss:.4f}',
  347. 'boundery_dice_loss: {boundery_dice_loss:.4f}',
  348. 'eta: {eta}',
  349. 'time: {time:.4f}',
  350. ]).format(
  351. it = it+1,
  352. max_it = max_iter,
  353. lr = lr,
  354. loss = loss_avg,
  355. boundery_bce_loss = loss_boundery_bce_avg,
  356. boundery_dice_loss = loss_boundery_dice_avg,
  357. time = t_intv,
  358. eta = eta
  359. )
  360. logger.info(msg)
  361. loss_avg = []
  362. loss_boundery_bce = []
  363. loss_boundery_dice = []
  364. st = ed
  365. # print(boundary_loss_func.get_params())
  366. if (it+1)%save_iter_sep==0:# and it != 0:
  367. ## model
  368. logger.info('evaluating the model ...')
  369. logger.info('setup and restore model')
  370. net.eval()
  371. # ## evaluator
  372. logger.info('compute the mIOU')
  373. with torch.no_grad():
  374. single_scale1 = MscEvalV0()
  375. mIOU50 = single_scale1(net, dlval, n_classes)
  376. single_scale2 = MscEvalV0(scale=0.75)
  377. mIOU75 = single_scale2(net, dlval, n_classes)
  378. save_pth = osp.join(save_pth_path, 'model_iter{}_mIOU50_{}_mIOU75_{}.pth'
  379. .format(it+1, str(round(mIOU50, 4)), str(round(mIOU75, 4))))
  380. state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
  381. # if dist.get_rank()==0:
  382. torch.save(state, save_pth)
  383. logger.info('training iteration {}, model saved to: {}'.format(it+1, save_pth))
  384. if mIOU50 > maxmIOU50:
  385. maxmIOU50 = mIOU50
  386. save_pth = osp.join(save_pth_path, 'model_maxmIOU50.pth'.format(it+1))
  387. state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
  388. # if dist.get_rank()==0:
  389. torch.save(state, save_pth)
  390. logger.info('max mIOU model saved to: {}'.format(save_pth))
  391. if mIOU75 > maxmIOU75:
  392. maxmIOU75 = mIOU75
  393. save_pth = osp.join(save_pth_path, 'model_maxmIOU75.pth'.format(it+1))
  394. state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
  395. # if dist.get_rank()==0: torch.save(state, save_pth)
  396. torch.save(state, save_pth)
  397. logger.info('max mIOU model saved to: {}'.format(save_pth))
  398. logger.info('mIOU50 is: {}, mIOU75 is: {}'.format(mIOU50, mIOU75))
  399. logger.info('maxmIOU50 is: {}, maxmIOU75 is: {}.'.format(maxmIOU50, maxmIOU75))
  400. net.train()
  401. ## dump the final model
  402. save_pth = osp.join(save_pth_path, 'model_final.pth')
  403. net.cpu()
  404. state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
  405. # if dist.get_rank()==0: torch.save(state, save_pth)
  406. torch.save(state, save_pth)
  407. logger.info('training done, model saved to: {}'.format(save_pth))
  408. print('epoch: ', epoch)
  409. if __name__ == "__main__":
  410. train()