Add PyTorch AMP check (#7917)

* Add PyTorch AMP check

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

* Cleanup

* Cleanup

* Robust for DDP

* Fixes

* Add amp enabled boolean to check_train_batch_size

* Simplify

* space to prefix

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2022-05-22 13:41:18 +02:00 committed by GitHub
parent 547c89b3a0
commit eb1217f3ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 14 deletions

View File

@ -524,9 +524,10 @@ class AutoShape(nn.Module):
max_det = 1000 # maximum number of detections per image max_det = 1000 # maximum number of detections per image
amp = False # Automatic Mixed Precision (AMP) inference amp = False # Automatic Mixed Precision (AMP) inference
def __init__(self, model): def __init__(self, model, verbose=True):
super().__init__() super().__init__()
LOGGER.info('Adding AutoShape... ') if verbose:
LOGGER.info('Adding AutoShape... ')
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
self.pt = not self.dmb or model.pt # PyTorch model self.pt = not self.dmb or model.pt # PyTorch model

View File

@ -27,7 +27,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import yaml import yaml
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD, Adam, AdamW, lr_scheduler from torch.optim import SGD, Adam, AdamW, lr_scheduler
from tqdm import tqdm from tqdm import tqdm
@ -46,10 +45,10 @@ from utils.autobatch import check_train_batch_size
from utils.callbacks import Callbacks from utils.callbacks import Callbacks
from utils.dataloaders import create_dataloader from utils.dataloaders import create_dataloader
from utils.downloads import attempt_download from utils.downloads import attempt_download
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size,
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run,
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
one_cycle, print_args, print_mutation, strip_optimizer) labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer)
from utils.loggers import Loggers from utils.loggers import Loggers
from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss from utils.loss import ComputeLoss
@ -126,6 +125,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
else: else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
amp = check_amp(model) # check AMP
# Freeze # Freeze
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
@ -141,7 +141,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
# Batch size # Batch size
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
batch_size = check_train_batch_size(model, imgsz) batch_size = check_train_batch_size(model, imgsz, amp)
loggers.on_params_update({"batch_size": batch_size}) loggers.on_params_update({"batch_size": batch_size})
# Optimizer # Optimizer
@ -293,7 +293,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
maps = np.zeros(nc) # mAP per class maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda) scaler = torch.cuda.amp.GradScaler(enabled=amp)
stopper = EarlyStopping(patience=opt.patience) stopper = EarlyStopping(patience=opt.patience)
compute_loss = ComputeLoss(model) # init loss class compute_loss = ComputeLoss(model) # init loss class
callbacks.run('on_train_start') callbacks.run('on_train_start')
@ -348,7 +348,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
# Forward # Forward
with amp.autocast(enabled=cuda): with torch.cuda.amp.autocast(amp):
pred = model(imgs) # forward pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1: if RANK != -1:

View File

@ -7,15 +7,14 @@ from copy import deepcopy
import numpy as np import numpy as np
import torch import torch
from torch.cuda import amp
from utils.general import LOGGER, colorstr from utils.general import LOGGER, colorstr
from utils.torch_utils import profile from utils.torch_utils import profile
def check_train_batch_size(model, imgsz=640): def check_train_batch_size(model, imgsz=640, amp=True):
# Check YOLOv5 training batch size # Check YOLOv5 training batch size
with amp.autocast(): with torch.cuda.amp.autocast(amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size

View File

@ -36,9 +36,11 @@ import yaml
from utils.downloads import gsutil_getsize from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness from utils.metrics import box_iou, fitness
# Settings
FILE = Path(__file__).resolve() FILE = Path(__file__).resolve()
ROOT = FILE.parents[1] # YOLOv5 root directory ROOT = FILE.parents[1] # YOLOv5 root directory
RANK = int(os.getenv('RANK', -1))
# Settings
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
@ -505,6 +507,27 @@ def check_dataset(data, autodownload=True):
return data # dictionary return data # dictionary
def check_amp(model):
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
from models.common import AutoShape
if next(model.parameters()).device.type == 'cpu': # get model device
return False
prefix = colorstr('AMP: ')
im = cv2.imread(ROOT / 'data' / 'images' / 'bus.jpg')[..., ::-1] # OpenCV image (BGR to RGB)
m = AutoShape(model, verbose=False) # model
a = m(im).xyxy[0] # FP32 inference
m.amp = True
b = m(im).xyxy[0] # AMP inference
if (a.shape == b.shape) and torch.allclose(a, b, atol=1.0): # close to 1.0 pixel bounding box
LOGGER.info(emojis(f'{prefix}checks passed ✅'))
return True
else:
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}'))
return False
def url2file(url): def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/ url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/