|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470 |
- #!/usr/bin/python
- # -*- encoding: utf-8 -*-
- import os
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- from logger import setup_logger
- from models.model_stages import BiSeNet
- # from heliushuju import Heliushuju
- from heliushuju_process import Heliushuju
- from loss.loss import OhemCELoss
- from loss.detail_loss import DetailAggregateLoss
- # from evaluation import MscEvalV0
- from evaluation_process import MscEvalV0
- from optimizer_loss import Optimizer
- import sys
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- import torch.nn.functional as F
- import torch.distributed as dist
-
-
- import os.path as osp
- import logging
- import time
- import datetime
- import argparse
-
- logger = logging.getLogger()
-
- def str2bool(v):
- if v.lower() in ('yes', 'true', 't', 'y', '1'):
- return True
- elif v.lower() in ('no', 'false', 'f', 'n', '0'):
- return False
- else:
- raise argparse.ArgumentTypeError('Unsupported value encountered.')
-
-
- def parse_args():
- parse = argparse.ArgumentParser()
- parse.add_argument(
- '--local_rank',
- dest = 'local_rank',
- type = int,
- default = -1, # yuanshi
- # default=0, # gaidong
- )
- parse.add_argument(
- '--n_workers_train',
- dest = 'n_workers_train',
- type = int,
- default = 8,####8
- )
- parse.add_argument(
- '--n_workers_val',
- dest = 'n_workers_val',
- type = int,
- default = 2,###0
- )
- parse.add_argument(
- '--n_img_per_gpu',
- dest = 'n_img_per_gpu',
- type = int,
- default = 8,
- )
- parse.add_argument(
- '--max_iter',
- dest = 'max_iter',
- type = int,
- default = 43000, # 60000
- )
- parse.add_argument(
- '--save_iter_sep',
- dest = 'save_iter_sep',
- type = int,
- default = 1000,
- )
- parse.add_argument(
- '--warmup_steps',
- dest = 'warmup_steps',
- type = int,
- default = 1000,
- )
- parse.add_argument(
- '--mode',
- dest = 'mode',
- type = str,
- default = 'train',
- )
- parse.add_argument(
- '--ckpt',
- dest = 'ckpt',
- type = str,
- default = None,
- )
- parse.add_argument(
- '--respath',
- dest = 'respath',
- type = str,
- # default = 'checkpoints_1720/wurenji_train_STDC1-Seg', # 原始
- default='./model_save', # 改动
- )
- parse.add_argument(
- '--backbone',
- dest = 'backbone',
- type = str,
- default = 'STDCNet813',##'CatNetSmall'
- )
- parse.add_argument(
- '--pretrain_path',
- dest = 'pretrain_path',
- type = str,
- default='./checkpoints2/STDCNet813M_73.91.tar',
- )
- parse.add_argument(
- '--use_conv_last',
- dest = 'use_conv_last',
- type = str2bool,
- default = False,
- )
- parse.add_argument(
- '--use_boundary_2',
- dest = 'use_boundary_2',
- type = str2bool,
- default = False,
- )
- parse.add_argument(
- '--use_boundary_4',
- dest = 'use_boundary_4',
- type = str2bool,
- default = False,
- )
- parse.add_argument(
- '--use_boundary_8',
- dest = 'use_boundary_8',
- type = str2bool,
- default = True, # False
- )
- parse.add_argument(
- '--use_boundary_16',
- dest = 'use_boundary_16',
- type = str2bool,
- default = False,
- )
- return parse.parse_args()
-
-
- def train():
- args = parse_args()
-
- save_pth_path = os.path.join(args.respath, 'pths')
- dspth = './data/'
-
- # print(save_pth_path)
- # print(osp.exists(save_pth_path))
- # if not osp.exists(save_pth_path) and dist.get_rank()==0:
- if not osp.exists(save_pth_path):
- os.makedirs(save_pth_path)
-
- torch.cuda.set_device(args.local_rank)
- ########################################################################fenbushi
- # dist.init_process_group(
- # backend = 'nccl',
- # init_method = 'tcp://127.0.0.1:33274',
- # world_size = torch.cuda.device_count(),
- # rank=args.local_rank
- # )
-
- setup_logger(args.respath)
- ## dataset
- # n_classes = 2 # 原始
- n_classes = 3 # 改动
-
- n_img_per_gpu = args.n_img_per_gpu
- n_workers_train = args.n_workers_train
- n_workers_val = args.n_workers_val
- use_boundary_16 = args.use_boundary_16
- use_boundary_8 = args.use_boundary_8
- use_boundary_4 = args.use_boundary_4
- use_boundary_2 = args.use_boundary_2
-
- mode = args.mode # train
- cropsize = [1024, 512]
- randomscale = (0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)
- ##################################################################################################fenbushi
- # if dist.get_rank()==0:
- # logger.info('n_workers_train: {}'.format(n_workers_train))
- # logger.info('n_workers_val: {}'.format(n_workers_val))
- # logger.info('use_boundary_2: {}'.format(use_boundary_2))
- # logger.info('use_boundary_4: {}'.format(use_boundary_4))
- # logger.info('use_boundary_8: {}'.format(use_boundary_8))
- # logger.info('use_boundary_16: {}'.format(use_boundary_16))
- # logger.info('mode: {}'.format(args.mode))
-
-
- ds = Heliushuju(dspth, cropsize=cropsize, mode=mode, randomscale=randomscale)
- sampler = None
- # #################################################################################################fenbushi
- # sampler = torch.utils.data.distributed.DistributedSampler(ds)
- dl = DataLoader(ds,
- batch_size = n_img_per_gpu,
- shuffle = False,
- sampler = sampler,
- num_workers = n_workers_train,
- pin_memory = False,
- drop_last = True)
- # exit(0)
- dsval = Heliushuju(dspth, mode='val', randomscale=randomscale)
-
- # x,y = ds[0]
- # x, y = dsval[0]
- # sys.exit(0)
-
- sampler_val = None
- ##################################################################################################fenbushi
- # sampler_val = torch.utils.data.distributed.DistributedSampler(dsval)
- dlval = DataLoader(dsval,
- batch_size = 1,
- shuffle = False,
- sampler = sampler_val,
- num_workers = n_workers_val,
- drop_last = False)
-
- ## model
- ignore_idx = 255
- net = BiSeNet(backbone=args.backbone, n_classes=n_classes, pretrain_model=args.pretrain_path,
- use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4, use_boundary_8=use_boundary_8,
- use_boundary_16=use_boundary_16, use_conv_last=args.use_conv_last)
-
- if not args.ckpt is None:
- net.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
- net.cuda()
- net.train()
- ##################################################################################################fenbushi
- # net = nn.parallel.DistributedDataParallel(net,
- # device_ids = [args.local_rank, ],
- # output_device = args.local_rank,
- # find_unused_parameters=True
- # )
- net = nn.DataParallel(net, device_ids=[0])###########################################################################
-
- score_thres = 0.7
- n_min = n_img_per_gpu*cropsize[0]*cropsize[1]//16
- criteria_p = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
- criteria_16 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
- criteria_32 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
- boundary_loss_func = DetailAggregateLoss()
- ## optimizer
- maxmIOU50 = 0.
- maxmIOU75 = 0.
- momentum = 0.9
- weight_decay = 5e-4
- lr_start = 1e-2
- max_iter = args.max_iter
- save_iter_sep = args.save_iter_sep
- power = 0.9
- warmup_steps = args.warmup_steps
- warmup_start_lr = 1e-5
- ##################################################################################################fenbushi
- # if dist.get_rank()==0:
- # print('max_iter: ', max_iter)
- # print('save_iter_sep: ', save_iter_sep)
- # print('warmup_steps: ', warmup_steps)
- print('max_iter: ', max_iter)
- print('save_iter_sep: ', save_iter_sep)
- print('warmup_steps: ', warmup_steps)
- optim = Optimizer(
- model = net.module,
- loss = boundary_loss_func,
- lr0 = lr_start,
- momentum = momentum,
- wd = weight_decay,
- warmup_steps = warmup_steps,
- warmup_start_lr = warmup_start_lr,
- max_iter = max_iter,
- power = power)
-
- ## train loop
- msg_iter = 50
- loss_avg = []
- loss_boundery_bce = []
- loss_boundery_dice = []
- st = glob_st = time.time()
- diter = iter(dl)
- # diter = enumerate(dl)
- epoch = 0
- for it in range(max_iter):
- try:
- im, lb = diter.__next__()
- # print(im.size()[0])
- # im, lb = next(diter)
- if not im.size()[0]==n_img_per_gpu: raise StopIteration
- except StopIteration:
- epoch += 1
- # sampler.set_epoch(epoch)
- diter = iter(dl)
- im, lb = next(diter)
- im = im.cuda()
- lb = lb.cuda()
- H, W = im.size()[2:]
- lb = torch.squeeze(lb, 1) # lb.shape : torch.Size([8, 360, 640])
-
- # print("11111111111111111111")
- # print(lb.shape)
- # print("111111111111111111")
-
- # lb = torch.argmax(lb, dim=3) # 添加(训练高速路时,需要添加这行代码,训练水域分割时,将这行代码注释掉)
-
- optim.zero_grad()
-
-
- if use_boundary_2 and use_boundary_4 and use_boundary_8:
- out, out16, out32, detail2, detail4, detail8 = net(im)
-
- if (not use_boundary_2) and use_boundary_4 and use_boundary_8:
- out, out16, out32, detail4, detail8 = net(im)
-
- if (not use_boundary_2) and (not use_boundary_4) and use_boundary_8:#######True
- out, out16, out32, detail8 = net(im)
-
- if (not use_boundary_2) and (not use_boundary_4) and (not use_boundary_8):
- out, out16, out32 = net(im)
-
- # lossp = criteria_p(out, lb)
- # loss2 = criteria_16(out16, lb)
- # loss3 = criteria_32(out32, lb)
- # out=torch.tensor(out, dtype=torch.float64)
- # out16=torch.tensor(out16, dtype=torch.float64)
- # out32=torch.tensor(out32, dtype=torch.float64)
-
- # out=out.long()
- # out16=out16.long()
- # out32=out32.long()
- # lb=lb.long()
-
-
- lossp = criteria_p(out, lb)
- loss2 = criteria_16(out16, lb)
- loss3 = criteria_32(out32, lb)
-
- boundery_bce_loss = 0.
- boundery_dice_loss = 0.
-
- if use_boundary_2:
- # if dist.get_rank()==0:
- # print('use_boundary_2')
- boundery_bce_loss2, boundery_dice_loss2 = boundary_loss_func(detail2, lb)
- boundery_bce_loss += boundery_bce_loss2
- boundery_dice_loss += boundery_dice_loss2
-
- if use_boundary_4:
- # if dist.get_rank()==0:
- # print('use_boundary_4')
- boundery_bce_loss4, boundery_dice_loss4 = boundary_loss_func(detail4, lb)
- boundery_bce_loss += boundery_bce_loss4
- boundery_dice_loss += boundery_dice_loss4
-
- if use_boundary_8:######
- # if dist.get_rank()==0:
- # print('use_boundary_8')
- boundery_bce_loss8, boundery_dice_loss8 = boundary_loss_func(detail8, lb)
- boundery_bce_loss += boundery_bce_loss8
- boundery_dice_loss += boundery_dice_loss8
-
- loss = lossp + loss2 + loss3 + boundery_bce_loss + boundery_dice_loss
-
- loss.backward()
- optim.step()
-
- loss_avg.append(loss.item())
-
- loss_boundery_bce.append(boundery_bce_loss.item())
- loss_boundery_dice.append(boundery_dice_loss.item())
-
- ## print training log message
- if (it+1)%msg_iter==0:
- loss_avg = sum(loss_avg) / len(loss_avg)
- lr = optim.lr
- ed = time.time()
- t_intv, glob_t_intv = ed - st, ed - glob_st
- eta = int((max_iter - it) * (glob_t_intv / it))
- eta = str(datetime.timedelta(seconds=eta))
-
- loss_boundery_bce_avg = sum(loss_boundery_bce) / len(loss_boundery_bce)
- loss_boundery_dice_avg = sum(loss_boundery_dice) / len(loss_boundery_dice)
- msg = ', '.join([
- 'it: {it}/{max_it}',
- 'lr: {lr:4f}',
- 'loss: {loss:.4f}',
- 'boundery_bce_loss: {boundery_bce_loss:.4f}',
- 'boundery_dice_loss: {boundery_dice_loss:.4f}',
- 'eta: {eta}',
- 'time: {time:.4f}',
- ]).format(
- it = it+1,
- max_it = max_iter,
- lr = lr,
- loss = loss_avg,
- boundery_bce_loss = loss_boundery_bce_avg,
- boundery_dice_loss = loss_boundery_dice_avg,
- time = t_intv,
- eta = eta
- )
-
- logger.info(msg)
- loss_avg = []
- loss_boundery_bce = []
- loss_boundery_dice = []
- st = ed
- # print(boundary_loss_func.get_params())
- if (it+1)%save_iter_sep==0:# and it != 0:
-
- ## model
- logger.info('evaluating the model ...')
- logger.info('setup and restore model')
-
- net.eval()
-
- # ## evaluator
- logger.info('compute the mIOU')
- with torch.no_grad():
- single_scale1 = MscEvalV0()
- mIOU50 = single_scale1(net, dlval, n_classes)
-
- single_scale2 = MscEvalV0(scale=0.75)
- mIOU75 = single_scale2(net, dlval, n_classes)
-
- save_pth = osp.join(save_pth_path, 'model_iter{}_mIOU50_{}_mIOU75_{}.pth'
- .format(it+1, str(round(mIOU50, 4)), str(round(mIOU75, 4))))
-
- state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
- # if dist.get_rank()==0:
- torch.save(state, save_pth)
-
- logger.info('training iteration {}, model saved to: {}'.format(it+1, save_pth))
-
- if mIOU50 > maxmIOU50:
- maxmIOU50 = mIOU50
- save_pth = osp.join(save_pth_path, 'model_maxmIOU50.pth'.format(it+1))
- state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
- # if dist.get_rank()==0:
- torch.save(state, save_pth)
-
- logger.info('max mIOU model saved to: {}'.format(save_pth))
-
- if mIOU75 > maxmIOU75:
- maxmIOU75 = mIOU75
- save_pth = osp.join(save_pth_path, 'model_maxmIOU75.pth'.format(it+1))
- state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
- # if dist.get_rank()==0: torch.save(state, save_pth)
- torch.save(state, save_pth)
- logger.info('max mIOU model saved to: {}'.format(save_pth))
-
- logger.info('mIOU50 is: {}, mIOU75 is: {}'.format(mIOU50, mIOU75))
- logger.info('maxmIOU50 is: {}, maxmIOU75 is: {}.'.format(maxmIOU50, maxmIOU75))
-
- net.train()
-
- ## dump the final model
- save_pth = osp.join(save_pth_path, 'model_final.pth')
- net.cpu()
- state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
- # if dist.get_rank()==0: torch.save(state, save_pth)
- torch.save(state, save_pth)
- logger.info('training done, model saved to: {}'.format(save_pth))
- print('epoch: ', epoch)
-
-
- if __name__ == "__main__":
- train()
|