* Add PyTorch AMP check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleanup * Cleanup * Cleanup * Robust for DDP * Fixes * Add amp enabled boolean to check_train_batch_size * Simplify * space to prefix Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>modifyDataloader
@@ -524,9 +524,10 @@ class AutoShape(nn.Module): | |||
max_det = 1000 # maximum number of detections per image | |||
amp = False # Automatic Mixed Precision (AMP) inference | |||
def __init__(self, model): | |||
def __init__(self, model, verbose=True): | |||
super().__init__() | |||
LOGGER.info('Adding AutoShape... ') | |||
if verbose: | |||
LOGGER.info('Adding AutoShape... ') | |||
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes | |||
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance | |||
self.pt = not self.dmb or model.pt # PyTorch model |
@@ -27,7 +27,6 @@ import torch | |||
import torch.distributed as dist | |||
import torch.nn as nn | |||
import yaml | |||
from torch.cuda import amp | |||
from torch.nn.parallel import DistributedDataParallel as DDP | |||
from torch.optim import SGD, Adam, AdamW, lr_scheduler | |||
from tqdm import tqdm | |||
@@ -46,10 +45,10 @@ from utils.autobatch import check_train_batch_size | |||
from utils.callbacks import Callbacks | |||
from utils.dataloaders import create_dataloader | |||
from utils.downloads import attempt_download | |||
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements, | |||
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, | |||
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods, | |||
one_cycle, print_args, print_mutation, strip_optimizer) | |||
from utils.general import (LOGGER, check_amp, check_dataset, check_file, check_git_status, check_img_size, | |||
check_requirements, check_suffix, check_version, check_yaml, colorstr, get_latest_run, | |||
increment_path, init_seeds, intersect_dicts, labels_to_class_weights, | |||
labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer) | |||
from utils.loggers import Loggers | |||
from utils.loggers.wandb.wandb_utils import check_wandb_resume | |||
from utils.loss import ComputeLoss | |||
@@ -126,6 +125,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio | |||
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report | |||
else: | |||
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create | |||
amp = check_amp(model) # check AMP | |||
# Freeze | |||
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze | |||
@@ -141,7 +141,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio | |||
# Batch size | |||
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size | |||
batch_size = check_train_batch_size(model, imgsz) | |||
batch_size = check_train_batch_size(model, imgsz, amp) | |||
loggers.on_params_update({"batch_size": batch_size}) | |||
# Optimizer | |||
@@ -293,7 +293,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio | |||
maps = np.zeros(nc) # mAP per class | |||
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) | |||
scheduler.last_epoch = start_epoch - 1 # do not move | |||
scaler = amp.GradScaler(enabled=cuda) | |||
scaler = torch.cuda.amp.GradScaler(enabled=amp) | |||
stopper = EarlyStopping(patience=opt.patience) | |||
compute_loss = ComputeLoss(model) # init loss class | |||
callbacks.run('on_train_start') | |||
@@ -348,7 +348,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio | |||
imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) | |||
# Forward | |||
with amp.autocast(enabled=cuda): | |||
with torch.cuda.amp.autocast(amp): | |||
pred = model(imgs) # forward | |||
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size | |||
if RANK != -1: |
@@ -7,15 +7,14 @@ from copy import deepcopy | |||
import numpy as np | |||
import torch | |||
from torch.cuda import amp | |||
from utils.general import LOGGER, colorstr | |||
from utils.torch_utils import profile | |||
def check_train_batch_size(model, imgsz=640): | |||
def check_train_batch_size(model, imgsz=640, amp=True): | |||
# Check YOLOv5 training batch size | |||
with amp.autocast(): | |||
with torch.cuda.amp.autocast(amp): | |||
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size | |||
@@ -36,9 +36,11 @@ import yaml | |||
from utils.downloads import gsutil_getsize | |||
from utils.metrics import box_iou, fitness | |||
# Settings | |||
FILE = Path(__file__).resolve() | |||
ROOT = FILE.parents[1] # YOLOv5 root directory | |||
RANK = int(os.getenv('RANK', -1)) | |||
# Settings | |||
DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory | |||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads | |||
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode | |||
@@ -505,6 +507,27 @@ def check_dataset(data, autodownload=True): | |||
return data # dictionary | |||
def check_amp(model): | |||
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation | |||
from models.common import AutoShape | |||
if next(model.parameters()).device.type == 'cpu': # get model device | |||
return False | |||
prefix = colorstr('AMP: ') | |||
im = cv2.imread(ROOT / 'data' / 'images' / 'bus.jpg')[..., ::-1] # OpenCV image (BGR to RGB) | |||
m = AutoShape(model, verbose=False) # model | |||
a = m(im).xyxy[0] # FP32 inference | |||
m.amp = True | |||
b = m(im).xyxy[0] # AMP inference | |||
if (a.shape == b.shape) and torch.allclose(a, b, atol=1.0): # close to 1.0 pixel bounding box | |||
LOGGER.info(emojis(f'{prefix}checks passed ✅')) | |||
return True | |||
else: | |||
help_url = 'https://github.com/ultralytics/yolov5/issues/7908' | |||
LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')) | |||
return False | |||
def url2file(url): | |||
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt | |||
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/ |