|
|
@@ -37,15 +37,17 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di |
|
|
|
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html |
|
|
|
RANK = int(os.getenv('RANK', -1)) |
|
|
|
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) |
|
|
|
|
|
|
|
|
|
|
|
def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
opt, |
|
|
|
device, |
|
|
|
): |
|
|
|
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \ |
|
|
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \ |
|
|
|
opt.single_cls |
|
|
|
save_dir, epochs, batch_size, total_batch_size, weights, single_cls = \ |
|
|
|
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.single_cls |
|
|
|
|
|
|
|
# Directories |
|
|
|
wdir = save_dir / 'weights' |
|
|
@@ -69,13 +71,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
# Configure |
|
|
|
plots = not opt.evolve # create plots |
|
|
|
cuda = device.type != 'cpu' |
|
|
|
init_seeds(2 + rank) |
|
|
|
init_seeds(2 + RANK) |
|
|
|
with open(opt.data) as f: |
|
|
|
data_dict = yaml.safe_load(f) # data dict |
|
|
|
|
|
|
|
# Loggers |
|
|
|
loggers = {'wandb': None, 'tb': None} # loggers dict |
|
|
|
if rank in [-1, 0]: |
|
|
|
if RANK in [-1, 0]: |
|
|
|
# TensorBoard |
|
|
|
if not opt.evolve: |
|
|
|
prefix = colorstr('tensorboard: ') |
|
|
@@ -99,7 +101,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
# Model |
|
|
|
pretrained = weights.endswith('.pt') |
|
|
|
if pretrained: |
|
|
|
with torch_distributed_zero_first(rank): |
|
|
|
with torch_distributed_zero_first(RANK): |
|
|
|
weights = attempt_download(weights) # download if not found locally |
|
|
|
ckpt = torch.load(weights, map_location=device) # load checkpoint |
|
|
|
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create |
|
|
@@ -110,7 +112,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report |
|
|
|
else: |
|
|
|
model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create |
|
|
|
with torch_distributed_zero_first(rank): |
|
|
|
with torch_distributed_zero_first(RANK): |
|
|
|
check_dataset(data_dict) # check |
|
|
|
train_path = data_dict['train'] |
|
|
|
test_path = data_dict['val'] |
|
|
@@ -158,7 +160,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
# plot_lr_scheduler(optimizer, scheduler, epochs) |
|
|
|
|
|
|
|
# EMA |
|
|
|
ema = ModelEMA(model) if rank in [-1, 0] else None |
|
|
|
ema = ModelEMA(model) if RANK in [-1, 0] else None |
|
|
|
|
|
|
|
# Resume |
|
|
|
start_epoch, best_fitness = 0, 0.0 |
|
|
@@ -194,28 +196,28 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples |
|
|
|
|
|
|
|
# DP mode |
|
|
|
if cuda and rank == -1 and torch.cuda.device_count() > 1: |
|
|
|
if cuda and RANK == -1 and torch.cuda.device_count() > 1: |
|
|
|
model = torch.nn.DataParallel(model) |
|
|
|
|
|
|
|
# SyncBatchNorm |
|
|
|
if opt.sync_bn and cuda and rank != -1: |
|
|
|
if opt.sync_bn and cuda and RANK != -1: |
|
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) |
|
|
|
logger.info('Using SyncBatchNorm()') |
|
|
|
|
|
|
|
# Trainloader |
|
|
|
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls, |
|
|
|
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, |
|
|
|
world_size=opt.world_size, workers=opt.workers, |
|
|
|
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK, |
|
|
|
workers=opt.workers, |
|
|
|
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) |
|
|
|
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class |
|
|
|
nb = len(dataloader) # number of batches |
|
|
|
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) |
|
|
|
|
|
|
|
# Process 0 |
|
|
|
if rank in [-1, 0]: |
|
|
|
if RANK in [-1, 0]: |
|
|
|
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls, |
|
|
|
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, |
|
|
|
world_size=opt.world_size, workers=opt.workers, |
|
|
|
workers=opt.workers, |
|
|
|
pad=0.5, prefix=colorstr('val: '))[0] |
|
|
|
|
|
|
|
if not opt.resume: |
|
|
@@ -234,8 +236,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
model.half().float() # pre-reduce anchor precision |
|
|
|
|
|
|
|
# DDP mode |
|
|
|
if cuda and rank != -1: |
|
|
|
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank, |
|
|
|
if cuda and RANK != -1: |
|
|
|
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, |
|
|
|
# nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698 |
|
|
|
find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules())) |
|
|
|
|
|
|
@@ -269,15 +271,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
# Update image weights (optional) |
|
|
|
if opt.image_weights: |
|
|
|
# Generate indices |
|
|
|
if rank in [-1, 0]: |
|
|
|
if RANK in [-1, 0]: |
|
|
|
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights |
|
|
|
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights |
|
|
|
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx |
|
|
|
# Broadcast if DDP |
|
|
|
if rank != -1: |
|
|
|
indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int() |
|
|
|
if RANK != -1: |
|
|
|
indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int() |
|
|
|
dist.broadcast(indices, 0) |
|
|
|
if rank != 0: |
|
|
|
if RANK != 0: |
|
|
|
dataset.indices = indices.cpu().numpy() |
|
|
|
|
|
|
|
# Update mosaic border |
|
|
@@ -285,11 +287,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders |
|
|
|
|
|
|
|
mloss = torch.zeros(4, device=device) # mean losses |
|
|
|
if rank != -1: |
|
|
|
if RANK != -1: |
|
|
|
dataloader.sampler.set_epoch(epoch) |
|
|
|
pbar = enumerate(dataloader) |
|
|
|
logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size')) |
|
|
|
if rank in [-1, 0]: |
|
|
|
if RANK in [-1, 0]: |
|
|
|
pbar = tqdm(pbar, total=nb) # progress bar |
|
|
|
optimizer.zero_grad() |
|
|
|
for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- |
|
|
@@ -319,8 +321,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
with amp.autocast(enabled=cuda): |
|
|
|
pred = model(imgs) # forward |
|
|
|
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size |
|
|
|
if rank != -1: |
|
|
|
loss *= opt.world_size # gradient averaged between devices in DDP mode |
|
|
|
if RANK != -1: |
|
|
|
loss *= WORLD_SIZE # gradient averaged between devices in DDP mode |
|
|
|
if opt.quad: |
|
|
|
loss *= 4. |
|
|
|
|
|
|
@@ -336,7 +338,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
ema.update(model) |
|
|
|
|
|
|
|
# Print |
|
|
|
if rank in [-1, 0]: |
|
|
|
if RANK in [-1, 0]: |
|
|
|
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses |
|
|
|
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) |
|
|
|
s = ('%10s' * 2 + '%10.4g' * 6) % ( |
|
|
@@ -362,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
# DDP process 0 or single-GPU |
|
|
|
if rank in [-1, 0]: |
|
|
|
if RANK in [-1, 0]: |
|
|
|
# mAP |
|
|
|
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) |
|
|
|
final_epoch = epoch + 1 == epochs |
|
|
@@ -424,7 +426,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
|
|
|
|
# end epoch ---------------------------------------------------------------------------------------------------- |
|
|
|
# end training ----------------------------------------------------------------------------------------------------- |
|
|
|
if rank in [-1, 0]: |
|
|
|
if RANK in [-1, 0]: |
|
|
|
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n') |
|
|
|
if plots: |
|
|
|
plot_results(save_dir=save_dir) # save as results.png |
|
|
@@ -457,8 +459,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary |
|
|
|
name='run_' + wandb_logger.wandb_run.id + '_model', |
|
|
|
aliases=['latest', 'best', 'stripped']) |
|
|
|
wandb_logger.finish_run() |
|
|
|
else: |
|
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
return results |
|
|
|
|
|
|
@@ -486,7 +487,6 @@ def parse_opt(): |
|
|
|
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') |
|
|
|
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') |
|
|
|
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') |
|
|
|
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') |
|
|
|
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') |
|
|
|
parser.add_argument('--project', default='runs/train', help='save to project/name') |
|
|
|
parser.add_argument('--entity', default=None, help='W&B entity') |
|
|
@@ -499,18 +499,15 @@ def parse_opt(): |
|
|
|
parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B') |
|
|
|
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') |
|
|
|
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') |
|
|
|
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') |
|
|
|
opt = parser.parse_args() |
|
|
|
|
|
|
|
# Set DDP variables |
|
|
|
opt.world_size = int(getattr(os.environ, 'WORLD_SIZE', 1)) |
|
|
|
opt.global_rank = int(getattr(os.environ, 'RANK', -1)) |
|
|
|
return opt |
|
|
|
|
|
|
|
|
|
|
|
def main(opt): |
|
|
|
print(opt) |
|
|
|
set_logging(opt.global_rank) |
|
|
|
if opt.global_rank in [-1, 0]: |
|
|
|
set_logging(RANK) |
|
|
|
if RANK in [-1, 0]: |
|
|
|
print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items())) |
|
|
|
check_git_status() |
|
|
|
check_requirements(exclude=['thop']) |
|
|
|
|
|
|
@@ -519,11 +516,9 @@ def main(opt): |
|
|
|
if opt.resume and not wandb_run: # resume an interrupted run |
|
|
|
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path |
|
|
|
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' |
|
|
|
apriori = opt.global_rank, opt.local_rank |
|
|
|
with open(Path(ckpt).parent.parent / 'opt.yaml') as f: |
|
|
|
opt = argparse.Namespace(**yaml.safe_load(f)) # replace |
|
|
|
opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = \ |
|
|
|
'', ckpt, True, opt.total_batch_size, *apriori # reinstate |
|
|
|
opt.cfg, opt.weights, opt.resume, opt.batch_size = '', ckpt, True, opt.total_batch_size # reinstate |
|
|
|
logger.info('Resuming training from %s' % ckpt) |
|
|
|
else: |
|
|
|
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') |
|
|
@@ -536,19 +531,21 @@ def main(opt): |
|
|
|
# DDP mode |
|
|
|
opt.total_batch_size = opt.batch_size |
|
|
|
device = select_device(opt.device, batch_size=opt.batch_size) |
|
|
|
if opt.local_rank != -1: |
|
|
|
assert torch.cuda.device_count() > opt.local_rank |
|
|
|
torch.cuda.set_device(opt.local_rank) |
|
|
|
device = torch.device('cuda', opt.local_rank) |
|
|
|
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend |
|
|
|
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' |
|
|
|
if LOCAL_RANK != -1: |
|
|
|
from datetime import timedelta |
|
|
|
assert torch.cuda.device_count() > LOCAL_RANK, 'too few GPUS for DDP command' |
|
|
|
torch.cuda.set_device(LOCAL_RANK) |
|
|
|
device = torch.device('cuda', LOCAL_RANK) |
|
|
|
dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60)) |
|
|
|
assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count' |
|
|
|
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training' |
|
|
|
opt.batch_size = opt.total_batch_size // opt.world_size |
|
|
|
opt.batch_size = opt.total_batch_size // WORLD_SIZE |
|
|
|
|
|
|
|
# Train |
|
|
|
logger.info(opt) |
|
|
|
if not opt.evolve: |
|
|
|
train(opt.hyp, opt, device) |
|
|
|
if WORLD_SIZE > 1 and RANK == 0: |
|
|
|
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')] |
|
|
|
|
|
|
|
# Evolve hyperparameters (optional) |
|
|
|
else: |
|
|
@@ -584,7 +581,7 @@ def main(opt): |
|
|
|
|
|
|
|
with open(opt.hyp) as f: |
|
|
|
hyp = yaml.safe_load(f) # load hyps dict |
|
|
|
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' |
|
|
|
assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve' |
|
|
|
opt.notest, opt.nosave = True, True # only test/save final epoch |
|
|
|
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices |
|
|
|
yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here |