You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

538 lines
27KB

  1. import argparse
  2. import glob
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import shutil
  8. import time
  9. from pathlib import Path
  10. import numpy as np
  11. import torch.distributed as dist
  12. import torch.nn.functional as F
  13. import torch.optim as optim
  14. import torch.optim.lr_scheduler as lr_scheduler
  15. import torch.utils.data
  16. import yaml
  17. from torch.cuda import amp
  18. from torch.nn.parallel import DistributedDataParallel as DDP
  19. from torch.utils.tensorboard import SummaryWriter
  20. from tqdm import tqdm
  21. import test # import test.py to get mAP after each epoch
  22. from models.yolo import Model
  23. from utils.datasets import create_dataloader
  24. from utils.general import (
  25. torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors, labels_to_image_weights,
  26. compute_loss, plot_images, fitness, strip_optimizer, plot_results, get_latest_run, check_dataset, check_file,
  27. check_git_status, check_img_size, increment_dir, print_mutation, plot_evolution, set_logging)
  28. from utils.google_utils import attempt_download
  29. from utils.torch_utils import init_seeds, ModelEMA, select_device, intersect_dicts
  30. logger = logging.getLogger(__name__)
  31. def train(hyp, opt, device, tb_writer=None):
  32. logger.info(f'Hyperparameters {hyp}')
  33. log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory
  34. wdir = log_dir / 'weights' # weights directory
  35. os.makedirs(wdir, exist_ok=True)
  36. last = wdir / 'last.pt'
  37. best = wdir / 'best.pt'
  38. results_file = str(log_dir / 'results.txt')
  39. epochs, batch_size, total_batch_size, weights, rank = \
  40. opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  41. # Save run settings
  42. with open(log_dir / 'hyp.yaml', 'w') as f:
  43. yaml.dump(hyp, f, sort_keys=False)
  44. with open(log_dir / 'opt.yaml', 'w') as f:
  45. yaml.dump(vars(opt), f, sort_keys=False)
  46. # Configure
  47. cuda = device.type != 'cpu'
  48. init_seeds(2 + rank)
  49. with open(opt.data) as f:
  50. data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict
  51. with torch_distributed_zero_first(rank):
  52. check_dataset(data_dict) # check
  53. train_path = data_dict['train']
  54. test_path = data_dict['val']
  55. nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names
  56. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  57. # Model
  58. pretrained = weights.endswith('.pt')
  59. if pretrained:
  60. with torch_distributed_zero_first(rank):
  61. attempt_download(weights) # download if not found locally
  62. ckpt = torch.load(weights, map_location=device) # load checkpoint
  63. # if hyp['anchors']:
  64. # ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
  65. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  66. exclude = ['anchor'] if opt.cfg else [] # exclude keys
  67. state_dict = ckpt['model'].float().state_dict() # to FP32
  68. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  69. model.load_state_dict(state_dict, strict=False) # load
  70. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  71. else:
  72. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  73. # Freeze
  74. freeze = ['', ] # parameter names to freeze (full or partial)
  75. if any(freeze):
  76. for k, v in model.named_parameters():
  77. if any(x in k for x in freeze):
  78. print('freezing %s' % k)
  79. v.requires_grad = False
  80. # Optimizer
  81. nbs = 64 # nominal batch size
  82. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  83. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  84. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  85. for k, v in model.named_parameters():
  86. v.requires_grad = True
  87. if '.bias' in k:
  88. pg2.append(v) # biases
  89. elif '.weight' in k and '.bn' not in k:
  90. pg1.append(v) # apply weight decay
  91. else:
  92. pg0.append(v) # all else
  93. if opt.adam:
  94. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  95. else:
  96. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  97. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  98. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  99. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  100. del pg0, pg1, pg2
  101. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  102. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  103. lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
  104. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  105. # plot_lr_scheduler(optimizer, scheduler, epochs)
  106. # Resume
  107. start_epoch, best_fitness = 0, 0.0
  108. if pretrained:
  109. # Optimizer
  110. if ckpt['optimizer'] is not None:
  111. optimizer.load_state_dict(ckpt['optimizer'])
  112. best_fitness = ckpt['best_fitness']
  113. # Results
  114. if ckpt.get('training_results') is not None:
  115. with open(results_file, 'w') as file:
  116. file.write(ckpt['training_results']) # write results.txt
  117. # Epochs
  118. start_epoch = ckpt['epoch'] + 1
  119. if opt.resume:
  120. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  121. shutil.copytree(wdir, wdir.parent / f'weights_backup_epoch{start_epoch - 1}') # save previous weights
  122. if epochs < start_epoch:
  123. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  124. (weights, ckpt['epoch'], epochs))
  125. epochs += ckpt['epoch'] # finetune additional epochs
  126. del ckpt, state_dict
  127. # Image sizes
  128. gs = int(max(model.stride)) # grid size (max stride)
  129. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  130. # DP mode
  131. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  132. model = torch.nn.DataParallel(model)
  133. # SyncBatchNorm
  134. if opt.sync_bn and cuda and rank != -1:
  135. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  136. logger.info('Using SyncBatchNorm()')
  137. # Exponential moving average
  138. ema = ModelEMA(model) if rank in [-1, 0] else None
  139. # DDP mode
  140. if cuda and rank != -1:
  141. model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank))
  142. # Trainloader
  143. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  144. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  145. world_size=opt.world_size, workers=opt.workers)
  146. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  147. nb = len(dataloader) # number of batches
  148. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  149. # Testloader
  150. if rank in [-1, 0]:
  151. ema.updates = start_epoch * nb // accumulate # set EMA updates
  152. testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
  153. hyp=hyp, augment=False, cache=opt.cache_images, rect=True, rank=-1,
  154. world_size=opt.world_size, workers=opt.workers)[0] # only runs on process 0
  155. # Model parameters
  156. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  157. model.nc = nc # attach number of classes to model
  158. model.hyp = hyp # attach hyperparameters to model
  159. model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
  160. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  161. model.names = names
  162. # Classes and Anchors
  163. if rank in [-1, 0] and not opt.resume:
  164. labels = np.concatenate(dataset.labels, 0)
  165. c = torch.tensor(labels[:, 0]) # classes
  166. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  167. # model._initialize_biases(cf.to(device))
  168. plot_labels(labels, save_dir=log_dir)
  169. if tb_writer:
  170. # tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
  171. tb_writer.add_histogram('classes', c, 0)
  172. # Anchors
  173. if not opt.noautoanchor:
  174. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  175. # Start training
  176. t0 = time.time()
  177. nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
  178. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  179. maps = np.zeros(nc) # mAP per class
  180. results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
  181. scheduler.last_epoch = start_epoch - 1 # do not move
  182. scaler = amp.GradScaler(enabled=cuda)
  183. logger.info('Image sizes %g train, %g test' % (imgsz, imgsz_test))
  184. logger.info('Using %g dataloader workers' % dataloader.num_workers)
  185. logger.info('Starting training for %g epochs...' % epochs)
  186. # torch.autograd.set_detect_anomaly(True)
  187. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  188. model.train()
  189. # Update image weights (optional)
  190. if opt.image_weights:
  191. # Generate indices
  192. if rank in [-1, 0]:
  193. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  194. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  195. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  196. # Broadcast if DDP
  197. if rank != -1:
  198. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  199. dist.broadcast(indices, 0)
  200. if rank != 0:
  201. dataset.indices = indices.cpu().numpy()
  202. # Update mosaic border
  203. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  204. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  205. mloss = torch.zeros(4, device=device) # mean losses
  206. if rank != -1:
  207. dataloader.sampler.set_epoch(epoch)
  208. pbar = enumerate(dataloader)
  209. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
  210. if rank in [-1, 0]:
  211. pbar = tqdm(pbar, total=nb) # progress bar
  212. optimizer.zero_grad()
  213. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  214. ni = i + nb * epoch # number integrated batches (since train start)
  215. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  216. # Warmup
  217. if ni <= nw:
  218. xi = [0, nw] # x interp
  219. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
  220. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  221. for j, x in enumerate(optimizer.param_groups):
  222. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  223. x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  224. if 'momentum' in x:
  225. x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
  226. # Multi-scale
  227. if opt.multi_scale:
  228. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  229. sf = sz / max(imgs.shape[2:]) # scale factor
  230. if sf != 1:
  231. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  232. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  233. # Forward
  234. with amp.autocast(enabled=cuda):
  235. pred = model(imgs) # forward
  236. loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
  237. if rank != -1:
  238. loss *= opt.world_size # gradient averaged between devices in DDP mode
  239. # Backward
  240. scaler.scale(loss).backward()
  241. # Optimize
  242. if ni % accumulate == 0:
  243. scaler.step(optimizer) # optimizer.step
  244. scaler.update()
  245. optimizer.zero_grad()
  246. if ema:
  247. ema.update(model)
  248. # Print
  249. if rank in [-1, 0]:
  250. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  251. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  252. s = ('%10s' * 2 + '%10.4g' * 6) % (
  253. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  254. pbar.set_description(s)
  255. # Plot
  256. if ni < 3:
  257. f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename
  258. result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
  259. if tb_writer and result is not None:
  260. tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  261. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  262. # end batch ------------------------------------------------------------------------------------------------
  263. # Scheduler
  264. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  265. scheduler.step()
  266. # DDP process 0 or single-GPU
  267. if rank in [-1, 0]:
  268. # mAP
  269. if ema:
  270. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
  271. final_epoch = epoch + 1 == epochs
  272. if not opt.notest or final_epoch: # Calculate mAP
  273. if final_epoch: # replot predictions
  274. [os.remove(x) for x in glob.glob(str(log_dir / 'test_batch*_pred.jpg')) if os.path.exists(x)]
  275. results, maps, times = test.test(opt.data,
  276. batch_size=total_batch_size,
  277. imgsz=imgsz_test,
  278. model=ema.ema,
  279. single_cls=opt.single_cls,
  280. dataloader=testloader,
  281. save_dir=log_dir)
  282. # Write
  283. with open(results_file, 'a') as f:
  284. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
  285. if len(opt.name) and opt.bucket:
  286. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  287. # Tensorboard
  288. if tb_writer:
  289. tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  290. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  291. 'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  292. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  293. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  294. tb_writer.add_scalar(tag, x, epoch)
  295. # Update best mAP
  296. fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
  297. if fi > best_fitness:
  298. best_fitness = fi
  299. # Save model
  300. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  301. if save:
  302. with open(results_file, 'r') as f: # create checkpoint
  303. ckpt = {'epoch': epoch,
  304. 'best_fitness': best_fitness,
  305. 'training_results': f.read(),
  306. 'model': ema.ema,
  307. 'optimizer': None if final_epoch else optimizer.state_dict()}
  308. # Save last, best and delete
  309. torch.save(ckpt, last)
  310. if best_fitness == fi:
  311. torch.save(ckpt, best)
  312. del ckpt
  313. # end epoch ----------------------------------------------------------------------------------------------------
  314. # end training
  315. if rank in [-1, 0]:
  316. # Strip optimizers
  317. n = opt.name if opt.name.isnumeric() else ''
  318. fresults, flast, fbest = log_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'
  319. for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]):
  320. if os.path.exists(f1):
  321. os.rename(f1, f2) # rename
  322. if str(f2).endswith('.pt'): # is *.pt
  323. strip_optimizer(f2) # strip optimizer
  324. os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None # upload
  325. # Finish
  326. if not opt.evolve:
  327. plot_results(save_dir=log_dir) # save as results.png
  328. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  329. dist.destroy_process_group() if rank not in [-1, 0] else None
  330. torch.cuda.empty_cache()
  331. return results
  332. if __name__ == '__main__':
  333. parser = argparse.ArgumentParser()
  334. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  335. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  336. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  337. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  338. parser.add_argument('--epochs', type=int, default=300)
  339. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  340. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  341. parser.add_argument('--rect', action='store_true', help='rectangular training')
  342. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  343. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  344. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  345. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  346. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  347. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  348. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  349. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  350. parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
  351. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  352. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  353. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  354. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  355. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  356. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  357. parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
  358. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  359. opt = parser.parse_args()
  360. # Set DDP variables
  361. opt.total_batch_size = opt.batch_size
  362. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  363. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  364. set_logging(opt.global_rank)
  365. if opt.global_rank in [-1, 0]:
  366. check_git_status()
  367. # Resume
  368. if opt.resume: # resume an interrupted run
  369. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  370. log_dir = Path(ckpt).parent.parent # runs/exp0
  371. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  372. with open(log_dir / 'opt.yaml') as f:
  373. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
  374. opt.cfg, opt.weights, opt.resume = '', ckpt, True
  375. logger.info('Resuming training from %s' % ckpt)
  376. else:
  377. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  378. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  379. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  380. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  381. log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
  382. device = select_device(opt.device, batch_size=opt.batch_size)
  383. # DDP mode
  384. if opt.local_rank != -1:
  385. assert torch.cuda.device_count() > opt.local_rank
  386. torch.cuda.set_device(opt.local_rank)
  387. device = torch.device('cuda', opt.local_rank)
  388. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  389. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  390. opt.batch_size = opt.total_batch_size // opt.world_size
  391. logger.info(opt)
  392. with open(opt.hyp) as f:
  393. hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
  394. # Train
  395. if not opt.evolve:
  396. tb_writer = None
  397. if opt.global_rank in [-1, 0]:
  398. logger.info('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir)
  399. tb_writer = SummaryWriter(log_dir=log_dir) # runs/exp0
  400. train(hyp, opt, device, tb_writer)
  401. # Evolve hyperparameters (optional)
  402. else:
  403. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  404. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  405. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  406. 'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
  407. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  408. 'giou': (1, 0.02, 0.2), # GIoU loss gain
  409. 'cls': (1, 0.2, 4.0), # cls loss gain
  410. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  411. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  412. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  413. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  414. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  415. # 'anchors': (1, 2.0, 10.0), # anchors per output grid (0 to ignore)
  416. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  417. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  418. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  419. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  420. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  421. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  422. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  423. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  424. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  425. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  426. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  427. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  428. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  429. opt.notest, opt.nosave = True, True # only test/save final epoch
  430. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  431. yaml_file = Path('runs/evolve/hyp_evolved.yaml') # save best result here
  432. if opt.bucket:
  433. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  434. for _ in range(100): # generations to evolve
  435. if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
  436. # Select parent(s)
  437. parent = 'single' # parent selection method: 'single' or 'weighted'
  438. x = np.loadtxt('evolve.txt', ndmin=2)
  439. n = min(5, len(x)) # number of previous results to consider
  440. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  441. w = fitness(x) - fitness(x).min() # weights
  442. if parent == 'single' or len(x) == 1:
  443. # x = x[random.randint(0, n - 1)] # random selection
  444. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  445. elif parent == 'weighted':
  446. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  447. # Mutate
  448. mp, s = 0.9, 0.2 # mutation probability, sigma
  449. npr = np.random
  450. npr.seed(int(time.time()))
  451. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  452. ng = len(meta)
  453. v = np.ones(ng)
  454. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  455. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  456. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  457. hyp[k] = float(x[i + 7] * v[i]) # mutate
  458. # Constrain to limits
  459. for k, v in meta.items():
  460. hyp[k] = max(hyp[k], v[1]) # lower limit
  461. hyp[k] = min(hyp[k], v[2]) # upper limit
  462. hyp[k] = round(hyp[k], 5) # significant digits
  463. # Train mutation
  464. results = train(hyp.copy(), opt, device)
  465. # Write mutation results
  466. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  467. # Plot results
  468. plot_evolution(yaml_file)
  469. print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these '
  470. 'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file))