Browse Source

PyTorch 1.6.0 update with native AMP (#573)

* PyTorch have Automatic Mixed Precision (AMP) Training.

* Fixed the problem of inconsistent code length indentation

* Fixed the problem of inconsistent code length indentation

* Mixed precision training is turned on by default
5.0
Liu Changyu GitHub 4 years ago
parent
commit
c020875b17
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 46 deletions
  1. +36
    -44
      train.py
  2. +2
    -2
      utils/torch_utils.py

+ 36
- 44
train.py View File

import torch.optim as optim import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data import torch.utils.data
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter


from utils.datasets import * from utils.datasets import *
from utils.utils import * from utils.utils import *


mixed_precision = True
try: # Mixed precision training https://github.com/NVIDIA/apex
from apex import amp
except:
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
mixed_precision = False # not installed

# Hyperparameters # Hyperparameters
hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) 'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
yaml.dump(vars(opt), f, sort_keys=False) yaml.dump(vars(opt), f, sort_keys=False)


# Configure # Configure
cuda = device.type != 'cpu'
init_seeds(2 + rank) init_seeds(2 + rank)
with open(opt.data) as f: with open(opt.data) as f:
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
optimizer.add_param_group({'params': pg2}) # add pg2 (biases) optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
del pg0, pg1, pg2 del pg0, pg1, pg2
# Scheduler https://arxiv.org/pdf/1812.01187.pdf # Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)


del ckpt del ckpt


# Mixed precision training https://github.com/NVIDIA/apex
if mixed_precision:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

# DP mode # DP mode
if device.type != 'cpu' and rank == -1 and torch.cuda.device_count() > 1:
if cuda and rank == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)


# SyncBatchNorm # SyncBatchNorm
if opt.sync_bn and device.type != 'cpu' and rank != -1:
if opt.sync_bn and cuda and rank != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
print('Using SyncBatchNorm()') print('Using SyncBatchNorm()')


ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None


# DDP mode # DDP mode
if device.type != 'cpu' and rank != -1:
if cuda and rank != -1:
model = DDP(model, device_ids=[rank], output_device=rank) model = DDP(model, device_ids=[rank], output_device=rank)


# Trainloader # Trainloader
maps = np.zeros(nc) # mAP per class maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification' results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
scheduler.last_epoch = start_epoch - 1 # do not move scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
if rank in [0, -1]: if rank in [0, -1]:
print('Image sizes %g train, %g test' % (imgsz, imgsz_test)) print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
print('Using %g dataloader workers' % dataloader.num_workers) print('Using %g dataloader workers' % dataloader.num_workers)
model.train() model.train()


# Update image weights (optional) # Update image weights (optional)
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
if dataset.image_weights: if dataset.image_weights:
# Generate indices.
# Generate indices
if rank in [-1, 0]: if rank in [-1, 0]:
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w) image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
dataset.indices = random.choices(range(dataset.n), weights=image_weights, dataset.indices = random.choices(range(dataset.n), weights=image_weights,
k=dataset.n) # rand weighted idx k=dataset.n) # rand weighted idx
# Broadcast.
# Broadcast if DDP
if rank != -1: if rank != -1:
indices = torch.zeros([dataset.n], dtype=torch.int) indices = torch.zeros([dataset.n], dtype=torch.int)
if rank == 0: if rank == 0:
optimizer.zero_grad() optimizer.zero_grad()
for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
ni = i + nb * epoch # number integrated batches (since train start) ni = i + nb * epoch # number integrated batches (since train start)
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0


# Warmup # Warmup
if ni <= nw: if ni <= nw:
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)


# Forward
pred = model(imgs)
# Autocast
with amp.autocast():
# Forward
pred = model(imgs)


# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
return results
# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode
# if not torch.isfinite(loss):
# print('WARNING: non-finite loss, ending training ', loss_items)
# return results


# Backward # Backward
if mixed_precision:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
scaler.scale(loss).backward()


