Parcourir la source

single command --resume (#756)

* single command --resume

* else check files, remove TODO

* argparse.Namespace()

* tensorboard lr

* bug fix in get_latest_run()
5.0
Glenn Jocher GitHub il y a 4 ans
Parent
révision
ebafd1ead5
Aucune clé connue n'a été trouvée dans la base pour cette signature ID de la clé GPG: 4AEE18F83AFDEB23
2 fichiers modifiés avec 31 ajouts et 26 suppressions
  1. +30
    -25
      train.py
  2. +1
    -1
      utils/general.py

+ 30
- 25
train.py Voir le fichier

@@ -42,7 +42,6 @@ def train(hyp, opt, device, tb_writer=None):
epochs, batch_size, total_batch_size, weights, rank = \
opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank

# TODO: Use DDP logging. Only the first process is allowed to log.
# Save run settings
with open(log_dir / 'hyp.yaml', 'w') as f:
yaml.dump(hyp, f, sort_keys=False)
@@ -130,6 +129,8 @@ def train(hyp, opt, device, tb_writer=None):

# Epochs
start_epoch = ckpt['epoch'] + 1
if opt.resume:
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
if epochs < start_epoch:
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
(weights, ckpt['epoch'], epochs))
@@ -158,19 +159,19 @@ def train(hyp, opt, device, tb_writer=None):
model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank))

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
cache=opt.cache_images, rect=opt.rect, rank=rank,
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
world_size=opt.world_size, workers=opt.workers)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
ema.updates = start_epoch * nb // accumulate # set EMA updates
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)

# Testloader
if rank in [-1, 0]:
# local_rank is set to -1. Because only the first process is expected to do evaluation.
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
cache=opt.cache_images, rect=True, rank=-1, world_size=opt.world_size,
workers=opt.workers)[0]
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
hyp=hyp, augment=False, cache=opt.cache_images, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers)[0] # only runs on process 0

# Model parameters
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
@@ -283,7 +284,7 @@ def train(hyp, opt, device, tb_writer=None):
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()
if ema is not None:
if ema:
ema.update(model)

# Print
@@ -305,12 +306,13 @@ def train(hyp, opt, device, tb_writer=None):
# end batch ------------------------------------------------------------------------------------------------

# Scheduler
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
scheduler.step()

# DDP process 0 or single-GPU
if rank in [-1, 0]:
# mAP
if ema is not None:
if ema:
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
@@ -330,10 +332,11 @@ def train(hyp, opt, device, tb_writer=None):

# Tensorboard
if tb_writer:
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
for x, tag in zip(list(mloss[:-1]) + list(results), tags):
'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss
'x/lr0', 'x/lr1', 'x/lr2'] # params
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
tb_writer.add_scalar(tag, x, epoch)

# Update best mAP
@@ -389,8 +392,7 @@ if __name__ == '__main__':
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', nargs='?', const='get_last', default=False,
help='resume from given path/last.pt, or most recent run if blank')
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--notest', action='store_true', help='only test final epoch')
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
@@ -413,21 +415,24 @@ if __name__ == '__main__':
opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
set_logging(opt.global_rank)

# Resume
if opt.resume:
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
if last and not opt.weights:
logger.info(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights
if opt.global_rank in [-1, 0]:
check_git_status()

opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
# Resume
if opt.resume: # resume an interrupted run
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
opt.cfg, opt.weights, opt.resume = '', ckpt, True
logger.info('Resuming training from %s' % ckpt)

else:
opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)

opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = select_device(opt.device, batch_size=opt.batch_size)

# DDP mode

+ 1
- 1
utils/general.py Voir le fichier

@@ -61,7 +61,7 @@ def init_seeds(seed=0):
def get_latest_run(search_dir='./runs'):
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
return max(last_list, key=os.path.getctime)
return max(last_list, key=os.path.getctime) if last_list else ''


def check_git_status():

Chargement…
Annuler
Enregistrer