You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

586 line
30KB

  1. import argparse
  2. import logging
  3. import os
  4. import random
  5. import time
  6. from pathlib import Path
  7. from threading import Thread
  8. from warnings import warn
  9. import math
  10. import numpy as np
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.optim as optim
  15. import torch.optim.lr_scheduler as lr_scheduler
  16. import torch.utils.data
  17. import yaml
  18. from torch.cuda import amp
  19. from torch.nn.parallel import DistributedDataParallel as DDP
  20. from torch.utils.tensorboard import SummaryWriter
  21. from tqdm import tqdm
  22. import test # import test.py to get mAP after each epoch
  23. from models.experimental import attempt_load
  24. from models.yolo import Model
  25. from utils.autoanchor import check_anchors
  26. from utils.datasets import create_dataloader
  27. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  28. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  29. print_mutation, set_logging
  30. from utils.google_utils import attempt_download
  31. from utils.loss import compute_loss
  32. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  33. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
  34. logger = logging.getLogger(__name__)
  35. try:
  36. import wandb
  37. except ImportError:
  38. wandb = None
  39. logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)")
  40. def train(hyp, opt, device, tb_writer=None, wandb=None):
  41. logger.info(f'Hyperparameters {hyp}')
  42. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  43. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  44. # Directories
  45. wdir = save_dir / 'weights'
  46. wdir.mkdir(parents=True, exist_ok=True) # make dir
  47. last = wdir / 'last.pt'
  48. best = wdir / 'best.pt'
  49. results_file = save_dir / 'results.txt'
  50. # Save run settings
  51. with open(save_dir / 'hyp.yaml', 'w') as f:
  52. yaml.dump(hyp, f, sort_keys=False)
  53. with open(save_dir / 'opt.yaml', 'w') as f:
  54. yaml.dump(vars(opt), f, sort_keys=False)
  55. # Configure
  56. plots = not opt.evolve # create plots
  57. cuda = device.type != 'cpu'
  58. init_seeds(2 + rank)
  59. with open(opt.data) as f:
  60. data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict
  61. with torch_distributed_zero_first(rank):
  62. check_dataset(data_dict) # check
  63. train_path = data_dict['train']
  64. test_path = data_dict['val']
  65. nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names
  66. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  67. # Model
  68. pretrained = weights.endswith('.pt')
  69. if pretrained:
  70. with torch_distributed_zero_first(rank):
  71. attempt_download(weights) # download if not found locally
  72. ckpt = torch.load(weights, map_location=device) # load checkpoint
  73. if hyp.get('anchors'):
  74. ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
  75. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  76. exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
  77. state_dict = ckpt['model'].float().state_dict() # to FP32
  78. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  79. model.load_state_dict(state_dict, strict=False) # load
  80. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  81. else:
  82. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  83. # Freeze
  84. freeze = [] # parameter names to freeze (full or partial)
  85. for k, v in model.named_parameters():
  86. v.requires_grad = True # train all layers
  87. if any(x in k for x in freeze):
  88. print('freezing %s' % k)
  89. v.requires_grad = False
  90. # Optimizer
  91. nbs = 64 # nominal batch size
  92. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  93. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  94. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  95. for k, v in model.named_modules():
  96. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  97. pg2.append(v.bias) # biases
  98. if isinstance(v, nn.BatchNorm2d):
  99. pg0.append(v.weight) # no decay
  100. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  101. pg1.append(v.weight) # apply decay
  102. if opt.adam:
  103. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  104. else:
  105. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  106. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  107. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  108. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  109. del pg0, pg1, pg2
  110. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  111. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  112. lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
  113. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  114. # plot_lr_scheduler(optimizer, scheduler, epochs)
  115. # Logging
  116. if wandb and wandb.run is None:
  117. opt.hyp = hyp # add hyperparameters
  118. wandb_run = wandb.init(config=opt, resume="allow",
  119. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  120. name=save_dir.stem,
  121. id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
  122. loggers = {'wandb': wandb} # loggers dict
  123. # Resume
  124. start_epoch, best_fitness = 0, 0.0
  125. if pretrained:
  126. # Optimizer
  127. if ckpt['optimizer'] is not None:
  128. optimizer.load_state_dict(ckpt['optimizer'])
  129. best_fitness = ckpt['best_fitness']
  130. # Results
  131. if ckpt.get('training_results') is not None:
  132. with open(results_file, 'w') as file:
  133. file.write(ckpt['training_results']) # write results.txt
  134. # Epochs
  135. start_epoch = ckpt['epoch'] + 1
  136. if opt.resume:
  137. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  138. if epochs < start_epoch:
  139. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  140. (weights, ckpt['epoch'], epochs))
  141. epochs += ckpt['epoch'] # finetune additional epochs
  142. del ckpt, state_dict
  143. # Image sizes
  144. gs = int(max(model.stride)) # grid size (max stride)
  145. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  146. # DP mode
  147. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  148. model = torch.nn.DataParallel(model)
  149. # SyncBatchNorm
  150. if opt.sync_bn and cuda and rank != -1:
  151. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  152. logger.info('Using SyncBatchNorm()')
  153. # EMA
  154. ema = ModelEMA(model) if rank in [-1, 0] else None
  155. # DDP mode
  156. if cuda and rank != -1:
  157. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
  158. # Trainloader
  159. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  160. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  161. world_size=opt.world_size, workers=opt.workers,
  162. image_weights=opt.image_weights)
  163. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  164. nb = len(dataloader) # number of batches
  165. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  166. # Process 0
  167. if rank in [-1, 0]:
  168. ema.updates = start_epoch * nb // accumulate # set EMA updates
  169. testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader
  170. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True,
  171. rank=-1, world_size=opt.world_size, workers=opt.workers, pad=0.5)[0]
  172. if not opt.resume:
  173. labels = np.concatenate(dataset.labels, 0)
  174. c = torch.tensor(labels[:, 0]) # classes
  175. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  176. # model._initialize_biases(cf.to(device))
  177. if plots:
  178. Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start()
  179. if tb_writer:
  180. tb_writer.add_histogram('classes', c, 0)
  181. # Anchors
  182. if not opt.noautoanchor:
  183. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  184. # Model parameters
  185. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  186. model.nc = nc # attach number of classes to model
  187. model.hyp = hyp # attach hyperparameters to model
  188. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  189. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  190. model.names = names
  191. # Start training
  192. t0 = time.time()
  193. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  194. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  195. maps = np.zeros(nc) # mAP per class
  196. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  197. scheduler.last_epoch = start_epoch - 1 # do not move
  198. scaler = amp.GradScaler(enabled=cuda)
  199. logger.info('Image sizes %g train, %g test\n'
  200. 'Using %g dataloader workers\nLogging results to %s\n'
  201. 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs))
  202. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  203. model.train()
  204. # Update image weights (optional)
  205. if opt.image_weights:
  206. # Generate indices
  207. if rank in [-1, 0]:
  208. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  209. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  210. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  211. # Broadcast if DDP
  212. if rank != -1:
  213. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  214. dist.broadcast(indices, 0)
  215. if rank != 0:
  216. dataset.indices = indices.cpu().numpy()
  217. # Update mosaic border
  218. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  219. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  220. mloss = torch.zeros(4, device=device) # mean losses
  221. if rank != -1:
  222. dataloader.sampler.set_epoch(epoch)
  223. pbar = enumerate(dataloader)
  224. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
  225. if rank in [-1, 0]:
  226. pbar = tqdm(pbar, total=nb) # progress bar
  227. optimizer.zero_grad()
  228. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  229. ni = i + nb * epoch # number integrated batches (since train start)
  230. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  231. # Warmup
  232. if ni <= nw:
  233. xi = [0, nw] # x interp
  234. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  235. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  236. for j, x in enumerate(optimizer.param_groups):
  237. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  238. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  239. if 'momentum' in x:
  240. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  241. # Multi-scale
  242. if opt.multi_scale:
  243. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  244. sf = sz / max(imgs.shape[2:]) # scale factor
  245. if sf != 1:
  246. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  247. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  248. # Forward
  249. with amp.autocast(enabled=cuda):
  250. pred = model(imgs) # forward
  251. loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
  252. if rank != -1:
  253. loss *= opt.world_size # gradient averaged between devices in DDP mode
  254. # Backward
  255. scaler.scale(loss).backward()
  256. # Optimize
  257. if ni % accumulate == 0:
  258. scaler.step(optimizer) # optimizer.step
  259. scaler.update()
  260. optimizer.zero_grad()
  261. if ema:
  262. ema.update(model)
  263. # Print
  264. if rank in [-1, 0]:
  265. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  266. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  267. s = ('%10s' * 2 + '%10.4g' * 6) % (
  268. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  269. pbar.set_description(s)
  270. # Plot
  271. if plots and ni < 3:
  272. f = save_dir / f'train_batch{ni}.jpg' # filename
  273. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  274. # if tb_writer:
  275. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  276. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  277. elif plots and ni == 3 and wandb:
  278. wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')]})
  279. # end batch ------------------------------------------------------------------------------------------------
  280. # end epoch ----------------------------------------------------------------------------------------------------
  281. # Scheduler
  282. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  283. scheduler.step()
  284. # DDP process 0 or single-GPU
  285. if rank in [-1, 0]:
  286. # mAP
  287. if ema:
  288. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
  289. final_epoch = epoch + 1 == epochs
  290. if not opt.notest or final_epoch: # Calculate mAP
  291. results, maps, times = test.test(opt.data,
  292. batch_size=total_batch_size,
  293. imgsz=imgsz_test,
  294. model=ema.ema,
  295. single_cls=opt.single_cls,
  296. dataloader=testloader,
  297. save_dir=save_dir,
  298. plots=plots and final_epoch,
  299. log_imgs=opt.log_imgs if wandb else 0)
  300. # Write
  301. with open(results_file, 'a') as f:
  302. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  303. if len(opt.name) and opt.bucket:
  304. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  305. # Log
  306. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  307. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  308. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  309. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  310. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  311. if tb_writer:
  312. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  313. if wandb:
  314. wandb.log({tag: x}) # W&B
  315. # Update best mAP
  316. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  317. if fi > best_fitness:
  318. best_fitness = fi
  319. # Save model
  320. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  321. if save:
  322. with open(results_file, 'r') as f: # create checkpoint
  323. ckpt = {'epoch': epoch,
  324. 'best_fitness': best_fitness,
  325. 'training_results': f.read(),
  326. 'model': ema.ema,
  327. 'optimizer': None if final_epoch else optimizer.state_dict(),
  328. 'wandb_id': wandb_run.id if wandb else None}
  329. # Save last, best and delete
  330. torch.save(ckpt, last)
  331. if best_fitness == fi:
  332. torch.save(ckpt, best)
  333. del ckpt
  334. # end epoch ----------------------------------------------------------------------------------------------------
  335. # end training
  336. if rank in [-1, 0]:
  337. # Strip optimizers
  338. for f in [last, best]:
  339. if f.exists(): # is *.pt
  340. strip_optimizer(f) # strip optimizer
  341. os.system('gsutil cp %s gs://%s/weights' % (f, opt.bucket)) if opt.bucket else None # upload
  342. # Plots
  343. if plots:
  344. plot_results(save_dir=save_dir) # save as results.png
  345. if wandb:
  346. files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
  347. wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
  348. if (save_dir / f).exists()]})
  349. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  350. # Test best.pt
  351. if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
  352. results, _, _ = test.test(opt.data,
  353. batch_size=total_batch_size,
  354. imgsz=imgsz_test,
  355. model=attempt_load(best if best.exists() else last, device).half(),
  356. single_cls=opt.single_cls,
  357. dataloader=testloader,
  358. save_dir=save_dir,
  359. save_json=True, # use pycocotools
  360. plots=False)
  361. else:
  362. dist.destroy_process_group()
  363. wandb.run.finish() if wandb and wandb.run else None
  364. torch.cuda.empty_cache()
  365. return results
  366. if __name__ == '__main__':
  367. parser = argparse.ArgumentParser()
  368. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  369. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  370. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  371. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  372. parser.add_argument('--epochs', type=int, default=300)
  373. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  374. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  375. parser.add_argument('--rect', action='store_true', help='rectangular training')
  376. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  377. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  378. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  379. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  380. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  381. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  382. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  383. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  384. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  385. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  386. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  387. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  388. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  389. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  390. parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
  391. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  392. parser.add_argument('--project', default='runs/train', help='save to project/name')
  393. parser.add_argument('--name', default='exp', help='save to project/name')
  394. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  395. opt = parser.parse_args()
  396. # Set DDP variables
  397. opt.total_batch_size = opt.batch_size
  398. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  399. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  400. set_logging(opt.global_rank)
  401. if opt.global_rank in [-1, 0]:
  402. check_git_status()
  403. # Resume
  404. if opt.resume: # resume an interrupted run
  405. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  406. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  407. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  408. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
  409. opt.cfg, opt.weights, opt.resume = '', ckpt, True
  410. logger.info('Resuming training from %s' % ckpt)
  411. else:
  412. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  413. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  414. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  415. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  416. opt.name = 'evolve' if opt.evolve else opt.name
  417. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
  418. # DDP mode
  419. device = select_device(opt.device, batch_size=opt.batch_size)
  420. if opt.local_rank != -1:
  421. assert torch.cuda.device_count() > opt.local_rank
  422. torch.cuda.set_device(opt.local_rank)
  423. device = torch.device('cuda', opt.local_rank)
  424. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  425. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  426. opt.batch_size = opt.total_batch_size // opt.world_size
  427. # Hyperparameters
  428. with open(opt.hyp) as f:
  429. hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
  430. if 'box' not in hyp:
  431. warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
  432. (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
  433. hyp['box'] = hyp.pop('giou')
  434. # Train
  435. logger.info(opt)
  436. if not opt.evolve:
  437. tb_writer = None # init loggers
  438. if opt.global_rank in [-1, 0]:
  439. logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
  440. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  441. train(hyp, opt, device, tb_writer, wandb)
  442. # Evolve hyperparameters (optional)
  443. else:
  444. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  445. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  446. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  447. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  448. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  449. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  450. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  451. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  452. 'box': (1, 0.02, 0.2), # box loss gain
  453. 'cls': (1, 0.2, 4.0), # cls loss gain
  454. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  455. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  456. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  457. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  458. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  459. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  460. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  461. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  462. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  463. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  464. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  465. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  466. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  467. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  468. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  469. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  470. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  471. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  472. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  473. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  474. opt.notest, opt.nosave = True, True # only test/save final epoch
  475. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  476. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  477. if opt.bucket:
  478. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  479. for _ in range(300): # generations to evolve
  480. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  481. # Select parent(s)
  482. parent = 'single' # parent selection method: 'single' or 'weighted'
  483. x = np.loadtxt('evolve.txt', ndmin=2)
  484. n = min(5, len(x)) # number of previous results to consider
  485. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  486. w = fitness(x) - fitness(x).min() # weights
  487. if parent == 'single' or len(x) == 1:
  488. # x = x[random.randint(0, n - 1)] # random selection
  489. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  490. elif parent == 'weighted':
  491. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  492. # Mutate
  493. mp, s = 0.8, 0.2 # mutation probability, sigma
  494. npr = np.random
  495. npr.seed(int(time.time()))
  496. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  497. ng = len(meta)
  498. v = np.ones(ng)
  499. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  500. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  501. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  502. hyp[k] = float(x[i + 7] * v[i]) # mutate
  503. # Constrain to limits
  504. for k, v in meta.items():
  505. hyp[k] = max(hyp[k], v[1]) # lower limit
  506. hyp[k] = min(hyp[k], v[2]) # upper limit
  507. hyp[k] = round(hyp[k], 5) # significant digits
  508. # Train mutation
  509. results = train(hyp.copy(), opt, device, wandb=wandb)
  510. # Write mutation results
  511. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  512. # Plot results
  513. plot_evolution(yaml_file)
  514. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  515. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')