You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

515 lines
25KB

  1. import argparse
  2. import torch.distributed as dist
  3. import torch.nn.functional as F
  4. import torch.optim as optim
  5. import torch.optim.lr_scheduler as lr_scheduler
  6. import torch.utils.data
  7. from torch.cuda import amp
  8. from torch.nn.parallel import DistributedDataParallel as DDP
  9. from torch.utils.tensorboard import SummaryWriter
  10. import test # import test.py to get mAP after each epoch
  11. from models.yolo import Model
  12. from utils import google_utils
  13. from utils.datasets import *
  14. from utils.utils import *
  15. # Hyperparameters
  16. hyp = {'optimizer': 'SGD', # ['Adam', 'SGD', ...] from torch.optim
  17. 'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
  18. 'momentum': 0.937, # SGD momentum/Adam beta1
  19. 'weight_decay': 5e-4, # optimizer weight decay
  20. 'giou': 0.05, # GIoU loss gain
  21. 'cls': 0.5, # cls loss gain
  22. 'cls_pw': 1.0, # cls BCELoss positive_weight
  23. 'obj': 1.0, # obj loss gain (scale with pixels)
  24. 'obj_pw': 1.0, # obj BCELoss positive_weight
  25. 'iou_t': 0.20, # IoU training threshold
  26. 'anchor_t': 4.0, # anchor-multiple threshold
  27. 'fl_gamma': 0.0, # focal loss gamma (efficientDet default gamma=1.5)
  28. 'hsv_h': 0.015, # image HSV-Hue augmentation (fraction)
  29. 'hsv_s': 0.7, # image HSV-Saturation augmentation (fraction)
  30. 'hsv_v': 0.4, # image HSV-Value augmentation (fraction)
  31. 'degrees': 0.0, # image rotation (+/- deg)
  32. 'translate': 0.5, # image translation (+/- fraction)
  33. 'scale': 0.5, # image scale (+/- gain)
  34. 'shear': 0.0, # image shear (+/- deg)
  35. 'perspective': 0.0, # image perspective (+/- fraction), range 0-0.001
  36. 'flipud': 0.0, # image flip up-down (probability)
  37. 'fliplr': 0.5, # image flip left-right (probability)
  38. 'mixup': 0.0} # image mixup (probability)
  39. def train(hyp, tb_writer, opt, device):
  40. print(f'Hyperparameters {hyp}')
  41. log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory
  42. wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
  43. os.makedirs(wdir, exist_ok=True)
  44. last = wdir + 'last.pt'
  45. best = wdir + 'best.pt'
  46. results_file = log_dir + os.sep + 'results.txt'
  47. epochs, batch_size, total_batch_size, weights, rank = \
  48. opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.local_rank
  49. # TODO: Use DDP logging. Only the first process is allowed to log.
  50. # Save run settings
  51. with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
  52. yaml.dump(hyp, f, sort_keys=False)
  53. with open(Path(log_dir) / 'opt.yaml', 'w') as f:
  54. yaml.dump(vars(opt), f, sort_keys=False)
  55. # Configure
  56. cuda = device.type != 'cpu'
  57. init_seeds(2 + rank)
  58. with open(opt.data) as f:
  59. data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
  60. train_path = data_dict['train']
  61. test_path = data_dict['val']
  62. nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names
  63. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  64. # Remove previous results
  65. if rank in [-1, 0]:
  66. for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
  67. os.remove(f)
  68. # Create model
  69. model = Model(opt.cfg, nc=nc).to(device)
  70. # Image sizes
  71. gs = int(max(model.stride)) # grid size (max stride)
  72. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  73. # Optimizer
  74. nbs = 64 # nominal batch size
  75. # default DDP implementation is slow for accumulation according to: https://pytorch.org/docs/stable/notes/ddp.html
  76. # all-reduce operation is carried out during loss.backward().
  77. # Thus, there would be redundant all-reduce communications in a accumulation procedure,
  78. # which means, the result is still right but the training speed gets slower.
  79. # TODO: If acceleration is needed, there is an implementation of allreduce_post_accumulation
  80. # in https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/run_pretraining.py
  81. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  82. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  83. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  84. for k, v in model.named_parameters():
  85. if v.requires_grad:
  86. if '.bias' in k:
  87. pg2.append(v) # biases
  88. elif '.weight' in k and '.bn' not in k:
  89. pg1.append(v) # apply weight decay
  90. else:
  91. pg0.append(v) # all else
  92. if hyp['optimizer'] == 'Adam':
  93. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  94. else:
  95. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  96. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  97. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  98. print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  99. del pg0, pg1, pg2
  100. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  101. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  102. lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
  103. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  104. # plot_lr_scheduler(optimizer, scheduler, epochs)
  105. # Load Model
  106. with torch_distributed_zero_first(rank):
  107. google_utils.attempt_download(weights)
  108. start_epoch, best_fitness = 0, 0.0
  109. if weights.endswith('.pt'): # pytorch format
  110. ckpt = torch.load(weights, map_location=device) # load checkpoint
  111. # load model
  112. try:
  113. exclude = ['anchor'] # exclude keys
  114. ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
  115. if k in model.state_dict() and not any(x in k for x in exclude)
  116. and model.state_dict()[k].shape == v.shape}
  117. model.load_state_dict(ckpt['model'], strict=False)
  118. print('Transferred %g/%g items from %s' % (len(ckpt['model']), len(model.state_dict()), weights))
  119. except KeyError as e:
  120. s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
  121. "Please delete or update %s and try again, or use --weights '' to train from scratch." \
  122. % (weights, opt.cfg, weights, weights)
  123. raise KeyError(s) from e
  124. # load optimizer
  125. if ckpt['optimizer'] is not None:
  126. optimizer.load_state_dict(ckpt['optimizer'])
  127. best_fitness = ckpt['best_fitness']
  128. # load results
  129. if ckpt.get('training_results') is not None:
  130. with open(results_file, 'w') as file:
  131. file.write(ckpt['training_results']) # write results.txt
  132. # epochs
  133. start_epoch = ckpt['epoch'] + 1
  134. if epochs < start_epoch:
  135. print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  136. (weights, ckpt['epoch'], epochs))
  137. epochs += ckpt['epoch'] # finetune additional epochs
  138. del ckpt
  139. # DP mode
  140. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  141. model = torch.nn.DataParallel(model)
  142. # SyncBatchNorm
  143. if opt.sync_bn and cuda and rank != -1:
  144. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  145. print('Using SyncBatchNorm()')
  146. # Exponential moving average
  147. ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None
  148. # DDP mode
  149. if cuda and rank != -1:
  150. model = DDP(model, device_ids=[rank], output_device=rank)
  151. # Trainloader
  152. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
  153. cache=opt.cache_images, rect=opt.rect, local_rank=rank,
  154. world_size=opt.world_size)
  155. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  156. nb = len(dataloader) # number of batches
  157. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  158. # Testloader
  159. if rank in [-1, 0]:
  160. # local_rank is set to -1. Because only the first process is expected to do evaluation.
  161. testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
  162. cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0]
  163. # Model parameters
  164. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  165. model.nc = nc # attach number of classes to model
  166. model.hyp = hyp # attach hyperparameters to model
  167. model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
  168. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  169. model.names = names
  170. # Class frequency
  171. if rank in [-1, 0]:
  172. labels = np.concatenate(dataset.labels, 0)
  173. c = torch.tensor(labels[:, 0]) # classes
  174. # cf = torch.bincount(c.long(), minlength=nc) + 1.
  175. # model._initialize_biases(cf.to(device))
  176. plot_labels(labels, save_dir=log_dir)
  177. if tb_writer:
  178. # tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
  179. tb_writer.add_histogram('classes', c, 0)
  180. # Check anchors
  181. if not opt.noautoanchor:
  182. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  183. # Start training
  184. t0 = time.time()
  185. nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
  186. maps = np.zeros(nc) # mAP per class
  187. results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
  188. scheduler.last_epoch = start_epoch - 1 # do not move
  189. scaler = amp.GradScaler(enabled=cuda)
  190. if rank in [0, -1]:
  191. print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
  192. print('Using %g dataloader workers' % dataloader.num_workers)
  193. print('Starting training for %g epochs...' % epochs)
  194. # torch.autograd.set_detect_anomaly(True)
  195. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  196. model.train()
  197. # Update image weights (optional)
  198. if dataset.image_weights:
  199. # Generate indices
  200. if rank in [-1, 0]:
  201. w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  202. image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
  203. dataset.indices = random.choices(range(dataset.n), weights=image_weights,
  204. k=dataset.n) # rand weighted idx
  205. # Broadcast if DDP
  206. if rank != -1:
  207. indices = torch.zeros([dataset.n], dtype=torch.int)
  208. if rank == 0:
  209. indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int)
  210. dist.broadcast(indices, 0)
  211. if rank != 0:
  212. dataset.indices = indices.cpu().numpy()
  213. # Update mosaic border
  214. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  215. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  216. mloss = torch.zeros(4, device=device) # mean losses
  217. if rank != -1:
  218. dataloader.sampler.set_epoch(epoch)
  219. pbar = enumerate(dataloader)
  220. if rank in [-1, 0]:
  221. print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
  222. pbar = tqdm(pbar, total=nb) # progress bar
  223. optimizer.zero_grad()
  224. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  225. ni = i + nb * epoch # number integrated batches (since train start)
  226. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  227. # Warmup
  228. if ni <= nw:
  229. xi = [0, nw] # x interp
  230. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
  231. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  232. for j, x in enumerate(optimizer.param_groups):
  233. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  234. x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  235. if 'momentum' in x:
  236. x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
  237. # Multi-scale
  238. if opt.multi_scale:
  239. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  240. sf = sz / max(imgs.shape[2:]) # scale factor
  241. if sf != 1:
  242. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  243. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  244. # Autocast
  245. with amp.autocast():
  246. # Forward
  247. pred = model(imgs)
  248. # Loss
  249. loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
  250. if rank != -1:
  251. loss *= opt.world_size # gradient averaged between devices in DDP mode
  252. # if not torch.isfinite(loss):
  253. # print('WARNING: non-finite loss, ending training ', loss_items)
  254. # return results
  255. # Backward
  256. scaler.scale(loss).backward()
  257. # Optimize
  258. if ni % accumulate == 0:
  259. scaler.step(optimizer) # optimizer.step
  260. scaler.update()
  261. optimizer.zero_grad()
  262. if ema is not None:
  263. ema.update(model)
  264. # Print
  265. if rank in [-1, 0]:
  266. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  267. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  268. s = ('%10s' * 2 + '%10.4g' * 6) % (
  269. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  270. pbar.set_description(s)
  271. # Plot
  272. if ni < 3:
  273. f = str(Path(log_dir) / ('train_batch%g.jpg' % ni)) # filename
  274. result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
  275. if tb_writer and result is not None:
  276. tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  277. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  278. # end batch ------------------------------------------------------------------------------------------------
  279. # Scheduler
  280. scheduler.step()
  281. # DDP process 0 or single-GPU
  282. if rank in [-1, 0]:
  283. # mAP
  284. if ema is not None:
  285. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
  286. final_epoch = epoch + 1 == epochs
  287. if not opt.notest or final_epoch: # Calculate mAP
  288. results, maps, times = test.test(opt.data,
  289. batch_size=total_batch_size,
  290. imgsz=imgsz_test,
  291. save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
  292. model=ema.ema.module if hasattr(ema.ema, 'module') else ema.ema,
  293. single_cls=opt.single_cls,
  294. dataloader=testloader,
  295. save_dir=log_dir)
  296. # Write
  297. with open(results_file, 'a') as f:
  298. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
  299. if len(opt.name) and opt.bucket:
  300. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  301. # Tensorboard
  302. if tb_writer:
  303. tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
  304. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  305. 'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
  306. for x, tag in zip(list(mloss[:-1]) + list(results), tags):
  307. tb_writer.add_scalar(tag, x, epoch)
  308. # Update best mAP
  309. fi = fitness(np.array(results).reshape(1, -1)) # fitness_i = weighted combination of [P, R, mAP, F1]
  310. if fi > best_fitness:
  311. best_fitness = fi
  312. # Save model
  313. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  314. if save:
  315. with open(results_file, 'r') as f: # create checkpoint
  316. ckpt = {'epoch': epoch,
  317. 'best_fitness': best_fitness,
  318. 'training_results': f.read(),
  319. 'model': ema.ema.module if hasattr(ema, 'module') else ema.ema,
  320. 'optimizer': None if final_epoch else optimizer.state_dict()}
  321. # Save last, best and delete
  322. torch.save(ckpt, last)
  323. if best_fitness == fi:
  324. torch.save(ckpt, best)
  325. del ckpt
  326. # end epoch ----------------------------------------------------------------------------------------------------
  327. # end training
  328. if rank in [-1, 0]:
  329. # Strip optimizers
  330. n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
  331. fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
  332. for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
  333. if os.path.exists(f1):
  334. os.rename(f1, f2) # rename
  335. ispt = f2.endswith('.pt') # is *.pt
  336. strip_optimizer(f2) if ispt else None # strip optimizer
  337. os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
  338. # Finish
  339. if not opt.evolve:
  340. plot_results(save_dir=log_dir) # save as results.png
  341. print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  342. dist.destroy_process_group() if rank not in [-1, 0] else None
  343. torch.cuda.empty_cache()
  344. return results
  345. if __name__ == '__main__':
  346. parser = argparse.ArgumentParser()
  347. parser.add_argument('--cfg', type=str, default='models/yolov5s.yaml', help='model.yaml path')
  348. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  349. parser.add_argument('--hyp', type=str, default='', help='hyp.yaml path (optional)')
  350. parser.add_argument('--epochs', type=int, default=300)
  351. parser.add_argument('--batch-size', type=int, default=16, help="Total batch size for all gpus.")
  352. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
  353. parser.add_argument('--rect', action='store_true', help='rectangular training')
  354. parser.add_argument('--resume', nargs='?', const='get_last', default=False,
  355. help='resume from given path/to/last.pt, or most recent run if blank.')
  356. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  357. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  358. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  359. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  360. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  361. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  362. parser.add_argument('--weights', type=str, default='', help='initial weights path')
  363. parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied')
  364. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  365. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  366. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  367. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  368. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  369. opt = parser.parse_args()
  370. # Resume
  371. last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
  372. if last and not opt.weights:
  373. print(f'Resuming training from {last}')
  374. opt.weights = last if opt.resume and not opt.weights else opt.weights
  375. if opt.local_rank in [-1, 0]:
  376. check_git_status()
  377. opt.cfg = check_file(opt.cfg) # check file
  378. opt.data = check_file(opt.data) # check file
  379. if opt.hyp: # update hyps
  380. opt.hyp = check_file(opt.hyp) # check file
  381. with open(opt.hyp) as f:
  382. hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
  383. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  384. device = torch_utils.select_device(opt.device, batch_size=opt.batch_size)
  385. opt.total_batch_size = opt.batch_size
  386. opt.world_size = 1
  387. # DDP mode
  388. if opt.local_rank != -1:
  389. assert torch.cuda.device_count() > opt.local_rank
  390. torch.cuda.set_device(opt.local_rank)
  391. device = torch.device("cuda", opt.local_rank)
  392. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  393. opt.world_size = dist.get_world_size()
  394. assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
  395. opt.batch_size = opt.total_batch_size // opt.world_size
  396. print(opt)
  397. # Train
  398. if not opt.evolve:
  399. if opt.local_rank in [-1, 0]:
  400. print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
  401. tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
  402. else:
  403. tb_writer = None
  404. train(hyp, tb_writer, opt, device)
  405. # Evolve hyperparameters (optional)
  406. else:
  407. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  408. tb_writer = None
  409. opt.notest, opt.nosave = True, True # only test/save final epoch
  410. if opt.bucket:
  411. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  412. for _ in range(10): # generations to evolve
  413. if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
  414. # Select parent(s)
  415. parent = 'single' # parent selection method: 'single' or 'weighted'
  416. x = np.loadtxt('evolve.txt', ndmin=2)
  417. n = min(5, len(x)) # number of previous results to consider
  418. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  419. w = fitness(x) - fitness(x).min() # weights
  420. if parent == 'single' or len(x) == 1:
  421. # x = x[random.randint(0, n - 1)] # random selection
  422. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  423. elif parent == 'weighted':
  424. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  425. # Mutate
  426. mp, s = 0.9, 0.2 # mutation probability, sigma
  427. npr = np.random
  428. npr.seed(int(time.time()))
  429. g = np.array([1, 1, 1, 1, 1, 1, 1, 0, .1, 1, 0, 1, 1, 1, 1, 1, 1, 1]) # gains
  430. ng = len(g)
  431. v = np.ones(ng)
  432. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  433. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  434. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  435. hyp[k] = x[i + 7] * v[i] # mutate
  436. # Clip to limits
  437. keys = ['lr0', 'iou_t', 'momentum', 'weight_decay', 'hsv_s', 'hsv_v', 'translate', 'scale', 'fl_gamma']
  438. limits = [(1e-5, 1e-2), (0.00, 0.70), (0.60, 0.98), (0, 0.001), (0, .9), (0, .9), (0, .9), (0, .9), (0, 3)]
  439. for k, v in zip(keys, limits):
  440. hyp[k] = np.clip(hyp[k], v[0], v[1])
  441. # Train mutation
  442. results = train(hyp.copy(), tb_writer, opt, device)
  443. # Write mutation results
  444. print_mutation(hyp, results, opt.bucket)
  445. # Plot results
  446. # plot_evolution_results(hyp)