You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

533 lines
27KB

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import shutil
  7. import time
  8. from pathlib import Path
  9. import numpy as np
  10. import torch.distributed as dist
  11. import torch.nn.functional as F
  12. import torch.optim as optim
  13. import torch.optim.lr_scheduler as lr_scheduler
  14. import torch.utils.data
  15. import yaml
  16. from torch.cuda import amp
  17. from torch.nn.parallel import DistributedDataParallel as DDP
  18. from torch.utils.tensorboard import SummaryWriter
  19. from tqdm import tqdm
  20. import test # import test.py to get mAP after each epoch
  21. from models.yolo import Model
  22. from utils.datasets import create_dataloader
  23. from utils.general import (
  24. torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors, labels_to_image_weights,
  25. compute_loss, plot_images, fitness, strip_optimizer, plot_results, get_latest_run, check_dataset, check_file,
  26. check_git_status, check_img_size, increment_dir, print_mutation, plot_evolution, set_logging)
  27. from utils.google_utils import attempt_download
  28. from utils.torch_utils import init_seeds, ModelEMA, select_device, intersect_dicts
  29. logger = logging.getLogger(__name__)
  30. def train(hyp, opt, device, tb_writer=None):
  31. logger.info(f'Hyperparameters {hyp}')
  32. log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory
  33. wdir = log_dir / 'weights' # weights directory
  34. os.makedirs(wdir, exist_ok=True)
  35. last = wdir / 'last.pt'
  36. best = wdir / 'best.pt'
  37. results_file = str(log_dir / 'results.txt')
  38. epochs, batch_size, total_batch_size, weights, rank = \
  39. opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  40. # Save run settings
  41. with open(log_dir / 'hyp.yaml', 'w') as f:
  42. yaml.dump(hyp, f, sort_keys=False)
  43. with open(log_dir / 'opt.yaml', 'w') as f:
  44. yaml.dump(vars(opt), f, sort_keys=False)
  45. # Configure
  46. cuda = device.type != 'cpu'
  47. init_seeds(2 + rank)
  48. with open(opt.data) as f:
  49. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  50. with torch_distributed_zero_first(rank):
  51. check_dataset(data_dict) # check
  52. train_path = data_dict['train']
  53. test_path = data_dict['val']
  54. nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names
  55. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  56. # Model
  57. pretrained = weights.endswith('.pt')
  58. if pretrained:
  59. with torch_distributed_zero_first(rank):
  60. attempt_download(weights) # download if not found locally
  61. ckpt = torch.load(weights, map_location=device) # load checkpoint
  62. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  63. exclude = ['anchor'] if opt.cfg else [] # exclude keys
  64. state_dict = ckpt['model'].float().state_dict() # to FP32
  65. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  66. model.load_state_dict(state_dict, strict=False) # load
  67. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  68. else:
  69. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  70. # Freeze
  71. freeze = ['', ] # parameter names to freeze (full or partial)
  72. if any(freeze):
  73. for k, v in model.named_parameters():
  74. if any(x in k for x in freeze):
  75. print('freezing %s' % k)
  76. v.requires_grad = False
  77. # Optimizer
  78. nbs = 64 # nominal batch size
  79. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  80. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  81. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  82. for k, v in model.named_parameters():
  83. v.requires_grad = True
  84. if '.bias' in k:
  85. pg2.append(v) # biases
  86. elif '.weight' in k and '.bn' not in k:
  87. pg1.append(v) # apply weight decay
  88. else:
  89. pg0.append(v) # all else
  90. if opt.adam:
  91. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  92. else:
  93. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  94. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  95. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  96. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  97. del pg0, pg1, pg2
  98. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  99. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  100. lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
  101. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  102. # plot_lr_scheduler(optimizer, scheduler, epochs)
  103. # Resume
  104. start_epoch, best_fitness = 0, 0.0
  105. if pretrained:
  106. # Optimizer
  107. if ckpt['optimizer'] is not None:
  108. optimizer.load_state_dict(ckpt['optimizer'])
  109. best_fitness = ckpt['best_fitness']
  110. # Results
  111. if ckpt.get('training_results') is not None:
  112. with open(results_file, 'w') as file:
  113. file.write(ckpt['training_results']) # write results.txt
  114. # Epochs
  115. start_epoch = ckpt['epoch'] + 1
  116. if opt.resume:
  117. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  118. shutil.copytree(wdir, wdir.parent / f'weights_backup_epoch{start_epoch - 1}') # save previous weights
  119. if epochs < start_epoch:
  120. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  121. (weights, ckpt['epoch'], epochs))
  122. epochs += ckpt['epoch'] # finetune additional epochs
  123. del ckpt, state_dict
  124. # Image sizes
  125. gs = int(max(model.stride)) # grid size (max stride)
  126. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  127. # DP mode
  128. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  129. model = torch.nn.DataParallel(model)
  130. # SyncBatchNorm
  131. if opt.sync_bn and cuda and rank != -1:
  132. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  133. logger.info('Using SyncBatchNorm()')
  134. # Exponential moving average
  135. ema = ModelEMA(model) if rank in [-1, 0] else None
  136. # DDP mode
  137. if cuda and rank != -1:
  138. model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank))
  139. # Trainloader
  140. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  141. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  142. world_size=opt.world_size, workers=opt.workers)
  143. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  144. nb = len(dataloader) # number of batches
  145. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  146. # Testloader
  147. if rank in [-1, 0]:
  148. ema.updates = start_epoch * nb // accumulate # set EMA updates
  149. testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
  150. hyp=hyp, augment=False, cache=opt.cache_images, rect=True, rank=-1,
  151. world_size=opt.world_size, workers=opt.workers)[0] # only runs on process 0
  152. # Model parameters
  153. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  154. model.nc = nc # attach number of classes to model
  155. model.hyp = hyp # attach hyperparameters to model
  156. model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
  157. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  158. model.names = names
  159. # Class frequency
  160. if rank in [-1, 0]:
  161. labels = np.concatenate(dataset.labels, 0)
  162. c = torch.tensor(labels[:, 0]) # classes
  163. # cf = torch.bincount(c.long(), minlength=nc) + 1.
  164. # model._initialize_biases(cf.to(device))
  165. plot_labels(labels, save_dir=log_dir)
  166. if tb_writer:
  167. # tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
  168. tb_writer.add_histogram('classes', c, 0)
  169. # Check anchors
  170. if not opt.noautoanchor:
  171. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  172. # Start training
  173. t0 = time.time()
  174. nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
  175. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  176. maps = np.zeros(nc) # mAP per class
  177. results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
  178. scheduler.last_epoch = start_epoch - 1 # do not move
  179. scaler = amp.GradScaler(enabled=cuda)
  180. logger.info('Image sizes %g train, %g test' % (imgsz, imgsz_test))
  181. logger.info('Using %g dataloader workers' % dataloader.num_workers)
  182. logger.info('Starting training for %g epochs...' % epochs)
  183. # torch.autograd.set_detect_anomaly(True)
  184. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  185. model.train()
  186. # Update image weights (optional)
  187. if dataset.image_weights:
  188. # Generate indices
  189. if rank in [-1, 0]:
  190. w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  191. image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
  192. dataset.indices = random.choices(range(dataset.n), weights=image_weights,
  193. k=dataset.n) # rand weighted idx
  194. # Broadcast if DDP
  195. if rank != -1:
  196. indices = torch.zeros([dataset.n], dtype=torch.int)
  197. if rank == 0:
  198. indices[:] = torch.tensor(dataset.indices, dtype=torch.int)
  199. dist.broadcast(indices, 0)
  200. if rank != 0:
  201. dataset.indices = indices.cpu().numpy()
  202. # Update mosaic border
  203. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  204. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  205. mloss = torch.zeros(4, device=device) # mean losses
  206. if rank != -1:
  207. dataloader.sampler.set_epoch(epoch)
  208. pbar = enumerate(dataloader)
  209. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
  210. if rank in [-1, 0]:
  211. pbar = tqdm(pbar, total=nb) # progress bar
  212. optimizer.zero_grad()
  213. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  214. ni = i + nb * epoch # number integrated batches (since train start)
  215. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  216. # Warmup
  217. if ni <= nw:
  218. xi = [0, nw] # x interp
  219. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
  220. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  221. for j, x in enumerate(optimizer.param_groups):
  222. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  223. x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  224. if 'momentum' in x:
  225. x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
  226. # Multi-scale
  227. if opt.multi_scale:
  228. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  229. sf = sz / max(imgs.shape[2:]) # scale factor
  230. if sf != 1:
  231. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  232. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  233. # Forward
  234. with amp.autocast(enabled=cuda):
  235. pred = model(imgs) # forward
  236. loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
  237. if rank != -1:
  238. loss *= opt.world_size # gradient averaged between devices in DDP mode
  239. # Backward
  240. scaler.scale(loss).backward()
  241. # Optimize
  242. if ni % accumulate == 0:
  243. scaler.step(optimizer) # optimizer.step
  244. scaler.update()
  245. optimizer.zero_grad()
  246. if ema:
  247. ema.update(model)
  248. # Print
  249. if rank in [-1, 0]:
  250. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  251. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  252. s = ('%10s' * 2 + '%10.4g' * 6) % (
  253. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  254. pbar.set_description(s)
  255. # Plot
  256. if ni < 3:
  257. f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename
  258. result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
  259. if tb_writer and result is not None:
  260. tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  261. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  262. # end batch ------------------------------------------------------------------------------------------------
  263. # Scheduler
  264. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  265. scheduler.step()
  266. # DDP process 0 or single-GPU
  267. if rank in [-1, 0]:
  268. # mAP
  269. if ema:
  270. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
  271. final_epoch = epoch + 1 == epochs
  272. if not opt.notest or final_epoch: # Calculate mAP
  273. results, maps, times = test.test(opt.data,
  274. batch_size=total_batch_size,
  275. imgsz=imgsz_test,
  276. model=ema.ema,
  277. single_cls=opt.single_cls,
  278. dataloader=testloader,
  279. save_dir=log_dir)
  280. # Write
  281. with open(results_file, 'a') as f:
  282. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
  283. if len(opt.name) and opt.bucket:
  284. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  285. # Tensorboard
  286. if tb_writer:
  287. tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  288. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  289. 'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  290. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  291. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  292. tb_writer.add_scalar(tag, x, epoch)
  293. # Update best mAP
  294. fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
  295. if fi > best_fitness:
  296. best_fitness = fi
  297. # Save model
  298. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  299. if save:
  300. with open(results_file, 'r') as f: # create checkpoint
  301. ckpt = {'epoch': epoch,
  302. 'best_fitness': best_fitness,
  303. 'training_results': f.read(),
  304. 'model': ema.ema,
  305. 'optimizer': None if final_epoch else optimizer.state_dict()}
  306. # Save last, best and delete
  307. torch.save(ckpt, last)
  308. if best_fitness == fi:
  309. torch.save(ckpt, best)
  310. del ckpt
  311. # end epoch ----------------------------------------------------------------------------------------------------
  312. # end training
  313. if rank in [-1, 0]:
  314. # Strip optimizers
  315. n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
  316. fresults, flast, fbest = 'results%s.txt' % n, wdir / f'last{n}.pt', wdir / f'best{n}.pt'
  317. for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', 'results.txt'], [flast, fbest, fresults]):
  318. if os.path.exists(f1):
  319. os.rename(f1, f2) # rename
  320. if str(f2).endswith('.pt'): # is *.pt
  321. strip_optimizer(f2) # strip optimizer
  322. os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None # upload
  323. # Finish
  324. if not opt.evolve:
  325. plot_results(save_dir=log_dir) # save as results.png
  326. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  327. dist.destroy_process_group() if rank not in [-1, 0] else None
  328. torch.cuda.empty_cache()
  329. return results
  330. if __name__ == '__main__':
  331. parser = argparse.ArgumentParser()
  332. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  333. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  334. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  335. parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml')
  336. parser.add_argument('--epochs', type=int, default=300)
  337. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  338. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
  339. parser.add_argument('--rect', action='store_true', help='rectangular training')
  340. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  341. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  342. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  343. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  344. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  345. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  346. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  347. parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
  348. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  349. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  350. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  351. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  352. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  353. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  354. parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
  355. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  356. opt = parser.parse_args()
  357. # Set DDP variables
  358. opt.total_batch_size = opt.batch_size
  359. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  360. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  361. set_logging(opt.global_rank)
  362. if opt.global_rank in [-1, 0]:
  363. check_git_status()
  364. # Resume
  365. if opt.resume: # resume an interrupted run
  366. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  367. log_dir = Path(ckpt).parent.parent # runs/exp0
  368. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  369. with open(log_dir / 'opt.yaml') as f:
  370. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
  371. opt.cfg, opt.weights, opt.resume = '', ckpt, True
  372. logger.info('Resuming training from %s' % ckpt)
  373. else:
  374. opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
  375. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  376. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  377. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  378. log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
  379. device = select_device(opt.device, batch_size=opt.batch_size)
  380. # DDP mode
  381. if opt.local_rank != -1:
  382. assert torch.cuda.device_count() > opt.local_rank
  383. torch.cuda.set_device(opt.local_rank)
  384. device = torch.device('cuda', opt.local_rank)
  385. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  386. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  387. opt.batch_size = opt.total_batch_size // opt.world_size
  388. logger.info(opt)
  389. with open(opt.hyp) as f:
  390. hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
  391. # Train
  392. if not opt.evolve:
  393. tb_writer = None
  394. if opt.global_rank in [-1, 0]:
  395. logger.info('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir)
  396. tb_writer = SummaryWriter(log_dir=log_dir) # runs/exp0
  397. train(hyp, opt, device, tb_writer)
  398. # Evolve hyperparameters (optional)
  399. else:
  400. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  401. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  402. 'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
  403. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  404. 'giou': (1, 0.02, 0.2), # GIoU loss gain
  405. 'cls': (1, 0.2, 4.0), # cls loss gain
  406. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  407. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  408. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  409. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  410. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  411. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  412. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  413. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  414. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  415. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  416. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  417. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  418. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  419. 'perspective': (1, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  420. 'flipud': (0, 0.0, 1.0), # image flip up-down (probability)
  421. 'fliplr': (1, 0.0, 1.0), # image flip left-right (probability)
  422. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  423. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  424. opt.notest, opt.nosave = True, True # only test/save final epoch
  425. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  426. yaml_file = Path('runs/evolve/hyp_evolved.yaml') # save best result here
  427. if opt.bucket:
  428. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  429. for _ in range(100): # generations to evolve
  430. if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
  431. # Select parent(s)
  432. parent = 'single' # parent selection method: 'single' or 'weighted'
  433. x = np.loadtxt('evolve.txt', ndmin=2)
  434. n = min(5, len(x)) # number of previous results to consider
  435. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  436. w = fitness(x) - fitness(x).min() # weights
  437. if parent == 'single' or len(x) == 1:
  438. # x = x[random.randint(0, n - 1)] # random selection
  439. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  440. elif parent == 'weighted':
  441. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  442. # Mutate
  443. mp, s = 0.9, 0.2 # mutation probability, sigma
  444. npr = np.random
  445. npr.seed(int(time.time()))
  446. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  447. ng = len(meta)
  448. v = np.ones(ng)
  449. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  450. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  451. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  452. hyp[k] = float(x[i + 7] * v[i]) # mutate
  453. # Constrain to limits
  454. for k, v in meta.items():
  455. hyp[k] = max(hyp[k], v[1]) # lower limit
  456. hyp[k] = min(hyp[k], v[2]) # upper limit
  457. hyp[k] = round(hyp[k], 5) # significant digits
  458. # Train mutation
  459. results = train(hyp.copy(), opt, device)
  460. # Write mutation results
  461. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  462. # Plot results
  463. plot_evolution(yaml_file)
  464. print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these '
  465. 'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file))