STDC-th/train.py

349 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python
# -*- encoding: utf-8 -*-
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from logger import setup_logger
from models.model_stages import BiSeNet
# from heliushuju import Heliushuju
from heliushuju_process import Heliushuju
from loss.loss import OhemCELoss
from loss.detail_loss import DetailAggregateLoss
# from evaluation import MscEvalV0
from evaluation_process import MscEvalV0
from optimizer_loss import Optimizer
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.distributed as dist
import os.path as osp
import logging
import time
import datetime
import argparse
import json
logger = logging.getLogger()
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument('--parJson',dest = 'parJson',type = str,default = './data/Car',)
parse.add_argument('--gpuId',dest = 'gpuId',type = str,default = '0',)
parse.add_argument('--local_rank',dest = 'local_rank',type = int,default = -1 )
parse.add_argument('--n_workers_train',dest = 'n_workers_train',type = int,default = 8)
parse.add_argument('--n_workers_val',dest = 'n_workers_val',type = int,default = 2)
parse.add_argument('--n_img_per_gpu',dest = 'n_img_per_gpu',type = int,default = 8,)
parse.add_argument('--max_iter',dest = 'max_iter',type = int,default = 43000)
parse.add_argument('--save_iter_sep',dest = 'save_iter_sep',type = int,default = 1000,)
parse.add_argument('--warmup_steps',dest = 'warmup_steps',type = int,default = 1000,)
parse.add_argument('--mode',dest = 'mode',type = str,default = 'train',)
parse.add_argument('--ckpt',dest = 'ckpt',type = str,default = None,)
parse.add_argument('--respath',dest = 'respath',type = str,default='./model_save')
parse.add_argument('--backbone',dest = 'backbone',type = str,default = 'STDCNet813')
parse.add_argument('--pretrain_path',dest = 'pretrain_path',type = str,default='./checkpoints2/STDCNet813M_73.91.tar',)
parse.add_argument('--use_conv_last',dest = 'use_conv_last',type = str2bool,default = False,)
parse.add_argument('--use_boundary_2',dest = 'use_boundary_2',type = str2bool,default = False,)
parse.add_argument('--use_boundary_4',dest = 'use_boundary_4',type = str2bool,default = False,)
parse.add_argument('--use_boundary_8',dest = 'use_boundary_8',type = str2bool,default = True)
parse.add_argument('--use_boundary_16',dest = 'use_boundary_16',type = str2bool,default = False,)
return parse.parse_args()
def train():
args = parse_args()
with open(args.parJson.strip(),'r') as fp:
par=json.load(fp)
#print('args.gpuId.strip :',args.gpuId.strip())
#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuId.strip())
torch.cuda.set_device(int(args.gpuId.strip()))
#print('line72:',torch.cuda.current_device())
save_pth_path = os.path.join(args.respath, 'pths')
#dspth = '../../data/CarRoadLane/'
dspth = par['dspth']
if not osp.exists(save_pth_path):
os.makedirs(save_pth_path)
#torch.cuda.set_device(args.local_rank)
########################################################################fenbushi
setup_logger(args.respath)
#n_classes = 4 # 改动
n_classes = par['n_classes']
n_img_per_gpu = args.n_img_per_gpu
n_workers_train = args.n_workers_train
n_workers_val = args.n_workers_val
use_boundary_16 = args.use_boundary_16
use_boundary_8 = args.use_boundary_8
use_boundary_4 = args.use_boundary_4
use_boundary_2 = args.use_boundary_2
mode = args.mode # train
#cropsize = [640, 360]
cropsize = eval(par['cropsize'])
#labelJson='./heliushuju_info.json'
labelJson = par['labelJson']
randomscale = (0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.125, 1.25, 1.375, 1.5)
#print('line102:',torch.cuda.current_device())
ds = Heliushuju(dspth, cropsize=cropsize, mode=mode,labelJson=labelJson, randomscale=randomscale)
sampler = None
# #################################################################################################fenbushi
# sampler = torch.utils.data.distributed.DistributedSampler(ds)
dl = DataLoader(ds,
batch_size = n_img_per_gpu,
shuffle = False,
sampler = sampler,
num_workers = n_workers_train,
pin_memory = False,
drop_last = True)
dsval = Heliushuju(dspth, mode='val',labelJson=labelJson, randomscale=randomscale)
sampler_val = None
dlval = DataLoader(dsval,
batch_size = 1,
shuffle = False,
sampler = sampler_val,
num_workers = n_workers_val,
drop_last = False)
#print('line124:',torch.cuda.current_device())
## model
#ignore_idx = 255
ignore_idx = par['ignore_idx']
net = BiSeNet(backbone=args.backbone, n_classes=n_classes, pretrain_model=args.pretrain_path,
use_boundary_2=use_boundary_2, use_boundary_4=use_boundary_4, use_boundary_8=use_boundary_8,
use_boundary_16=use_boundary_16, use_conv_last=args.use_conv_last)
if not args.ckpt is None:
net.load_state_dict(torch.load(args.ckpt, map_location='cpu'))
print('load checkpoints model:',args.ckpt)
#print('line136:',torch.cuda.current_device())
net.cuda()
net.train()
#print('line142:',net.device)
#net = nn.DataParallel(net, device_ids=[0])###########################################################################
#print('line144:',net.device)
score_thres = 0.7
n_min = n_img_per_gpu*cropsize[0]*cropsize[1]//16
criteria_p = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
criteria_16 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
criteria_32 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
boundary_loss_func = DetailAggregateLoss()
## optimizer
maxmIOU50 = 0.
maxmIOU75 = 0.
momentum = 0.9
weight_decay = 5e-4
lr_start = 1e-2
max_iter = args.max_iter
save_iter_sep = args.save_iter_sep
power = 0.9
warmup_steps = args.warmup_steps
warmup_start_lr = 1e-5
print('max_iter: ', max_iter)
print('save_iter_sep: ', save_iter_sep)
print('warmup_steps: ', warmup_steps)
optim = Optimizer(
#model = net.module,
model = net,
loss = boundary_loss_func,
lr0 = lr_start,
momentum = momentum,
wd = weight_decay,
warmup_steps = warmup_steps,
warmup_start_lr = warmup_start_lr,
max_iter = max_iter,
power = power)
## train loop
msg_iter = 50
loss_avg = []
loss_boundery_bce = []
loss_boundery_dice = []
st = glob_st = time.time()
diter = iter(dl)
# diter = enumerate(dl)
epoch = 0
for it in range(max_iter):
try:
im, lb = diter.__next__()
# print(im.size()[0])
# im, lb = next(diter)
if not im.size()[0]==n_img_per_gpu: raise StopIteration
except StopIteration:
epoch += 1
# sampler.set_epoch(epoch)
diter = iter(dl)
im, lb = next(diter)
im = im.cuda()
lb = lb.cuda()
H, W = im.size()[2:]
lb = torch.squeeze(lb, 1) # lb.shape : torch.Size([8, 360, 640])
# lb = torch.argmax(lb, dim=3) # 添加(训练高速路时,需要添加这行代码,训练水域分割时,将这行代码注释掉)
optim.zero_grad()
if use_boundary_2 and use_boundary_4 and use_boundary_8:
out, out16, out32, detail2, detail4, detail8 = net(im)
if (not use_boundary_2) and use_boundary_4 and use_boundary_8:
out, out16, out32, detail4, detail8 = net(im)
if (not use_boundary_2) and (not use_boundary_4) and use_boundary_8:#######True
out, out16, out32, detail8 = net(im)
if (not use_boundary_2) and (not use_boundary_4) and (not use_boundary_8):
out, out16, out32 = net(im)
lossp = criteria_p(out, lb)
loss2 = criteria_16(out16, lb)
loss3 = criteria_32(out32, lb)
#print(' line220 : img.size:', H,W,'label.size',lb.size(),' pred size:',out.size() )
boundery_bce_loss = 0.
boundery_dice_loss = 0.
if use_boundary_2:
boundery_bce_loss2, boundery_dice_loss2 = boundary_loss_func(detail2, lb)
boundery_bce_loss += boundery_bce_loss2
boundery_dice_loss += boundery_dice_loss2
if use_boundary_4:
boundery_bce_loss4, boundery_dice_loss4 = boundary_loss_func(detail4, lb)
boundery_bce_loss += boundery_bce_loss4
boundery_dice_loss += boundery_dice_loss4
if use_boundary_8:######
boundery_bce_loss8, boundery_dice_loss8 = boundary_loss_func(detail8, lb)
boundery_bce_loss += boundery_bce_loss8
boundery_dice_loss += boundery_dice_loss8
loss = lossp + loss2 + loss3 + boundery_bce_loss + boundery_dice_loss
loss.backward()
optim.step()
loss_avg.append(loss.item())
loss_boundery_bce.append(boundery_bce_loss.item())
loss_boundery_dice.append(boundery_dice_loss.item())
## print training log message
if (it+1)%msg_iter==0:
loss_avg = sum(loss_avg) / len(loss_avg)
lr = optim.lr
ed = time.time()
t_intv, glob_t_intv = ed - st, ed - glob_st
eta = int((max_iter - it) * (glob_t_intv / it))
eta = str(datetime.timedelta(seconds=eta))
loss_boundery_bce_avg = sum(loss_boundery_bce) / len(loss_boundery_bce)
loss_boundery_dice_avg = sum(loss_boundery_dice) / len(loss_boundery_dice)
msg = ', '.join([
'it: {it}/{max_it}',
'lr: {lr:4f}',
'loss: {loss:.4f}',
'boundery_bce_loss: {boundery_bce_loss:.4f}',
'boundery_dice_loss: {boundery_dice_loss:.4f}',
'eta: {eta}',
'time: {time:.4f}',
]).format(
it = it+1,
max_it = max_iter,
lr = lr,
loss = loss_avg,
boundery_bce_loss = loss_boundery_bce_avg,
boundery_dice_loss = loss_boundery_dice_avg,
time = t_intv,
eta = eta
)
logger.info(msg)
loss_avg = []
loss_boundery_bce = []
loss_boundery_dice = []
st = ed
# print(boundary_loss_func.get_params())
if (it+1)%save_iter_sep==0:# and it != 0:
#if (it+1)%50==0:# and it != 0:
## model
logger.info('evaluating the model ...')
logger.info('setup and restore model')
net.eval()
# ## evaluator
logger.info('compute the mIOU')
with torch.no_grad():
single_scale1 = MscEvalV0()
mIOU50 = single_scale1(net, dlval, n_classes)
single_scale2 = MscEvalV0(scale=0.75)
mIOU75 = single_scale2(net, dlval, n_classes)
save_pth = osp.join(save_pth_path, 'model_iter{}_mIOU50_{}_mIOU75_{}.pth'
.format(it+1, str(round(mIOU50, 4)), str(round(mIOU75, 4))))
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
# if dist.get_rank()==0:
torch.save(state, save_pth)
logger.info('training iteration {}, model saved to: {}'.format(it+1, save_pth))
if mIOU50 > maxmIOU50:
maxmIOU50 = mIOU50
save_pth = osp.join(save_pth_path, 'model_maxmIOU50.pth'.format(it+1))
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
# if dist.get_rank()==0:
torch.save(state, save_pth)
logger.info('max mIOU model saved to: {}'.format(save_pth))
if mIOU75 > maxmIOU75:
maxmIOU75 = mIOU75
save_pth = osp.join(save_pth_path, 'model_maxmIOU75.pth'.format(it+1))
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
# if dist.get_rank()==0: torch.save(state, save_pth)
torch.save(state, save_pth)
logger.info('max mIOU model saved to: {}'.format(save_pth))
logger.info('mIOU50 is: {}, mIOU75 is: {}'.format(mIOU50, mIOU75))
logger.info('maxmIOU50 is: {}, maxmIOU75 is: {}.'.format(maxmIOU50, maxmIOU75))
net.train()
## dump the final model
save_pth = osp.join(save_pth_path, 'model_final.pth')
net.cpu()
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
# if dist.get_rank()==0: torch.save(state, save_pth)
torch.save(state, save_pth)
logger.info('training done, model saved to: {}'.format(save_pth))
print('epoch: ', epoch)
if __name__ == "__main__":
train()