Browse Source

Fix warmup `accumulate` (#3722)

* gradient accumulation during warmup in train.py

Context:
`accumulate` is the number of batches/gradients accumulated before calling the next optimizer.step().
During warmup, it is ramped up from 1 to the final value nbs / batch_size. 
Although I have not seen this in other libraries, I like the idea. During warmup, as grads are large, too large steps are more of on issue than gradient noise due to small steps.

The bug:
The condition to perform the opt step is wrong
> if ni % accumulate == 0:
This produces irregular step sizes if `accumulate` is not constant. It becomes relevant when batch_size is small and `accumulate` changes many times during warmup.

This demo also shows the proposed solution, to use a ">=" condition instead:
https://colab.research.google.com/drive/1MA2z2eCXYB_BC5UZqgXueqL_y1Tz_XVq?usp=sharing

Further, I propose not to restrict the number of warmup iterations to >= 1000. If the user changes hyp['warmup_epochs'], this causes unexpected behavior. Also, it makes evolution unstable if this parameter was to be optimized.

* replace last_opt_step tracking by do_step(ni)

* add docstrings

* move down nw

* Update train.py

* revert math import move

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
modifyDataloader
yellowdolphin GitHub 3 years ago
parent
commit
3974d725b6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 1 deletions
  1. +3
    -1
      train.py

+ 3
- 1
train.py View File

@@ -270,6 +270,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
t0 = time.time()
nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
last_opt_step = -1
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
@@ -344,12 +345,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
scaler.scale(loss).backward()

# Optimize
if ni % accumulate == 0:
if ni - last_opt_step >= accumulate:
scaler.step(optimizer) # optimizer.step
scaler.update()
optimizer.zero_grad()
if ema:
ema.update(model)
last_opt_step = ni

# Print
if RANK in [-1, 0]:

Loading…
Cancel
Save