You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

595 lines
30KB

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import time
  7. from pathlib import Path
  8. from threading import Thread
  9. from warnings import warn
  10. import numpy as np
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.optim as optim
  15. import torch.optim.lr_scheduler as lr_scheduler
  16. import torch.utils.data
  17. import yaml
  18. from torch.cuda import amp
  19. from torch.nn.parallel import DistributedDataParallel as DDP
  20. from torch.utils.tensorboard import SummaryWriter
  21. from tqdm import tqdm
  22. import test # import test.py to get mAP after each epoch
  23. from models.experimental import attempt_load
  24. from models.yolo import Model
  25. from utils.autoanchor import check_anchors
  26. from utils.datasets import create_dataloader
  27. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  28. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  29. print_mutation, set_logging
  30. from utils.google_utils import attempt_download
  31. from utils.loss import compute_loss
  32. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  33. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
  34. logger = logging.getLogger(__name__)
  35. try:
  36. import wandb
  37. except ImportError:
  38. wandb = None
  39. logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)")
  40. def train(hyp, opt, device, tb_writer=None, wandb=None):
  41. logger.info(f'Hyperparameters {hyp}')
  42. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  43. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  44. # Directories
  45. wdir = save_dir / 'weights'
  46. wdir.mkdir(parents=True, exist_ok=True) # make dir
  47. last = wdir / 'last.pt'
  48. best = wdir / 'best.pt'
  49. results_file = save_dir / 'results.txt'
  50. # Save run settings
  51. with open(save_dir / 'hyp.yaml', 'w') as f:
  52. yaml.dump(hyp, f, sort_keys=False)
  53. with open(save_dir / 'opt.yaml', 'w') as f:
  54. yaml.dump(vars(opt), f, sort_keys=False)
  55. # Configure
  56. plots = not opt.evolve # create plots
  57. cuda = device.type != 'cpu'
  58. init_seeds(2 + rank)
  59. with open(opt.data) as f:
  60. data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict
  61. with torch_distributed_zero_first(rank):
  62. check_dataset(data_dict) # check
  63. train_path = data_dict['train']
  64. test_path = data_dict['val']
  65. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  66. names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  67. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  68. # Model
  69. pretrained = weights.endswith('.pt')
  70. if pretrained:
  71. with torch_distributed_zero_first(rank):
  72. attempt_download(weights) # download if not found locally
  73. ckpt = torch.load(weights, map_location=device) # load checkpoint
  74. if hyp.get('anchors'):
  75. ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
  76. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  77. exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
  78. state_dict = ckpt['model'].float().state_dict() # to FP32
  79. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  80. model.load_state_dict(state_dict, strict=False) # load
  81. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  82. else:
  83. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  84. # Freeze
  85. freeze = [] # parameter names to freeze (full or partial)
  86. for k, v in model.named_parameters():
  87. v.requires_grad = True # train all layers
  88. if any(x in k for x in freeze):
  89. print('freezing %s' % k)
  90. v.requires_grad = False
  91. # Optimizer
  92. nbs = 64 # nominal batch size
  93. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  94. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  95. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  96. for k, v in model.named_modules():
  97. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  98. pg2.append(v.bias) # biases
  99. if isinstance(v, nn.BatchNorm2d):
  100. pg0.append(v.weight) # no decay
  101. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  102. pg1.append(v.weight) # apply decay
  103. if opt.adam:
  104. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  105. else:
  106. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  107. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  108. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  109. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  110. del pg0, pg1, pg2
  111. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  112. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  113. lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
  114. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  115. # plot_lr_scheduler(optimizer, scheduler, epochs)
  116. # Logging
  117. if wandb and wandb.run is None:
  118. opt.hyp = hyp # add hyperparameters
  119. wandb_run = wandb.init(config=opt, resume="allow",
  120. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  121. name=save_dir.stem,
  122. id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
  123. loggers = {'wandb': wandb} # loggers dict
  124. # Resume
  125. start_epoch, best_fitness = 0, 0.0
  126. if pretrained:
  127. # Optimizer
  128. if ckpt['optimizer'] is not None:
  129. optimizer.load_state_dict(ckpt['optimizer'])
  130. best_fitness = ckpt['best_fitness']
  131. # Results
  132. if ckpt.get('training_results') is not None:
  133. with open(results_file, 'w') as file:
  134. file.write(ckpt['training_results']) # write results.txt
  135. # Epochs
  136. start_epoch = ckpt['epoch'] + 1
  137. if opt.resume:
  138. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  139. if epochs < start_epoch:
  140. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  141. (weights, ckpt['epoch'], epochs))
  142. epochs += ckpt['epoch'] # finetune additional epochs
  143. del ckpt, state_dict
  144. # Image sizes
  145. gs = int(max(model.stride)) # grid size (max stride)
  146. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  147. # DP mode
  148. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  149. model = torch.nn.DataParallel(model)
  150. # SyncBatchNorm
  151. if opt.sync_bn and cuda and rank != -1:
  152. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  153. logger.info('Using SyncBatchNorm()')
  154. # EMA
  155. ema = ModelEMA(model) if rank in [-1, 0] else None
  156. # DDP mode
  157. if cuda and rank != -1:
  158. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
  159. # Trainloader
  160. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  161. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  162. world_size=opt.world_size, workers=opt.workers,
  163. image_weights=opt.image_weights)
  164. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  165. nb = len(dataloader) # number of batches
  166. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  167. # Process 0
  168. if rank in [-1, 0]:
  169. ema.updates = start_epoch * nb // accumulate # set EMA updates
  170. testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader
  171. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True,
  172. rank=-1, world_size=opt.world_size, workers=opt.workers, pad=0.5)[0]
  173. if not opt.resume:
  174. labels = np.concatenate(dataset.labels, 0)
  175. c = torch.tensor(labels[:, 0]) # classes
  176. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  177. # model._initialize_biases(cf.to(device))
  178. if plots:
  179. plot_labels(labels, save_dir, loggers)
  180. if tb_writer:
  181. tb_writer.add_histogram('classes', c, 0)
  182. # Anchors
  183. if not opt.noautoanchor:
  184. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  185. # Model parameters
  186. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  187. model.nc = nc # attach number of classes to model
  188. model.hyp = hyp # attach hyperparameters to model
  189. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  190. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  191. model.names = names
  192. # Start training
  193. t0 = time.time()
  194. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  195. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  196. maps = np.zeros(nc) # mAP per class
  197. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  198. scheduler.last_epoch = start_epoch - 1 # do not move
  199. scaler = amp.GradScaler(enabled=cuda)
  200. logger.info('Image sizes %g train, %g test\n'
  201. 'Using %g dataloader workers\nLogging results to %s\n'
  202. 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs))
  203. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  204. model.train()
  205. # Update image weights (optional)
  206. if opt.image_weights:
  207. # Generate indices
  208. if rank in [-1, 0]:
  209. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  210. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  211. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  212. # Broadcast if DDP
  213. if rank != -1:
  214. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  215. dist.broadcast(indices, 0)
  216. if rank != 0:
  217. dataset.indices = indices.cpu().numpy()
  218. # Update mosaic border
  219. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  220. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  221. mloss = torch.zeros(4, device=device) # mean losses
  222. if rank != -1:
  223. dataloader.sampler.set_epoch(epoch)
  224. pbar = enumerate(dataloader)
  225. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
  226. if rank in [-1, 0]:
  227. pbar = tqdm(pbar, total=nb) # progress bar
  228. optimizer.zero_grad()
  229. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  230. ni = i + nb * epoch # number integrated batches (since train start)
  231. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  232. # Warmup
  233. if ni <= nw:
  234. xi = [0, nw] # x interp
  235. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  236. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  237. for j, x in enumerate(optimizer.param_groups):
  238. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  239. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  240. if 'momentum' in x:
  241. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  242. # Multi-scale
  243. if opt.multi_scale:
  244. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  245. sf = sz / max(imgs.shape[2:]) # scale factor
  246. if sf != 1:
  247. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  248. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  249. # Forward
  250. with amp.autocast(enabled=cuda):
  251. pred = model(imgs) # forward
  252. loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
  253. if rank != -1:
  254. loss *= opt.world_size # gradient averaged between devices in DDP mode
  255. # Backward
  256. scaler.scale(loss).backward()
  257. # Optimize
  258. if ni % accumulate == 0:
  259. scaler.step(optimizer) # optimizer.step
  260. scaler.update()
  261. optimizer.zero_grad()
  262. if ema:
  263. ema.update(model)
  264. # Print
  265. if rank in [-1, 0]:
  266. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  267. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  268. s = ('%10s' * 2 + '%10.4g' * 6) % (
  269. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  270. pbar.set_description(s)
  271. # Plot
  272. if plots and ni < 3:
  273. f = save_dir / f'train_batch{ni}.jpg' # filename
  274. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  275. # if tb_writer:
  276. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  277. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  278. elif plots and ni == 3 and wandb:
  279. wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')]})
  280. # end batch ------------------------------------------------------------------------------------------------
  281. # end epoch ----------------------------------------------------------------------------------------------------
  282. # Scheduler
  283. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  284. scheduler.step()
  285. # DDP process 0 or single-GPU
  286. if rank in [-1, 0]:
  287. # mAP
  288. if ema:
  289. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  290. final_epoch = epoch + 1 == epochs
  291. if not opt.notest or final_epoch: # Calculate mAP
  292. results, maps, times = test.test(opt.data,
  293. batch_size=total_batch_size,
  294. imgsz=imgsz_test,
  295. model=ema.ema,
  296. single_cls=opt.single_cls,
  297. dataloader=testloader,
  298. save_dir=save_dir,
  299. plots=plots and final_epoch,
  300. log_imgs=opt.log_imgs if wandb else 0)
  301. # Write
  302. with open(results_file, 'a') as f:
  303. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  304. if len(opt.name) and opt.bucket:
  305. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  306. # Log
  307. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  308. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  309. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  310. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  311. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  312. if tb_writer:
  313. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  314. if wandb:
  315. wandb.log({tag: x}) # W&B
  316. # Update best mAP
  317. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  318. if fi > best_fitness:
  319. best_fitness = fi
  320. # Save model
  321. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  322. if save:
  323. with open(results_file, 'r') as f: # create checkpoint
  324. ckpt = {'epoch': epoch,
  325. 'best_fitness': best_fitness,
  326. 'training_results': f.read(),
  327. 'model': ema.ema,
  328. 'optimizer': None if final_epoch else optimizer.state_dict(),
  329. 'wandb_id': wandb_run.id if wandb else None}
  330. # Save last, best and delete
  331. torch.save(ckpt, last)
  332. if best_fitness == fi:
  333. torch.save(ckpt, best)
  334. del ckpt
  335. # end epoch ----------------------------------------------------------------------------------------------------
  336. # end training
  337. if rank in [-1, 0]:
  338. # Strip optimizers
  339. final = best if best.exists() else last # final model
  340. for f in [last, best]:
  341. if f.exists():
  342. strip_optimizer(f) # strip optimizers
  343. if opt.bucket:
  344. os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
  345. # Plots
  346. if plots:
  347. plot_results(save_dir=save_dir) # save as results.png
  348. if wandb:
  349. files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
  350. wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
  351. if (save_dir / f).exists()]})
  352. if opt.log_artifacts:
  353. wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)
  354. # Test best.pt
  355. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  356. if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
  357. for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests
  358. results, _, _ = test.test(opt.data,
  359. batch_size=total_batch_size,
  360. imgsz=imgsz_test,
  361. conf_thres=conf,
  362. iou_thres=iou,
  363. model=attempt_load(final, device).half(),
  364. single_cls=opt.single_cls,
  365. dataloader=testloader,
  366. save_dir=save_dir,
  367. save_json=save_json,
  368. plots=False)
  369. else:
  370. dist.destroy_process_group()
  371. wandb.run.finish() if wandb and wandb.run else None
  372. torch.cuda.empty_cache()
  373. return results
  374. if __name__ == '__main__':
  375. parser = argparse.ArgumentParser()
  376. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  377. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  378. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  379. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  380. parser.add_argument('--epochs', type=int, default=300)
  381. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  382. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  383. parser.add_argument('--rect', action='store_true', help='rectangular training')
  384. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  385. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  386. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  387. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  388. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  389. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  390. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  391. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  392. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  393. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  394. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  395. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  396. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  397. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  398. parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
  399. parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
  400. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  401. parser.add_argument('--project', default='runs/train', help='save to project/name')
  402. parser.add_argument('--name', default='exp', help='save to project/name')
  403. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  404. opt = parser.parse_args()
  405. # Set DDP variables
  406. opt.total_batch_size = opt.batch_size
  407. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  408. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  409. set_logging(opt.global_rank)
  410. if opt.global_rank in [-1, 0]:
  411. check_git_status()
  412. # Resume
  413. if opt.resume: # resume an interrupted run
  414. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  415. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  416. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  417. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
  418. opt.cfg, opt.weights, opt.resume = '', ckpt, True
  419. logger.info('Resuming training from %s' % ckpt)
  420. else:
  421. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  422. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  423. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  424. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  425. opt.name = 'evolve' if opt.evolve else opt.name
  426. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
  427. # DDP mode
  428. device = select_device(opt.device, batch_size=opt.batch_size)
  429. if opt.local_rank != -1:
  430. assert torch.cuda.device_count() > opt.local_rank
  431. torch.cuda.set_device(opt.local_rank)
  432. device = torch.device('cuda', opt.local_rank)
  433. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  434. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  435. opt.batch_size = opt.total_batch_size // opt.world_size
  436. # Hyperparameters
  437. with open(opt.hyp) as f:
  438. hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
  439. if 'box' not in hyp:
  440. warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
  441. (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
  442. hyp['box'] = hyp.pop('giou')
  443. # Train
  444. logger.info(opt)
  445. if not opt.evolve:
  446. tb_writer = None # init loggers
  447. if opt.global_rank in [-1, 0]:
  448. logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
  449. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  450. train(hyp, opt, device, tb_writer, wandb)
  451. # Evolve hyperparameters (optional)
  452. else:
  453. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  454. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  455. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  456. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  457. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  458. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  459. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  460. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  461. 'box': (1, 0.02, 0.2), # box loss gain
  462. 'cls': (1, 0.2, 4.0), # cls loss gain
  463. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  464. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  465. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  466. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  467. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  468. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  469. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  470. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  471. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  472. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  473. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  474. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  475. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  476. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  477. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  478. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  479. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  480. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  481. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  482. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  483. opt.notest, opt.nosave = True, True # only test/save final epoch
  484. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  485. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  486. if opt.bucket:
  487. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  488. for _ in range(300): # generations to evolve
  489. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  490. # Select parent(s)
  491. parent = 'single' # parent selection method: 'single' or 'weighted'
  492. x = np.loadtxt('evolve.txt', ndmin=2)
  493. n = min(5, len(x)) # number of previous results to consider
  494. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  495. w = fitness(x) - fitness(x).min() # weights
  496. if parent == 'single' or len(x) == 1:
  497. # x = x[random.randint(0, n - 1)] # random selection
  498. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  499. elif parent == 'weighted':
  500. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  501. # Mutate
  502. mp, s = 0.8, 0.2 # mutation probability, sigma
  503. npr = np.random
  504. npr.seed(int(time.time()))
  505. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  506. ng = len(meta)
  507. v = np.ones(ng)
  508. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  509. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  510. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  511. hyp[k] = float(x[i + 7] * v[i]) # mutate
  512. # Constrain to limits
  513. for k, v in meta.items():
  514. hyp[k] = max(hyp[k], v[1]) # lower limit
  515. hyp[k] = min(hyp[k], v[2]) # upper limit
  516. hyp[k] = round(hyp[k], 5) # significant digits
  517. # Train mutation
  518. results = train(hyp.copy(), opt, device, wandb=wandb)
  519. # Write mutation results
  520. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  521. # Plot results
  522. plot_evolution(yaml_file)
  523. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  524. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')