# Optimize # Optimize
if ni % accumulate == 0: if ni % accumulate == 0:
optimizer.step()
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if ema is not None: if ema is not None:
ema.update(model) ema.update(model)
# Print # Print
if rank in [-1, 0]: if rank in [-1, 0]:
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
s = ('%10s' * 2 + '%10.4g' * 6) % ( s = ('%10s' * 2 + '%10.4g' * 6) % (
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1]) '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
pbar.set_description(s) pbar.set_description(s)
# Scheduler # Scheduler
scheduler.step() scheduler.step()


# Only the first process in DDP mode is allowed to log or save checkpoints.
# DDP process 0 or single-GPU
if rank in [-1, 0]: if rank in [-1, 0]:
# mAP # mAP
if ema is not None: if ema is not None:


# Save last, best and delete # Save last, best and delete
torch.save(ckpt, last) torch.save(ckpt, last)
if best_fitness == fi:
if best_fitness == fi:
torch.save(ckpt, best) torch.save(ckpt, best)
del ckpt del ckpt
# end epoch ---------------------------------------------------------------------------------------------------- # end epoch ----------------------------------------------------------------------------------------------------
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
opt = parser.parse_args() opt = parser.parse_args()


# Resume
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
if last and not opt.weights: if last and not opt.weights:
print(f'Resuming training from {last}') print(f'Resuming training from {last}')
opt.weights = last if opt.resume and not opt.weights else opt.weights opt.weights = last if opt.resume and not opt.weights else opt.weights

if opt.local_rank in [-1, 0]: if opt.local_rank in [-1, 0]:
check_git_status() check_git_status()
opt.cfg = check_file(opt.cfg) # check file opt.cfg = check_file(opt.cfg) # check file
with open(opt.hyp) as f: with open(opt.hyp) as f:
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
device = torch_utils.select_device(opt.device, batch_size=opt.batch_size)
opt.total_batch_size = opt.batch_size opt.total_batch_size = opt.batch_size
opt.world_size = 1 opt.world_size = 1
if device.type == 'cpu':
mixed_precision = False
elif opt.local_rank != -1:
# DDP mode

# DDP mode
if opt.local_rank != -1:
assert torch.cuda.device_count() > opt.local_rank assert torch.cuda.device_count() > opt.local_rank
torch.cuda.set_device(opt.local_rank) torch.cuda.set_device(opt.local_rank)
device = torch.device("cuda", opt.local_rank) device = torch.device("cuda", opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend dist.init_process_group(backend='nccl', init_method='env://') # distributed backend

opt.world_size = dist.get_world_size() opt.world_size = dist.get_world_size()
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!" assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
opt.batch_size = opt.total_batch_size // opt.world_size opt.batch_size = opt.total_batch_size // opt.world_size

print(opt) print(opt)


# Train # Train
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name)) tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
else: else:
tb_writer = None tb_writer = None

train(hyp, tb_writer, opt, device) train(hyp, tb_writer, opt, device)


# Evolve hyperparameters (optional) # Evolve hyperparameters (optional)
else: else:
assert opt.local_rank == -1, "DDP mode currently not implemented for Evolve!"
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'


tb_writer = None tb_writer = None
opt.notest, opt.nosave = True, True # only test/save final epoch opt.notest, opt.nosave = True, True # only test/save final epoch

+ 2
- 2
utils/torch_utils.py View File

cudnn.benchmark = True cudnn.benchmark = True




def select_device(device='', apex=False, batch_size=None):
def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3' # device = 'cpu' or '0' or '0,1,2,3'
cpu_request = device.lower() == 'cpu' cpu_request = device.lower() == 'cpu'
if device and not cpu_request: # if device requested other than 'cpu' if device and not cpu_request: # if device requested other than 'cpu'
if ng > 1 and batch_size: # check that batch_size is compatible with device_count if ng > 1 and batch_size: # check that batch_size is compatible with device_count
assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng) assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
x = [torch.cuda.get_device_properties(i) for i in range(ng)] x = [torch.cuda.get_device_properties(i) for i in range(ng)]
s = 'Using CUDA ' + ('Apex ' if apex else '') # apex for mixed precision https://github.com/NVIDIA/apex
s = 'Using CUDA '
for i in range(0, ng): for i in range(0, ng):
if i == 1: if i == 1:
s = ' ' * len(s) s = ' ' * len(s)

Loading…
Cancel
Save