Update train.py (#4136)

* Refactor train.py

* Update imports

* Update imports

* Update optimizer

* cleanup
This commit is contained in:
Glenn Jocher 2021-07-24 16:11:39 +02:00 committed by GitHub
parent 264be1a616
commit 63dd65e7ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 59 deletions

104
train.py
View File

@ -17,15 +17,13 @@ from threading import Thread
import math import math
import numpy as np import numpy as np
import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
import yaml import yaml
from torch.cuda import amp from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Adam, SGD, lr_scheduler
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
@ -58,16 +56,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
device, device,
): ):
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, = \ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, = \
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
opt.resume, opt.noval, opt.nosave, opt.workers opt.resume, opt.noval, opt.nosave, opt.workers
# Directories # Directories
save_dir = Path(save_dir) w = save_dir / 'weights' # weights dir
wdir = save_dir / 'weights' w.mkdir(parents=True, exist_ok=True) # make dir
wdir.mkdir(parents=True, exist_ok=True) # make dir last, best, results_file = w / 'last.pt', w / 'best.pt', save_dir / 'results.txt'
last = wdir / 'last.pt'
best = wdir / 'best.pt'
results_file = save_dir / 'results.txt'
# Hyperparameters # Hyperparameters
if isinstance(hyp, str): if isinstance(hyp, str):
@ -92,7 +87,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
loggers = {'wandb': None, 'tb': None} # loggers dict loggers = {'wandb': None, 'tb': None} # loggers dict
if RANK in [-1, 0]: if RANK in [-1, 0]:
# TensorBoard # TensorBoard
if not evolve: if plots:
prefix = colorstr('tensorboard: ') prefix = colorstr('tensorboard: ')
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/") LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
loggers['tb'] = SummaryWriter(str(save_dir)) loggers['tb'] = SummaryWriter(str(save_dir))
@ -105,11 +100,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
loggers['wandb'] = wandb_logger.wandb loggers['wandb'] = wandb_logger.wandb
if loggers['wandb']: if loggers['wandb']:
data_dict = wandb_logger.data_dict data_dict = wandb_logger.data_dict
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update values if resuming
nc = 1 if single_cls else int(data_dict['nc']) # number of classes nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data) # check assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
# Model # Model
@ -120,23 +115,22 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
ckpt = torch.load(weights, map_location=device) # load checkpoint ckpt = torch.load(weights, map_location=device) # load checkpoint
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
state_dict = ckpt['model'].float().state_dict() # to FP32 csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
model.load_state_dict(state_dict, strict=False) # load model.load_state_dict(csd, strict=False) # load
LOGGER.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
else: else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(RANK): with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check check_dataset(data_dict) # check
train_path = data_dict['train'] train_path, val_path = data_dict['train'], data_dict['val']
val_path = data_dict['val']
# Freeze # Freeze
freeze = [] # parameter names to freeze (full or partial) freeze = [] # parameter names to freeze (full or partial)
for k, v in model.named_parameters(): for k, v in model.named_parameters():
v.requires_grad = True # train all layers v.requires_grad = True # train all layers
if any(x in k for x in freeze): if any(x in k for x in freeze):
print('freezing %s' % k) print(f'freezing {k}')
v.requires_grad = False v.requires_grad = False
# Optimizer # Optimizer
@ -145,33 +139,32 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}") LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
pg0, pg1, pg2 = [], [], [] # optimizer parameter groups g0, g1, g2 = [], [], [] # optimizer parameter groups
for k, v in model.named_modules(): for v in model.modules():
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
pg2.append(v.bias) # biases g2.append(v.bias)
if isinstance(v, nn.BatchNorm2d): if isinstance(v, nn.BatchNorm2d): # weight with decay
pg0.append(v.weight) # no decay g0.append(v.weight)
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight without decay
pg1.append(v.weight) # apply decay g1.append(v.weight)
if opt.adam: if opt.adam:
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
else: else:
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
optimizer.add_param_group({'params': pg2}) # add pg2 (biases) optimizer.add_param_group({'params': g2}) # add g2 (biases)
LOGGER.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
del pg0, pg1, pg2 f"{len(g0)} weight, {len(g1)} weight (no decay), {len(g2)} bias")
del g0, g1, g2
# Scheduler https://arxiv.org/pdf/1812.01187.pdf # Scheduler
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
if opt.linear_lr: if opt.linear_lr:
lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
else: else:
lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf'] lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
# plot_lr_scheduler(optimizer, scheduler, epochs)
# EMA # EMA
ema = ModelEMA(model) if RANK in [-1, 0] else None ema = ModelEMA(model) if RANK in [-1, 0] else None
@ -196,13 +189,12 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Epochs # Epochs
start_epoch = ckpt['epoch'] + 1 start_epoch = ckpt['epoch'] + 1
if resume: if resume:
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs) assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.'
if epochs < start_epoch: if epochs < start_epoch:
LOGGER.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
(weights, ckpt['epoch'], epochs))
epochs += ckpt['epoch'] # finetune additional epochs epochs += ckpt['epoch'] # finetune additional epochs
del ckpt, state_dict del ckpt, csd
# Image sizes # Image sizes
gs = max(int(model.stride.max()), 32) # grid size (max stride) gs = max(int(model.stride.max()), 32) # grid size (max stride)
@ -217,7 +209,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# SyncBatchNorm # SyncBatchNorm
if opt.sync_bn and cuda and RANK != -1: if opt.sync_bn and cuda and RANK != -1:
raise Exception('can not train with --sync-bn, known issue https://github.com/ultralytics/yolov5/issues/3998')
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
LOGGER.info('Using SyncBatchNorm()') LOGGER.info('Using SyncBatchNorm()')
@ -228,7 +219,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
prefix=colorstr('train: ')) prefix=colorstr('train: '))
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(train_loader) # number of batches nb = len(train_loader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1) assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
# Process 0 # Process 0
if RANK in [-1, 0]: if RANK in [-1, 0]:
@ -261,7 +252,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
hyp['label_smoothing'] = opt.label_smoothing hyp['label_smoothing'] = opt.label_smoothing
model.nc = nc # attach number of classes to model model.nc = nc # attach number of classes to model
model.hyp = hyp # attach hyperparameters to model model.hyp = hyp # attach hyperparameters to model
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
model.names = names model.names = names
@ -315,7 +305,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Warmup # Warmup
if ni <= nw: if ni <= nw:
xi = [0, nw] # x interp xi = [0, nw] # x interp
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou) # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round()) accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
for j, x in enumerate(optimizer.param_groups): for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
@ -329,7 +319,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
sf = sz / max(imgs.shape[2:]) # scale factor sf = sz / max(imgs.shape[2:]) # scale factor
if sf != 1: if sf != 1:
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
# Forward # Forward
with amp.autocast(enabled=cuda): with amp.autocast(enabled=cuda):
@ -355,7 +345,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# Print # Print
if RANK in [-1, 0]: if RANK in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % ( s = ('%10s' * 2 + '%10.4g' * 6) % (
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]) f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s) pbar.set_description(s)
@ -381,7 +371,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
# DDP process 0 or single-GPU # DDP process 0 or single-GPU
if RANK in [-1, 0]: if RANK in [-1, 0]:
# mAP # mAP
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs final_epoch = epoch + 1 == epochs
if not noval or final_epoch: # Calculate mAP if not noval or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1 wandb_logger.current_epoch = epoch + 1
@ -457,6 +447,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
batch_size=batch_size // WORLD_SIZE * 2, batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz, imgsz=imgsz,
model=attempt_load(m, device).half(), model=attempt_load(m, device).half(),
iou_thres=0.7, # NMS IoU threshold for best pycocotools results
single_cls=single_cls, single_cls=single_cls,
dataloader=val_loader, dataloader=val_loader,
save_dir=save_dir, save_dir=save_dir,
@ -525,8 +516,7 @@ def main(opt):
check_requirements(exclude=['thop']) check_requirements(exclude=['thop'])
# Resume # Resume
wandb_run = check_wandb_resume(opt) if opt.resume and not check_wandb_resume(opt): # resume an interrupted run
if opt.resume and not wandb_run: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
with open(Path(ckpt).parent.parent / 'opt.yaml') as f: with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
@ -534,7 +524,6 @@ def main(opt):
opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate
LOGGER.info(f'Resuming training from {ckpt}') LOGGER.info(f'Resuming training from {ckpt}')
else: else:
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
opt.name = 'evolve' if opt.evolve else opt.name opt.name = 'evolve' if opt.evolve else opt.name
@ -545,11 +534,13 @@ def main(opt):
if LOCAL_RANK != -1: if LOCAL_RANK != -1:
from datetime import timedelta from datetime import timedelta
assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command' assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
assert not opt.evolve, '--evolve argument is not compatible with DDP training'
assert not opt.sync_bn, '--sync-bn known training issue, see https://github.com/ultralytics/yolov5/issues/3998'
torch.cuda.set_device(LOCAL_RANK) torch.cuda.set_device(LOCAL_RANK)
device = torch.device('cuda', LOCAL_RANK) device = torch.device('cuda', LOCAL_RANK)
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=60)) dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", timeout=timedelta(seconds=60))
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
# Train # Train
if not opt.evolve: if not opt.evolve:
@ -594,7 +585,6 @@ def main(opt):
hyp = yaml.safe_load(f) # load hyps dict hyp = yaml.safe_load(f) # load hyps dict
if 'anchors' not in hyp: # anchors commented in hyp.yaml if 'anchors' not in hyp: # anchors commented in hyp.yaml
hyp['anchors'] = 3 hyp['anchors'] = 3
assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve'
opt.noval, opt.nosave = True, True # only val/save final epoch opt.noval, opt.nosave = True, True # only val/save final epoch
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
@ -646,7 +636,7 @@ def main(opt):
def run(**kwargs): def run(**kwargs):
# Usage: import train; train.run(imgsz=320, weights='yolov5m.pt') # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
opt = parse_opt(True) opt = parse_opt(True)
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(opt, k, v) setattr(opt, k, v)

View File

@ -301,7 +301,7 @@ def clean_str(s):
def one_cycle(y1=0.0, y2=1.0, steps=100): def one_cycle(y1=0.0, y2=1.0, steps=100):
# lambda function for sinusoidal ramp from y1 to y2 # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1

View File

@ -108,7 +108,7 @@ class ComputeLoss:
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7 self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
for k in 'na', 'nc', 'nl', 'anchors': for k in 'na', 'nc', 'nl', 'anchors':
setattr(self, k, getattr(det, k)) setattr(self, k, getattr(det, k))