|
|
@@ -4,6 +4,7 @@ import math |
|
|
|
import os |
|
|
|
import random |
|
|
|
import time |
|
|
|
from copy import deepcopy |
|
|
|
from pathlib import Path |
|
|
|
from threading import Thread |
|
|
|
|
|
|
@@ -381,8 +382,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
ckpt = {'epoch': epoch, |
|
|
|
'best_fitness': best_fitness, |
|
|
|
'training_results': results_file.read_text(), |
|
|
|
'model': (model.module if is_parallel(model) else model).half(), |
|
|
|
'ema': (ema.ema.half(), ema.updates), |
|
|
|
'model': deepcopy(model.module if is_parallel(model) else model).half(), |
|
|
|
'ema': (deepcopy(ema.ema).half(), ema.updates), |
|
|
|
'optimizer': optimizer.state_dict(), |
|
|
|
'wandb_id': wandb_run.id if wandb else None} |
|
|
|
|
|
|
@@ -392,8 +393,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): |
|
|
|
torch.save(ckpt, best) |
|
|
|
del ckpt |
|
|
|
|
|
|
|
model.float(), ema.ema.float() |
|
|
|
|
|
|
|
# end epoch ---------------------------------------------------------------------------------------------------- |
|
|
|
# end training |
|
|
|
|