|
|
@@ -16,8 +16,7 @@ from utils.datasets import * |
|
|
|
from utils.utils import * |
|
|
|
|
|
|
|
# Hyperparameters |
|
|
|
hyp = {'optimizer': 'SGD', # ['Adam', 'SGD', ...] from torch.optim |
|
|
|
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) |
|
|
|
hyp = {'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3) |
|
|
|
'momentum': 0.937, # SGD momentum/Adam beta1 |
|
|
|
'weight_decay': 5e-4, # optimizer weight decay |
|
|
|
'giou': 0.05, # GIoU loss gain |
|
|
@@ -41,7 +40,7 @@ hyp = {'optimizer': 'SGD', # ['Adam', 'SGD', ...] from torch.optim |
|
|
|
'mixup': 0.0} # image mixup (probability) |
|
|
|
|
|
|
|
|
|
|
|
def train(hyp, tb_writer, opt, device): |
|
|
|
def train(hyp, opt, device, tb_writer=None): |
|
|
|
print(f'Hyperparameters {hyp}') |
|
|
|
log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory |
|
|
|
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory |
|
|
@@ -102,7 +101,7 @@ def train(hyp, tb_writer, opt, device): |
|
|
|
else: |
|
|
|
pg0.append(v) # all else |
|
|
|
|
|
|
|
if hyp['optimizer'] == 'Adam': |
|
|
|
if opt.adam: |
|
|
|
optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum |
|
|
|
else: |
|
|
|
optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) |
|
|
@@ -279,7 +278,7 @@ def train(hyp, tb_writer, opt, device): |
|
|
|
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) |
|
|
|
|
|
|
|
# Autocast |
|
|
|
with amp.autocast(): |
|
|
|
with amp.autocast(enabled=cuda): |
|
|
|
# Forward |
|
|
|
pred = model(imgs) |
|
|
|
|
|
|
@@ -402,11 +401,11 @@ if __name__ == '__main__': |
|
|
|
parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') |
|
|
|
parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)') |
|
|
|
parser.add_argument('--epochs', type=int, default=300) |
|
|
|
parser.add_argument('--batch-size', type=int, default=16, help="Total batch size for all gpus.") |
|
|
|
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs') |
|
|
|
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes') |
|
|
|
parser.add_argument('--rect', action='store_true', help='rectangular training') |
|
|
|
parser.add_argument('--resume', nargs='?', const='get_last', default=False, |
|
|
|
help='resume from given path/to/last.pt, or most recent run if blank.') |
|
|
|
help='resume from given path/last.pt, or most recent run if blank') |
|
|
|
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') |
|
|
|
parser.add_argument('--notest', action='store_true', help='only test final epoch') |
|
|
|
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') |
|
|
@@ -418,6 +417,7 @@ if __name__ == '__main__': |
|
|
|
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') |
|
|
|
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') |
|
|
|
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') |
|
|
|
parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') |
|
|
|
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') |
|
|
|
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') |
|
|
|
opt = parser.parse_args() |
|
|
@@ -445,30 +445,52 @@ if __name__ == '__main__': |
|
|
|
if opt.local_rank != -1: |
|
|
|
assert torch.cuda.device_count() > opt.local_rank |
|
|
|
torch.cuda.set_device(opt.local_rank) |
|
|
|
device = torch.device("cuda", opt.local_rank) |
|
|
|
device = torch.device('cuda', opt.local_rank) |
|
|
|
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend |
|
|
|
opt.world_size = dist.get_world_size() |
|
|
|
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!" |
|
|
|
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count' |
|
|
|
opt.batch_size = opt.total_batch_size // opt.world_size |
|
|
|
|
|
|
|
print(opt) |
|
|
|
|
|
|
|
# Train |
|
|
|
if not opt.evolve: |
|
|
|
tb_writer = None |
|
|
|
if opt.local_rank in [-1, 0]: |
|
|
|
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') |
|
|
|
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name)) |
|
|
|
else: |
|
|
|
tb_writer = None |
|
|
|
|
|
|
|
train(hyp, tb_writer, opt, device) |
|
|
|
train(hyp, opt, device, tb_writer) |
|
|
|
|
|
|
|
# Evolve hyperparameters (optional) |
|
|
|
else: |
|
|
|
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' |
|
|
|
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit) |
|
|
|
meta = {'lr0': (1, 1e-5, 1e-2), # initial learning rate (SGD=1E-2, Adam=1E-3) |
|
|
|
'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1 |
|
|
|
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay |
|
|
|
'giou': (1, 0.02, 0.2), # GIoU loss gain |
|
|
|
'cls': (1, 0.2, 4.0), # cls loss gain |
|
|
|
'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight |
|
|
|
'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels) |
|
|
|
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight |
|
|
|
'iou_t': (0, 0.1, 0.7), # IoU training threshold |
|
|
|
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold |
|
|
|
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) |
|
|
|
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction) |
|
|
|
'hsv_s': (1, 0.0, 0.8), # image HSV-Saturation augmentation (fraction) |
|
|
|
'hsv_v': (1, 0.0, 0.8), # image HSV-Value augmentation (fraction) |
|
|
|
'degrees': (1, 0.0, 45.0), # image rotation (+/- deg) |
|
|
|
'translate': (1, 0.0, 0.9), # image translation (+/- fraction) |
|
|
|
'scale': (1, 0.0, 0.9), # image scale (+/- gain) |
|
|
|
'shear': (1, 0.0, 10.0), # image shear (+/- deg) |
|
|
|
'perspective': (1, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 |
|
|
|
'flipud': (0, 0.0, 1.0), # image flip up-down (probability) |
|
|
|
'fliplr': (1, 0.0, 1.0), # image flip left-right (probability) |
|
|
|
'mixup': (1, 0.0, 1.0)} # image mixup (probability) |
|
|
|
|
|
|
|
tb_writer = None |
|
|
|
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve' |
|
|
|
opt.notest, opt.nosave = True, True # only test/save final epoch |
|
|
|
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices |
|
|
|
if opt.bucket: |
|
|
|
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists |
|
|
|
|
|
|
@@ -490,8 +512,8 @@ if __name__ == '__main__': |
|
|
|
mp, s = 0.9, 0.2 # mutation probability, sigma |
|
|
|
npr = np.random |
|
|
|
npr.seed(int(time.time())) |
|
|
|
g = np.array([1, 1, 1, 1, 1, 1, 1, 0, .1, 1, 0, 1, 1, 1, 1, 1, 1, 1]) # gains |
|
|
|
ng = len(g) |
|
|
|
g = np.array([x[0] for x in meta.values()]) # gains 0-1 |
|
|
|
ng = len(meta) |
|
|
|
v = np.ones(ng) |
|
|
|
while all(v == 1): # mutate until a change occurs (prevent duplicates) |
|
|
|
v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0) |
|
|
@@ -499,13 +521,11 @@ if __name__ == '__main__': |
|
|
|
hyp[k] = x[i + 7] * v[i] # mutate |
|
|
|
|
|
|
|
# Clip to limits |
|
|
|
keys = ['lr0', 'iou_t', 'momentum', 'weight_decay', 'hsv_s', 'hsv_v', 'translate', 'scale', 'fl_gamma'] |
|
|
|
limits = [(1e-5, 1e-2), (0.00, 0.70), (0.60, 0.98), (0, 0.001), (0, .9), (0, .9), (0, .9), (0, .9), (0, 3)] |
|
|
|
for k, v in zip(keys, limits): |
|
|
|
hyp[k] = np.clip(hyp[k], v[0], v[1]) |
|
|
|
for k, v in meta.items(): |
|
|
|
hyp[k] = np.clip(hyp[k], v[1], v[2]) |
|
|
|
|
|
|
|
# Train mutation |
|
|
|
results = train(hyp.copy(), tb_writer, opt, device) |
|
|
|
results = train(hyp.copy(), opt, device) |
|
|
|
|
|
|
|
# Write mutation results |
|
|
|
print_mutation(hyp, results, opt.bucket) |