Enable AdamW optimizer (#6152)
This commit is contained in:
parent
d95978a562
commit
e1dc894364
8
train.py
8
train.py
|
|
@ -22,7 +22,7 @@ import torch.nn as nn
|
||||||
import yaml
|
import yaml
|
||||||
from torch.cuda import amp
|
from torch.cuda import amp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import SGD, Adam, lr_scheduler
|
from torch.optim import SGD, Adam, AdamW, lr_scheduler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
FILE = Path(__file__).resolve()
|
FILE = Path(__file__).resolve()
|
||||||
|
|
@ -155,8 +155,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
||||||
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
||||||
g1.append(v.weight)
|
g1.append(v.weight)
|
||||||
|
|
||||||
if opt.adam:
|
if opt.optimizer == 'Adam':
|
||||||
optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
||||||
|
elif opt.optimizer == 'AdamW':
|
||||||
|
optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
||||||
else:
|
else:
|
||||||
optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
||||||
|
|
||||||
|
|
@ -460,7 +462,7 @@ def parse_opt(known=False):
|
||||||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
||||||
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
|
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
|
||||||
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
|
parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
|
||||||
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
|
parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
|
||||||
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
||||||
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
|
parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
|
||||||
parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
|
parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue