Selaa lähdekoodia

update train.py gsutil bucket fix (#463)

5.0
Glenn Jocher 4 vuotta sitten
vanhempi
commit
776555771f
1 muutettua tiedostoa jossa 40 lisäystä ja 52 poistoa
  1. +40
    -52
      train.py

+ 40
- 52
train.py Näytä tiedosto

@@ -47,11 +47,13 @@ def train(hyp, tb_writer, opt, device):
print(f'Hyperparameters {hyp}')
log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory

os.makedirs(wdir, exist_ok=True)
last = wdir + 'last.pt'
best = wdir + 'best.pt'
results_file = log_dir + os.sep + 'results.txt'
epochs, batch_size, total_batch_size, weights, rank = opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.local_rank
# TODO: Init DDP logging. Only the first process is allowed to log.
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.

# Save run settings
with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
@@ -59,17 +61,8 @@ def train(hyp, tb_writer, opt, device):
with open(Path(log_dir) / 'opt.yaml', 'w') as f:
yaml.dump(vars(opt), f, sort_keys=False)

epochs = opt.epochs # 300
batch_size = opt.batch_size # batch size per process.
total_batch_size = opt.total_batch_size
weights = opt.weights # initial training weights
local_rank = opt.local_rank

# TODO: Init DDP logging. Only the first process is allowed to log.
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.

# Configure
init_seeds(2 + local_rank)
init_seeds(2 + rank)
with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
train_path = data_dict['train']
@@ -78,7 +71,7 @@ def train(hyp, tb_writer, opt, device):
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check

# Remove previous results
if local_rank in [-1, 0]:
if rank in [-1, 0]:
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
os.remove(f)

@@ -91,7 +84,7 @@ def train(hyp, tb_writer, opt, device):

# Optimizer
nbs = 64 # nominal batch size
# the default DDP implementation is slow for accumulation according to: https://pytorch.org/docs/stable/notes/ddp.html
# default DDP implementation is slow for accumulation according to: https://pytorch.org/docs/stable/notes/ddp.html
# all-reduce operation is carried out during loss.backward().
# Thus, there would be redundant all-reduce communications in a accumulation procedure,
# which means, the result is still right but the training speed gets slower.
@@ -121,8 +114,7 @@ def train(hyp, tb_writer, opt, device):
del pg0, pg1, pg2

# Load Model
# Avoid multiple downloads.
with torch_distributed_zero_first(local_rank):
with torch_distributed_zero_first(rank):
google_utils.attempt_download(weights)
start_epoch, best_fitness = 0, 0.0
if weights.endswith('.pt'): # pytorch format
@@ -169,32 +161,31 @@ def train(hyp, tb_writer, opt, device):
# plot_lr_scheduler(optimizer, scheduler, epochs)

# DP mode
if device.type != 'cpu' and local_rank == -1 and torch.cuda.device_count() > 1:
if device.type != 'cpu' and rank == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)

# Exponential moving average
# From https://github.com/rwightman/pytorch-image-models/blob/master/train.py:
# "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper"
# chenyzsjtu: ema should be placed before after SyncBN. As SyncBN introduces new modules.
if opt.sync_bn and device.type != 'cpu' and local_rank != -1:
print("SyncBN activated!")
# SyncBatchNorm
if opt.sync_bn and device.type != 'cpu' and rank != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
ema = torch_utils.ModelEMA(model) if local_rank in [-1, 0] else None
print('Using SyncBatchNorm()')

# Exponential moving average
ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None

# DDP mode
if device.type != 'cpu' and local_rank != -1:
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
if device.type != 'cpu' and rank != -1:
model = DDP(model, device_ids=[rank], output_device=rank)

# Trainloader
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
cache=opt.cache_images, rect=opt.rect, local_rank=local_rank,
cache=opt.cache_images, rect=opt.rect, local_rank=rank,
world_size=opt.world_size)
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
nb = len(dataloader) # number of batches
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)

# Testloader
if local_rank in [-1, 0]:
if rank in [-1, 0]:
# local_rank is set to -1. Because only the first process is expected to do evaluation.
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0]
@@ -208,8 +199,7 @@ def train(hyp, tb_writer, opt, device):
model.names = names

# Class frequency
# Only one check and log is needed.
if local_rank in [-1, 0]:
if rank in [-1, 0]:
labels = np.concatenate(dataset.labels, 0)
c = torch.tensor(labels[:, 0]) # classes
# cf = torch.bincount(c.long(), minlength=nc) + 1.
@@ -222,13 +212,14 @@ def train(hyp, tb_writer, opt, device):
# Check anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)

# Start training
t0 = time.time()
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
scheduler.last_epoch = start_epoch - 1 # do not move
if local_rank in [0, -1]:
if rank in [0, -1]:
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
print('Using %g dataloader workers' % dataloader.num_workers)
print('Starting training for %g epochs...' % epochs)
@@ -240,18 +231,18 @@ def train(hyp, tb_writer, opt, device):
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
if dataset.image_weights:
# Generate indices.
if local_rank in [-1, 0]:
if rank in [-1, 0]:
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
dataset.indices = random.choices(range(dataset.n), weights=image_weights,
k=dataset.n) # rand weighted idx
# Broadcast.
if local_rank != -1:
if rank != -1:
indices = torch.zeros([dataset.n], dtype=torch.int)
if local_rank == 0:
if rank == 0:
indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int)
dist.broadcast(indices, 0)
if local_rank != 0:
if rank != 0:
dataset.indices = indices.cpu().numpy()

# Update mosaic border
@@ -259,10 +250,10 @@ def train(hyp, tb_writer, opt, device):
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders

mloss = torch.zeros(4, device=device) # mean losses
if local_rank != -1:
if rank != -1:
dataloader.sampler.set_epoch(epoch)
pbar = enumerate(dataloader)
if local_rank in [-1, 0]:
if rank in [-1, 0]:
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
pbar = tqdm(pbar, total=nb) # progress bar
optimizer.zero_grad()
@@ -293,10 +284,9 @@ def train(hyp, tb_writer, opt, device):
pred = model(imgs)

# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model)
# loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices.
if local_rank != -1:
loss *= opt.world_size
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
return results
@@ -316,7 +306,7 @@ def train(hyp, tb_writer, opt, device):
ema.update(model)

# Print
if local_rank in [-1, 0]:
if rank in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % (
@@ -337,7 +327,7 @@ def train(hyp, tb_writer, opt, device):
scheduler.step()

# Only the first process in DDP mode is allowed to log or save checkpoints.
if local_rank in [-1, 0]:
if rank in [-1, 0]:
# mAP
if ema is not None:
ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride'])
@@ -351,17 +341,17 @@ def train(hyp, tb_writer, opt, device):
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=log_dir)
# Explicitly keep the shape.
# Write
with open(results_file, 'a') as f:
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
if len(opt.name) and opt.bucket:
os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name))
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))

# Tensorboard
if tb_writer:
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1',
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
for x, tag in zip(list(mloss[:-1]) + list(results), tags):
tb_writer.add_scalar(tag, x, epoch)
@@ -389,7 +379,7 @@ def train(hyp, tb_writer, opt, device):
# end epoch ----------------------------------------------------------------------------------------------------
# end training

if local_rank in [-1, 0]:
if rank in [-1, 0]:
# Strip optimizers
n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
@@ -401,10 +391,10 @@ def train(hyp, tb_writer, opt, device):
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
# Finish
if not opt.evolve:
plot_results() # save as results.png
plot_results(save_dir=log_dir) # save as results.png
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))

dist.destroy_process_group() if local_rank not in [-1, 0] else None
dist.destroy_process_group() if rank not in [-1, 0] else None
torch.cuda.empty_cache()
return results

@@ -431,10 +421,8 @@ if __name__ == '__main__':
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
parser.add_argument("--sync-bn", action="store_true", help="Use sync-bn, only avaible in DDP mode.")
# Parameter For DDP.
parser.add_argument('--local_rank', type=int, default=-1,
help="Extra parameter for DDP implementation. Don't use it manually.")
parser.add_argument('--sync-bn', action="store_true", help='use SyncBatchNorm, only available in DDP mode')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
opt = parser.parse_args()

last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run

Loading…
Peruuta
Tallenna