Browse Source

Update optimizer param group strategy (#7376)

* Update optimizer param group strategy

Avoid empty lists on missing BathNorm2d models as in https://github.com/ultralytics/yolov5/issues/7375

* fix init
modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
bd2dda8e64
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 11 deletions
  1. +11
    -11
      train.py

+ 11
- 11
train.py View File

@@ -150,27 +150,27 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")

g0, g1, g2 = [], [], [] # optimizer parameter groups
g = [], [], [] # optimizer parameter groups
for v in model.modules():
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
g2.append(v.bias)
g[2].append(v.bias)
if isinstance(v, nn.BatchNorm2d): # weight (no decay)
g0.append(v.weight)
g[1].append(v.weight)
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
g1.append(v.weight)
g[0].append(v.weight)

if opt.optimizer == 'Adam':
optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
elif opt.optimizer == 'AdamW':
optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
else:
optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)

optimizer.add_param_group({'params': g1, 'weight_decay': hyp['weight_decay']}) # add g1 with weight_decay
optimizer.add_param_group({'params': g2}) # add g2 (biases)
optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']}) # add g0 with weight_decay
optimizer.add_param_group({'params': g[1]}) # add g1 (BatchNorm2d weights)
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
f"{len(g0)} weight (no decay), {len(g1)} weight, {len(g2)} bias")
del g0, g1, g2
f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
del g

# Scheduler
if opt.cos_lr:

Loading…
Cancel
Save