|
|
@@ -31,7 +31,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima |
|
|
|
from utils.google_utils import attempt_download |
|
|
|
from utils.loss import ComputeLoss |
|
|
|
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution |
|
|
|
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first |
|
|
|
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
@@ -136,6 +136,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) |
|
|
|
loggers = {'wandb': wandb} # loggers dict |
|
|
|
|
|
|
|
# EMA |
|
|
|
ema = ModelEMA(model) if rank in [-1, 0] else None |
|
|
|
|
|
|
|
# Resume |
|
|
|
start_epoch, best_fitness = 0, 0.0 |
|
|
|
if pretrained: |
|
|
@@ -144,6 +147,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
optimizer.load_state_dict(ckpt['optimizer']) |
|
|
|
best_fitness = ckpt['best_fitness'] |
|
|
|
|
|
|
|
# EMA |
|
|
|
if ema and ckpt.get('ema'): |
|
|
|
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict()) |
|
|
|
ema.updates = ckpt['ema'][1] |
|
|
|
|
|
|
|
# Results |
|
|
|
if ckpt.get('training_results') is not None: |
|
|
|
results_file.write_text(ckpt['training_results']) # write results.txt |
|
|
@@ -173,9 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) |
|
|
|
logger.info('Using SyncBatchNorm()') |
|
|
|
|
|
|
|
# EMA |
|
|
|
ema = ModelEMA(model) if rank in [-1, 0] else None |
|
|
|
|
|
|
|
# DDP mode |
|
|
|
if cuda and rank != -1: |
|
|
|
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) |
|
|
@@ -191,7 +196,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
|
|
|
|
# Process 0 |
|
|
|
if rank in [-1, 0]: |
|
|
|
ema.updates = start_epoch * nb // accumulate # set EMA updates |
|
|
|
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader |
|
|
|
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, |
|
|
|
world_size=opt.world_size, workers=opt.workers, |
|
|
@@ -335,8 +339,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
# DDP process 0 or single-GPU |
|
|
|
if rank in [-1, 0]: |
|
|
|
# mAP |
|
|
|
if ema: |
|
|
|
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) |
|
|
|
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) |
|
|
|
final_epoch = epoch + 1 == epochs |
|
|
|
if not opt.notest or final_epoch: # Calculate mAP |
|
|
|
results, maps, times = test.test(opt.data, |
|
|
@@ -378,8 +381,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
ckpt = {'epoch': epoch, |
|
|
|
'best_fitness': best_fitness, |
|
|
|
'training_results': results_file.read_text(), |
|
|
|
'model': ema.ema, |
|
|
|
'optimizer': None if final_epoch else optimizer.state_dict(), |
|
|
|
'model': (model.module if is_parallel(model) else model).half(), |
|
|
|
'ema': (ema.ema.half(), ema.updates), |
|
|
|
'optimizer': optimizer.state_dict(), |
|
|
|
'wandb_id': wandb_run.id if wandb else None} |
|
|
|
|
|
|
|
# Save last, best and delete |
|
|
@@ -387,6 +391,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
if best_fitness == fi: |
|
|
|
torch.save(ckpt, best) |
|
|
|
del ckpt |
|
|
|
|
|
|
|
model.float(), ema.ema.float() |
|
|
|
|
|
|
|
# end epoch ---------------------------------------------------------------------------------------------------- |
|
|
|
# end training |
|
|
|
|