Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

517 lines
26KB

  1. import argparse
  2. import math
  3. import os
  4. import random
  5. import time
  6. from pathlib import Path
  7. import numpy as np
  8. import torch.distributed as dist
  9. import torch.nn.functional as F
  10. import torch.optim as optim
  11. import torch.optim.lr_scheduler as lr_scheduler
  12. import torch.utils.data
  13. import yaml
  14. from torch.cuda import amp
  15. from torch.nn.parallel import DistributedDataParallel as DDP
  16. from torch.utils.tensorboard import SummaryWriter
  17. from tqdm import tqdm
  18. import test # import test.py to get mAP after each epoch
  19. from models.yolo import Model
  20. from utils.datasets import create_dataloader
  21. from utils.general import (
  22. torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors, labels_to_image_weights,
  23. compute_loss, plot_images, fitness, strip_optimizer, plot_results, get_latest_run, check_dataset, check_file,
  24. check_git_status, check_img_size, increment_dir, print_mutation, plot_evolution)
  25. from utils.google_utils import attempt_download
  26. from utils.torch_utils import init_seeds, ModelEMA, select_device, intersect_dicts
  27. def train(hyp, opt, device, tb_writer=None):
  28. print(f'Hyperparameters {hyp}')
  29. log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory
  30. wdir = str(log_dir / 'weights') + os.sep # weights directory
  31. os.makedirs(wdir, exist_ok=True)
  32. last = wdir + 'last.pt'
  33. best = wdir + 'best.pt'
  34. results_file = str(log_dir / 'results.txt')
  35. epochs, batch_size, total_batch_size, weights, rank = \
  36. opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  37. # TODO: Use DDP logging. Only the first process is allowed to log.
  38. # Save run settings
  39. with open(log_dir / 'hyp.yaml', 'w') as f:
  40. yaml.dump(hyp, f, sort_keys=False)
  41. with open(log_dir / 'opt.yaml', 'w') as f:
  42. yaml.dump(vars(opt), f, sort_keys=False)
  43. # Configure
  44. cuda = device.type != 'cpu'
  45. init_seeds(2 + rank)
  46. with open(opt.data) as f:
  47. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  48. with torch_distributed_zero_first(rank):
  49. check_dataset(data_dict) # check
  50. train_path = data_dict['train']
  51. test_path = data_dict['val']
  52. nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names
  53. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  54. # Model
  55. pretrained = weights.endswith('.pt')
  56. if pretrained:
  57. with torch_distributed_zero_first(rank):
  58. attempt_download(weights) # download if not found locally
  59. ckpt = torch.load(weights, map_location=device) # load checkpoint
  60. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  61. exclude = ['anchor'] if opt.cfg else [] # exclude keys
  62. state_dict = ckpt['model'].float().state_dict() # to FP32
  63. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  64. model.load_state_dict(state_dict, strict=False) # load
  65. print('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  66. else:
  67. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  68. # Optimizer
  69. nbs = 64 # nominal batch size
  70. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  71. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  72. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  73. for k, v in model.named_parameters():
  74. v.requires_grad = True
  75. if '.bias' in k:
  76. pg2.append(v) # biases
  77. elif '.weight' in k and '.bn' not in k:
  78. pg1.append(v) # apply weight decay
  79. else:
  80. pg0.append(v) # all else
  81. if opt.adam:
  82. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  83. else:
  84. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  85. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  86. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  87. print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  88. del pg0, pg1, pg2
  89. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  90. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  91. lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
  92. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  93. # plot_lr_scheduler(optimizer, scheduler, epochs)
  94. # Resume
  95. start_epoch, best_fitness = 0, 0.0
  96. if pretrained:
  97. # Optimizer
  98. if ckpt['optimizer'] is not None:
  99. optimizer.load_state_dict(ckpt['optimizer'])
  100. best_fitness = ckpt['best_fitness']
  101. # Results
  102. if ckpt.get('training_results') is not None:
  103. with open(results_file, 'w') as file:
  104. file.write(ckpt['training_results']) # write results.txt
  105. # Epochs
  106. start_epoch = ckpt['epoch'] + 1
  107. if epochs < start_epoch:
  108. print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  109. (weights, ckpt['epoch'], epochs))
  110. epochs += ckpt['epoch'] # finetune additional epochs
  111. del ckpt, state_dict
  112. # Image sizes
  113. gs = int(max(model.stride)) # grid size (max stride)
  114. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  115. # DP mode
  116. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  117. model = torch.nn.DataParallel(model)
  118. # SyncBatchNorm
  119. if opt.sync_bn and cuda and rank != -1:
  120. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  121. print('Using SyncBatchNorm()')
  122. # Exponential moving average
  123. ema = ModelEMA(model) if rank in [-1, 0] else None
  124. # DDP mode
  125. if cuda and rank != -1:
  126. model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank))
  127. # Trainloader
  128. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
  129. cache=opt.cache_images, rect=opt.rect, local_rank=rank,
  130. world_size=opt.world_size)
  131. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  132. nb = len(dataloader) # number of batches
  133. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  134. # Testloader
  135. if rank in [-1, 0]:
  136. # local_rank is set to -1. Because only the first process is expected to do evaluation.
  137. testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
  138. cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0]
  139. # Model parameters
  140. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  141. model.nc = nc # attach number of classes to model
  142. model.hyp = hyp # attach hyperparameters to model
  143. model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
  144. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  145. model.names = names
  146. # Class frequency
  147. if rank in [-1, 0]:
  148. labels = np.concatenate(dataset.labels, 0)
  149. c = torch.tensor(labels[:, 0]) # classes
  150. # cf = torch.bincount(c.long(), minlength=nc) + 1.
  151. # model._initialize_biases(cf.to(device))
  152. plot_labels(labels, save_dir=log_dir)
  153. if tb_writer:
  154. # tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
  155. tb_writer.add_histogram('classes', c, 0)
  156. # Check anchors
  157. if not opt.noautoanchor:
  158. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  159. # Start training
  160. t0 = time.time()
  161. nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
  162. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  163. maps = np.zeros(nc) # mAP per class
  164. results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
  165. scheduler.last_epoch = start_epoch - 1 # do not move
  166. scaler = amp.GradScaler(enabled=cuda)
  167. if rank in [0, -1]:
  168. print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
  169. print('Using %g dataloader workers' % dataloader.num_workers)
  170. print('Starting training for %g epochs...' % epochs)
  171. # torch.autograd.set_detect_anomaly(True)
  172. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  173. model.train()
  174. # Update image weights (optional)
  175. if dataset.image_weights:
  176. # Generate indices
  177. if rank in [-1, 0]:
  178. w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  179. image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
  180. dataset.indices = random.choices(range(dataset.n), weights=image_weights,
  181. k=dataset.n) # rand weighted idx
  182. # Broadcast if DDP
  183. if rank != -1:
  184. indices = torch.zeros([dataset.n], dtype=torch.int)
  185. if rank == 0:
  186. indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int)
  187. dist.broadcast(indices, 0)
  188. if rank != 0:
  189. dataset.indices = indices.cpu().numpy()
  190. # Update mosaic border
  191. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  192. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  193. mloss = torch.zeros(4, device=device) # mean losses
  194. if rank != -1:
  195. dataloader.sampler.set_epoch(epoch)
  196. pbar = enumerate(dataloader)
  197. if rank in [-1, 0]:
  198. print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
  199. pbar = tqdm(pbar, total=nb) # progress bar
  200. optimizer.zero_grad()
  201. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  202. ni = i + nb * epoch # number integrated batches (since train start)
  203. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  204. # Warmup
  205. if ni <= nw:
  206. xi = [0, nw] # x interp
  207. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
  208. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  209. for j, x in enumerate(optimizer.param_groups):
  210. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  211. x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  212. if 'momentum' in x:
  213. x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
  214. # Multi-scale
  215. if opt.multi_scale:
  216. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  217. sf = sz / max(imgs.shape[2:]) # scale factor
  218. if sf != 1:
  219. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  220. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  221. # Autocast
  222. with amp.autocast(enabled=cuda):
  223. # Forward
  224. pred = model(imgs)
  225. # Loss
  226. loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
  227. if rank != -1:
  228. loss *= opt.world_size # gradient averaged between devices in DDP mode
  229. # if not torch.isfinite(loss):
  230. # print('WARNING: non-finite loss, ending training ', loss_items)
  231. # return results
  232. # Backward
  233. scaler.scale(loss).backward()
  234. # Optimize
  235. if ni % accumulate == 0:
  236. scaler.step(optimizer) # optimizer.step
  237. scaler.update()
  238. optimizer.zero_grad()
  239. if ema is not None:
  240. ema.update(model)
  241. # Print
  242. if rank in [-1, 0]:
  243. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  244. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  245. s = ('%10s' * 2 + '%10.4g' * 6) % (
  246. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  247. pbar.set_description(s)
  248. # Plot
  249. if ni < 3:
  250. f = str(log_dir / ('train_batch%g.jpg' % ni)) # filename
  251. result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
  252. if tb_writer and result is not None:
  253. tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  254. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  255. # end batch ------------------------------------------------------------------------------------------------
  256. # Scheduler
  257. scheduler.step()
  258. # DDP process 0 or single-GPU
  259. if rank in [-1, 0]:
  260. # mAP
  261. if ema is not None:
  262. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
  263. final_epoch = epoch + 1 == epochs
  264. if not opt.notest or final_epoch: # Calculate mAP
  265. results, maps, times = test.test(opt.data,
  266. batch_size=total_batch_size,
  267. imgsz=imgsz_test,
  268. model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,
  269. single_cls=opt.single_cls,
  270. dataloader=testloader,
  271. save_dir=log_dir)
  272. # Write
  273. with open(results_file, 'a') as f:
  274. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
  275. if len(opt.name) and opt.bucket:
  276. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  277. # Tensorboard
  278. if tb_writer:
  279. tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
  280. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  281. 'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
  282. for x, tag in zip(list(mloss[:-1]) + list(results), tags):
  283. tb_writer.add_scalar(tag, x, epoch)
  284. # Update best mAP
  285. fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
  286. if fi > best_fitness:
  287. best_fitness = fi
  288. # Save model
  289. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  290. if save:
  291. with open(results_file, 'r') as f: # create checkpoint
  292. ckpt = {'epoch': epoch,
  293. 'best_fitness': best_fitness,
  294. 'training_results': f.read(),
  295. 'model': ema.ema.module if hasattr(ema, 'module') else ema.ema,
  296. 'optimizer': None if final_epoch else optimizer.state_dict()}
  297. # Save last, best and delete
  298. torch.save(ckpt, last)
  299. if best_fitness == fi:
  300. torch.save(ckpt, best)
  301. del ckpt
  302. # end epoch ----------------------------------------------------------------------------------------------------
  303. # end training
  304. if rank in [-1, 0]:
  305. # Strip optimizers
  306. n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
  307. fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
  308. for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
  309. if os.path.exists(f1):
  310. os.rename(f1, f2) # rename
  311. ispt = f2.endswith('.pt') # is *.pt
  312. strip_optimizer(f2) if ispt else None # strip optimizer
  313. os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
  314. # Finish
  315. if not opt.evolve:
  316. plot_results(save_dir=log_dir) # save as results.png
  317. print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  318. dist.destroy_process_group() if rank not in [-1, 0] else None
  319. torch.cuda.empty_cache()
  320. return results
  321. if __name__ == '__main__':
  322. parser = argparse.ArgumentParser()
  323. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  324. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  325. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  326. parser.add_argument('--hyp', type=str, default='', help='hyperparameters path, i.e. data/hyp.scratch.yaml')
  327. parser.add_argument('--epochs', type=int, default=300)
  328. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  329. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
  330. parser.add_argument('--rect', action='store_true', help='rectangular training')
  331. parser.add_argument('--resume', nargs='?', const='get_last', default=False,
  332. help='resume from given path/last.pt, or most recent run if blank')
  333. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  334. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  335. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  336. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  337. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  338. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  339. parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
  340. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  341. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  342. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  343. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  344. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  345. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  346. parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
  347. opt = parser.parse_args()
  348. # Resume
  349. if opt.resume:
  350. last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
  351. if last and not opt.weights:
  352. print(f'Resuming training from {last}')
  353. opt.weights = last if opt.resume and not opt.weights else opt.weights
  354. if opt.local_rank == -1 or ("RANK" in os.environ and os.environ["RANK"] == "0"):
  355. check_git_status()
  356. opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
  357. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  358. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  359. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  360. device = select_device(opt.device, batch_size=opt.batch_size)
  361. opt.total_batch_size = opt.batch_size
  362. opt.world_size = 1
  363. opt.global_rank = -1
  364. # DDP mode
  365. if opt.local_rank != -1:
  366. assert torch.cuda.device_count() > opt.local_rank
  367. torch.cuda.set_device(opt.local_rank)
  368. device = torch.device('cuda', opt.local_rank)
  369. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  370. opt.world_size = dist.get_world_size()
  371. opt.global_rank = dist.get_rank()
  372. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  373. opt.batch_size = opt.total_batch_size // opt.world_size
  374. print(opt)
  375. with open(opt.hyp) as f:
  376. hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
  377. # Train
  378. if not opt.evolve:
  379. tb_writer = None
  380. if opt.global_rank in [-1, 0]:
  381. print('Start Tensorboard with "tensorboard --logdir %s", view at http://localhost:6006/' % opt.logdir)
  382. tb_writer = SummaryWriter(log_dir=increment_dir(Path(opt.logdir) / 'exp', opt.name)) # runs/exp
  383. train(hyp, opt, device, tb_writer)
  384. # Evolve hyperparameters (optional)
  385. else:
  386. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  387. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  388. 'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
  389. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  390. 'giou': (1, 0.02, 0.2), # GIoU loss gain
  391. 'cls': (1, 0.2, 4.0), # cls loss gain
  392. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  393. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  394. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  395. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  396. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  397. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  398. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  399. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  400. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  401. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  402. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  403. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  404. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  405. 'perspective': (1, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  406. 'flipud': (0, 0.0, 1.0), # image flip up-down (probability)
  407. 'fliplr': (1, 0.0, 1.0), # image flip left-right (probability)
  408. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  409. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  410. opt.notest, opt.nosave = True, True # only test/save final epoch
  411. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  412. yaml_file = Path('runs/evolve/hyp_evolved.yaml') # save best result here
  413. if opt.bucket:
  414. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  415. for _ in range(100): # generations to evolve
  416. if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
  417. # Select parent(s)
  418. parent = 'single' # parent selection method: 'single' or 'weighted'
  419. x = np.loadtxt('evolve.txt', ndmin=2)
  420. n = min(5, len(x)) # number of previous results to consider
  421. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  422. w = fitness(x) - fitness(x).min() # weights
  423. if parent == 'single' or len(x) == 1:
  424. # x = x[random.randint(0, n - 1)] # random selection
  425. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  426. elif parent == 'weighted':
  427. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  428. # Mutate
  429. mp, s = 0.9, 0.2 # mutation probability, sigma
  430. npr = np.random
  431. npr.seed(int(time.time()))
  432. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  433. ng = len(meta)
  434. v = np.ones(ng)
  435. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  436. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  437. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  438. hyp[k] = float(x[i + 7] * v[i]) # mutate
  439. # Constrain to limits
  440. for k, v in meta.items():
  441. hyp[k] = max(hyp[k], v[1]) # lower limit
  442. hyp[k] = min(hyp[k], v[2]) # upper limit
  443. hyp[k] = round(hyp[k], 5) # significant digits
  444. # Train mutation
  445. results = train(hyp.copy(), opt, device)
  446. # Write mutation results
  447. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  448. # Plot results
  449. plot_evolution(yaml_file)
  450. print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these '
  451. 'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file))