Browse Source

Add EarlyStopping feature (#4576)

* Add EarlyStopping feature

* Add comment

* Cleanup

* Cleanup2

* debug

* debug2

* debug3

* debug3

* debug4

* debug5

* debug6

* debug7

* debug8

* debug9

* debug10

* debug11

* debug12

* Cleanup

* Add TODO for known DDP issue
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
93cc015748
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 1 deletions
  1. +18
    -1
      train.py
  2. +17
    -0
      utils/torch_utils.py

+ 18
- 1
train.py View File

@@ -40,7 +40,8 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolve
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, intersect_dicts, select_device, \
torch_distributed_zero_first
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.metrics import fitness
from utils.loggers import Loggers
@@ -255,6 +256,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = amp.GradScaler(enabled=cuda)
stopper = EarlyStopping(patience=opt.patience)
compute_loss = ComputeLoss(model) # init loss class
LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
f'Using {train_loader.num_workers} dataloader workers\n'
@@ -389,6 +391,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
del ckpt
callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)

# Stop Single-GPU
if stopper(epoch=epoch, fitness=fi):
break

# Stop DDP TODO: known issues shttps://github.com/ultralytics/yolov5/pull/4576
# stop = stopper(epoch=epoch, fitness=fi)
# if RANK == 0:
# dist.broadcast_object_list([stop], 0) # broadcast 'stop' to all ranks

# Stop DPP
# with torch_distributed_zero_first(RANK):
# if stop:
# break # must break all DDP ranks

# end epoch ----------------------------------------------------------------------------------------------------
# end training -----------------------------------------------------------------------------------------------------
if RANK in [-1, 0]:
@@ -454,6 +470,7 @@ def parse_opt(known=False):
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
parser.add_argument('--patience', type=int, default=30, help='EarlyStopping patience (epochs)')
opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt


+ 17
- 0
utils/torch_utils.py View File

@@ -293,6 +293,23 @@ def copy_attr(a, b, include=(), exclude=()):
setattr(a, k, v)


class EarlyStopping:
# YOLOv5 simple early stopper
def __init__(self, patience=30):
self.best_fitness = 0.0 # i.e. mAP
self.best_epoch = 0
self.patience = patience # epochs to wait after fitness stops improving to stop

def __call__(self, epoch, fitness):
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
self.best_epoch = epoch
self.best_fitness = fitness
stop = (epoch - self.best_epoch) >= self.patience # stop training if patience exceeded
if stop:
LOGGER.info(f'EarlyStopping patience {self.patience} exceeded, stopping training.')
return stop


class ModelEMA:
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).

Loading…
Cancel
Save