|
|
@@ -22,7 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
import test # import test.py to get mAP after each epoch |
|
|
|
import test # for end-of-epoch mAP |
|
|
|
from models.experimental import attempt_load |
|
|
|
from models.yolo import Model |
|
|
|
from utils.autoanchor import check_anchors |
|
|
@@ -39,7 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def train(hyp, opt, device, tb_writer=None): |
|
|
|
def train(hyp, |
|
|
|
opt, |
|
|
|
device, |
|
|
|
tb_writer=None |
|
|
|
): |
|
|
|
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) |
|
|
|
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \ |
|
|
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \ |
|
|
@@ -341,7 +345,7 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
save_dir.glob('train*.jpg') if x.exists()]}) |
|
|
|
|
|
|
|
# end batch ------------------------------------------------------------------------------------------------ |
|
|
|
|
|
|
|
|
|
|
|
# Scheduler |
|
|
|
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard |
|
|
|
scheduler.step() |
|
|
@@ -404,12 +408,11 @@ def train(hyp, opt, device, tb_writer=None): |
|
|
|
torch.save(ckpt, best) |
|
|
|
if wandb_logger.wandb: |
|
|
|
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1: |
|
|
|
wandb_logger.log_model( |
|
|
|
last.parent, opt, epoch, fi, best_model=best_fitness == fi) |
|
|
|
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi) |
|
|
|
del ckpt |
|
|
|
|
|
|
|
# end epoch ---------------------------------------------------------------------------------------------------- |
|
|
|
# end training |
|
|
|
# end training ----------------------------------------------------------------------------------------------------- |
|
|
|
if rank in [-1, 0]: |
|
|
|
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') |
|
|
|
if plots: |