Enable AdamW optimizer (#6152)
This commit is contained in:
parent
d95978a562
commit
e1dc894364
8
train.py
8
train.py
|
|
@ -22,7 +22,7 @@ import torch.nn as nn
|
|||
import yaml
|
||||
from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import SGD, Adam, lr_scheduler
|
||||
from torch.optim import SGD, Adam, AdamW, lr_scheduler
|
||||
from tqdm import tqdm
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
|
|
@ -155,8 +155,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|||
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
||||
g1.append(v.weight)
|
||||
|
||||
if opt.adam:
|
||||
if opt.optimizer == 'Adam':
|
||||
optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
||||
elif opt.optimizer == 'AdamW':
|
||||
optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
||||
else:
|
||||
optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
||||
|
||||
|
|
@ -460,7 +462,7 @@ def parse_opt(known=False):
|
|||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
|
||||
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
|
||||
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
|
||||
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
|
||||
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
||||
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
|
||||
parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
|
||||
|
|
|
|||
Loading…
Reference in New Issue