You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

575 lines
29KB

  1. import argparse
  2. import logging
  3. import os
  4. import random
  5. import time
  6. from pathlib import Path
  7. from threading import Thread
  8. from warnings import warn
  9. import math
  10. import numpy as np
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.optim as optim
  15. import torch.optim.lr_scheduler as lr_scheduler
  16. import torch.utils.data
  17. import yaml
  18. from torch.cuda import amp
  19. from torch.nn.parallel import DistributedDataParallel as DDP
  20. from torch.utils.tensorboard import SummaryWriter
  21. from tqdm import tqdm
  22. import test # import test.py to get mAP after each epoch
  23. from models.yolo import Model
  24. from utils.autoanchor import check_anchors
  25. from utils.datasets import create_dataloader
  26. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  27. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  28. print_mutation, set_logging
  29. from utils.google_utils import attempt_download
  30. from utils.loss import compute_loss
  31. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  32. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
  33. logger = logging.getLogger(__name__)
  34. try:
  35. import wandb
  36. except ImportError:
  37. wandb = None
  38. logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)")
  39. def train(hyp, opt, device, tb_writer=None, wandb=None):
  40. logger.info(f'Hyperparameters {hyp}')
  41. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  42. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  43. # Directories
  44. wdir = save_dir / 'weights'
  45. wdir.mkdir(parents=True, exist_ok=True) # make dir
  46. last = wdir / 'last.pt'
  47. best = wdir / 'best.pt'
  48. results_file = save_dir / 'results.txt'
  49. # Save run settings
  50. with open(save_dir / 'hyp.yaml', 'w') as f:
  51. yaml.dump(hyp, f, sort_keys=False)
  52. with open(save_dir / 'opt.yaml', 'w') as f:
  53. yaml.dump(vars(opt), f, sort_keys=False)
  54. # Configure
  55. plots = not opt.evolve # create plots
  56. cuda = device.type != 'cpu'
  57. init_seeds(2 + rank)
  58. with open(opt.data) as f:
  59. data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict
  60. with torch_distributed_zero_first(rank):
  61. check_dataset(data_dict) # check
  62. train_path = data_dict['train']
  63. test_path = data_dict['val']
  64. nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names
  65. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  66. # Model
  67. pretrained = weights.endswith('.pt')
  68. if pretrained:
  69. with torch_distributed_zero_first(rank):
  70. attempt_download(weights) # download if not found locally
  71. ckpt = torch.load(weights, map_location=device) # load checkpoint
  72. if hyp.get('anchors'):
  73. ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
  74. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  75. exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
  76. state_dict = ckpt['model'].float().state_dict() # to FP32
  77. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  78. model.load_state_dict(state_dict, strict=False) # load
  79. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  80. else:
  81. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  82. # Freeze
  83. freeze = [] # parameter names to freeze (full or partial)
  84. for k, v in model.named_parameters():
  85. v.requires_grad = True # train all layers
  86. if any(x in k for x in freeze):
  87. print('freezing %s' % k)
  88. v.requires_grad = False
  89. # Optimizer
  90. nbs = 64 # nominal batch size
  91. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  92. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  93. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  94. for k, v in model.named_modules():
  95. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  96. pg2.append(v.bias) # biases
  97. if isinstance(v, nn.BatchNorm2d):
  98. pg0.append(v.weight) # no decay
  99. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  100. pg1.append(v.weight) # apply decay
  101. if opt.adam:
  102. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  103. else:
  104. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  105. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  106. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  107. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  108. del pg0, pg1, pg2
  109. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  110. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  111. lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
  112. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  113. # plot_lr_scheduler(optimizer, scheduler, epochs)
  114. # Logging
  115. if wandb and wandb.run is None:
  116. opt.hyp = hyp # add hyperparameters
  117. wandb_run = wandb.init(config=opt, resume="allow",
  118. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  119. name=save_dir.stem,
  120. id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
  121. loggers = {'wandb': wandb} # loggers dict
  122. # Resume
  123. start_epoch, best_fitness = 0, 0.0
  124. if pretrained:
  125. # Optimizer
  126. if ckpt['optimizer'] is not None:
  127. optimizer.load_state_dict(ckpt['optimizer'])
  128. best_fitness = ckpt['best_fitness']
  129. # Results
  130. if ckpt.get('training_results') is not None:
  131. with open(results_file, 'w') as file:
  132. file.write(ckpt['training_results']) # write results.txt
  133. # Epochs
  134. start_epoch = ckpt['epoch'] + 1
  135. if opt.resume:
  136. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  137. if epochs < start_epoch:
  138. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  139. (weights, ckpt['epoch'], epochs))
  140. epochs += ckpt['epoch'] # finetune additional epochs
  141. del ckpt, state_dict
  142. # Image sizes
  143. gs = int(max(model.stride)) # grid size (max stride)
  144. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  145. # DP mode
  146. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  147. model = torch.nn.DataParallel(model)
  148. # SyncBatchNorm
  149. if opt.sync_bn and cuda and rank != -1:
  150. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  151. logger.info('Using SyncBatchNorm()')
  152. # EMA
  153. ema = ModelEMA(model) if rank in [-1, 0] else None
  154. # DDP mode
  155. if cuda and rank != -1:
  156. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
  157. # Trainloader
  158. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  159. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  160. world_size=opt.world_size, workers=opt.workers,
  161. image_weights=opt.image_weights)
  162. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  163. nb = len(dataloader) # number of batches
  164. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  165. # Process 0
  166. if rank in [-1, 0]:
  167. ema.updates = start_epoch * nb // accumulate # set EMA updates
  168. testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
  169. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True,
  170. rank=-1, world_size=opt.world_size, workers=opt.workers)[0] # testloader
  171. if not opt.resume:
  172. labels = np.concatenate(dataset.labels, 0)
  173. c = torch.tensor(labels[:, 0]) # classes
  174. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  175. # model._initialize_biases(cf.to(device))
  176. if plots:
  177. Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start()
  178. if tb_writer:
  179. tb_writer.add_histogram('classes', c, 0)
  180. # Anchors
  181. if not opt.noautoanchor:
  182. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  183. # Model parameters
  184. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  185. model.nc = nc # attach number of classes to model
  186. model.hyp = hyp # attach hyperparameters to model
  187. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  188. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  189. model.names = names
  190. # Start training
  191. t0 = time.time()
  192. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  193. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  194. maps = np.zeros(nc) # mAP per class
  195. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  196. scheduler.last_epoch = start_epoch - 1 # do not move
  197. scaler = amp.GradScaler(enabled=cuda)
  198. logger.info('Image sizes %g train, %g test\n'
  199. 'Using %g dataloader workers\nLogging results to %s\n'
  200. 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs))
  201. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  202. model.train()
  203. # Update image weights (optional)
  204. if opt.image_weights:
  205. # Generate indices
  206. if rank in [-1, 0]:
  207. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  208. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  209. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  210. # Broadcast if DDP
  211. if rank != -1:
  212. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  213. dist.broadcast(indices, 0)
  214. if rank != 0:
  215. dataset.indices = indices.cpu().numpy()
  216. # Update mosaic border
  217. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  218. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  219. mloss = torch.zeros(4, device=device) # mean losses
  220. if rank != -1:
  221. dataloader.sampler.set_epoch(epoch)
  222. pbar = enumerate(dataloader)
  223. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
  224. if rank in [-1, 0]:
  225. pbar = tqdm(pbar, total=nb) # progress bar
  226. optimizer.zero_grad()
  227. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  228. ni = i + nb * epoch # number integrated batches (since train start)
  229. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  230. # Warmup
  231. if ni <= nw:
  232. xi = [0, nw] # x interp
  233. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  234. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  235. for j, x in enumerate(optimizer.param_groups):
  236. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  237. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  238. if 'momentum' in x:
  239. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  240. # Multi-scale
  241. if opt.multi_scale:
  242. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  243. sf = sz / max(imgs.shape[2:]) # scale factor
  244. if sf != 1:
  245. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  246. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  247. # Forward
  248. with amp.autocast(enabled=cuda):
  249. pred = model(imgs) # forward
  250. loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
  251. if rank != -1:
  252. loss *= opt.world_size # gradient averaged between devices in DDP mode
  253. # Backward
  254. scaler.scale(loss).backward()
  255. # Optimize
  256. if ni % accumulate == 0:
  257. scaler.step(optimizer) # optimizer.step
  258. scaler.update()
  259. optimizer.zero_grad()
  260. if ema:
  261. ema.update(model)
  262. # Print
  263. if rank in [-1, 0]:
  264. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  265. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  266. s = ('%10s' * 2 + '%10.4g' * 6) % (
  267. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  268. pbar.set_description(s)
  269. # Plot
  270. if plots and ni < 3:
  271. f = save_dir / f'train_batch{ni}.jpg' # filename
  272. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  273. # if tb_writer:
  274. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  275. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  276. elif plots and ni == 3 and wandb:
  277. wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')]})
  278. # end batch ------------------------------------------------------------------------------------------------
  279. # end epoch ----------------------------------------------------------------------------------------------------
  280. # Scheduler
  281. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  282. scheduler.step()
  283. # DDP process 0 or single-GPU
  284. if rank in [-1, 0]:
  285. # mAP
  286. if ema:
  287. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
  288. final_epoch = epoch + 1 == epochs
  289. if not opt.notest or final_epoch: # Calculate mAP
  290. results, maps, times = test.test(opt.data,
  291. batch_size=total_batch_size,
  292. imgsz=imgsz_test,
  293. model=ema.ema,
  294. single_cls=opt.single_cls,
  295. dataloader=testloader,
  296. save_dir=save_dir,
  297. plots=plots and final_epoch,
  298. log_imgs=opt.log_imgs if wandb else 0)
  299. # Write
  300. with open(results_file, 'a') as f:
  301. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  302. if len(opt.name) and opt.bucket:
  303. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  304. # Log
  305. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  306. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  307. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  308. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  309. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  310. if tb_writer:
  311. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  312. if wandb:
  313. wandb.log({tag: x}) # W&B
  314. # Update best mAP
  315. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  316. if fi > best_fitness:
  317. best_fitness = fi
  318. # Save model
  319. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  320. if save:
  321. with open(results_file, 'r') as f: # create checkpoint
  322. ckpt = {'epoch': epoch,
  323. 'best_fitness': best_fitness,
  324. 'training_results': f.read(),
  325. 'model': ema.ema,
  326. 'optimizer': None if final_epoch else optimizer.state_dict(),
  327. 'wandb_id': wandb_run.id if wandb else None}
  328. # Save last, best and delete
  329. torch.save(ckpt, last)
  330. if best_fitness == fi:
  331. torch.save(ckpt, best)
  332. del ckpt
  333. # end epoch ----------------------------------------------------------------------------------------------------
  334. # end training
  335. if rank in [-1, 0]:
  336. # Strip optimizers
  337. n = opt.name if opt.name.isnumeric() else ''
  338. fresults, flast, fbest = save_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'
  339. for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]):
  340. if f1.exists():
  341. os.rename(f1, f2) # rename
  342. if str(f2).endswith('.pt'): # is *.pt
  343. strip_optimizer(f2) # strip optimizer
  344. os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None # upload
  345. # Finish
  346. if plots:
  347. plot_results(save_dir=save_dir) # save as results.png
  348. if wandb:
  349. files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
  350. wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
  351. if (save_dir / f).exists()]})
  352. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  353. else:
  354. dist.destroy_process_group()
  355. wandb.run.finish() if wandb and wandb.run else None
  356. torch.cuda.empty_cache()
  357. return results
  358. if __name__ == '__main__':
  359. parser = argparse.ArgumentParser()
  360. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  361. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  362. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  363. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  364. parser.add_argument('--epochs', type=int, default=300)
  365. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  366. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  367. parser.add_argument('--rect', action='store_true', help='rectangular training')
  368. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  369. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  370. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  371. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  372. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  373. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  374. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  375. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  376. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  377. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  378. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  379. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  380. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  381. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  382. parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
  383. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  384. parser.add_argument('--project', default='runs/train', help='save to project/name')
  385. parser.add_argument('--name', default='exp', help='save to project/name')
  386. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  387. opt = parser.parse_args()
  388. # Set DDP variables
  389. opt.total_batch_size = opt.batch_size
  390. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  391. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  392. set_logging(opt.global_rank)
  393. if opt.global_rank in [-1, 0]:
  394. check_git_status()
  395. # Resume
  396. if opt.resume: # resume an interrupted run
  397. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  398. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  399. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  400. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
  401. opt.cfg, opt.weights, opt.resume = '', ckpt, True
  402. logger.info('Resuming training from %s' % ckpt)
  403. else:
  404. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  405. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  406. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  407. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  408. opt.name = 'evolve' if opt.evolve else opt.name
  409. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
  410. # DDP mode
  411. device = select_device(opt.device, batch_size=opt.batch_size)
  412. if opt.local_rank != -1:
  413. assert torch.cuda.device_count() > opt.local_rank
  414. torch.cuda.set_device(opt.local_rank)
  415. device = torch.device('cuda', opt.local_rank)
  416. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  417. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  418. opt.batch_size = opt.total_batch_size // opt.world_size
  419. # Hyperparameters
  420. with open(opt.hyp) as f:
  421. hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
  422. if 'box' not in hyp:
  423. warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
  424. (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
  425. hyp['box'] = hyp.pop('giou')
  426. # Train
  427. logger.info(opt)
  428. if not opt.evolve:
  429. tb_writer = None # init loggers
  430. if opt.global_rank in [-1, 0]:
  431. logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
  432. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  433. train(hyp, opt, device, tb_writer, wandb)
  434. # Evolve hyperparameters (optional)
  435. else:
  436. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  437. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  438. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  439. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  440. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  441. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  442. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  443. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  444. 'box': (1, 0.02, 0.2), # box loss gain
  445. 'cls': (1, 0.2, 4.0), # cls loss gain
  446. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  447. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  448. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  449. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  450. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  451. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  452. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  453. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  454. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  455. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  456. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  457. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  458. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  459. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  460. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  461. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  462. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  463. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  464. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  465. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  466. opt.notest, opt.nosave = True, True # only test/save final epoch
  467. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  468. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  469. if opt.bucket:
  470. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  471. for _ in range(300): # generations to evolve
  472. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  473. # Select parent(s)
  474. parent = 'single' # parent selection method: 'single' or 'weighted'
  475. x = np.loadtxt('evolve.txt', ndmin=2)
  476. n = min(5, len(x)) # number of previous results to consider
  477. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  478. w = fitness(x) - fitness(x).min() # weights
  479. if parent == 'single' or len(x) == 1:
  480. # x = x[random.randint(0, n - 1)] # random selection
  481. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  482. elif parent == 'weighted':
  483. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  484. # Mutate
  485. mp, s = 0.8, 0.2 # mutation probability, sigma
  486. npr = np.random
  487. npr.seed(int(time.time()))
  488. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  489. ng = len(meta)
  490. v = np.ones(ng)
  491. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  492. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  493. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  494. hyp[k] = float(x[i + 7] * v[i]) # mutate
  495. # Constrain to limits
  496. for k, v in meta.items():
  497. hyp[k] = max(hyp[k], v[1]) # lower limit
  498. hyp[k] = min(hyp[k], v[2]) # upper limit
  499. hyp[k] = round(hyp[k], 5) # significant digits
  500. # Train mutation
  501. results = train(hyp.copy(), opt, device, wandb=wandb)
  502. # Write mutation results
  503. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  504. # Plot results
  505. plot_evolution(yaml_file)
  506. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  507. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')