|
|
@@ -20,8 +20,9 @@ except: |
|
|
|
|
|
|
|
|
|
|
|
# Hyperparameters |
|
|
|
hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) |
|
|
|
'momentum': 0.937, # SGD momentum |
|
|
|
hyp = {'optimizer': 'adam' #if none, default is SGD |
|
|
|
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) |
|
|
|
'momentum': 0.937, # SGD momentum/Adam beta1 |
|
|
|
'weight_decay': 5e-4, # optimizer weight decay |
|
|
|
'giou': 0.05, # giou loss gain |
|
|
|
'cls': 0.58, # cls loss gain |
|
|
@@ -90,8 +91,11 @@ def train(hyp): |
|
|
|
else: |
|
|
|
pg0.append(v) # all else |
|
|
|
|
|
|
|
optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else \ |
|
|
|
optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) |
|
|
|
if hyp.optimizer =='adam': |
|
|
|
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) #use default beta2, adjust beta1 for Adam momentum per momentum adjustments in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR |
|
|
|
else: |
|
|
|
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) |
|
|
|
|
|
|
|
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay |
|
|
|
optimizer.add_param_group({'params': pg2}) # add pg2 (biases) |
|
|
|
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0))) |
|
|
@@ -380,7 +384,6 @@ if __name__ == '__main__': |
|
|
|
parser.add_argument('--weights', type=str, default='', help='initial weights path') |
|
|
|
parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') |
|
|
|
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') |
|
|
|
parser.add_argument('--adam', action='store_true', help='use adam optimizer') |
|
|
|
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%') |
|
|
|
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') |
|
|
|
parser.add_argument('--hyp', type=str, default='', help ='hyp cfg path [*.yaml].') |