You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

565 line
29KB

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import time
  7. from pathlib import Path
  8. from warnings import warn
  9. import numpy as np
  10. import torch.distributed as dist
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. import torch.optim as optim
  14. import torch.optim.lr_scheduler as lr_scheduler
  15. import torch.utils.data
  16. import yaml
  17. from torch.cuda import amp
  18. from torch.nn.parallel import DistributedDataParallel as DDP
  19. from torch.utils.tensorboard import SummaryWriter
  20. from tqdm import tqdm
  21. import test # import test.py to get mAP after each epoch
  22. from models.yolo import Model
  23. from utils.autoanchor import check_anchors
  24. from utils.datasets import create_dataloader
  25. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  26. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  27. print_mutation, set_logging
  28. from utils.google_utils import attempt_download
  29. from utils.loss import compute_loss
  30. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  31. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
  32. logger = logging.getLogger(__name__)
  33. def train(hyp, opt, device, tb_writer=None, wandb=None):
  34. logger.info(f'Hyperparameters {hyp}')
  35. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  36. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  37. # Directories
  38. wdir = save_dir / 'weights'
  39. wdir.mkdir(parents=True, exist_ok=True) # make dir
  40. last = wdir / 'last.pt'
  41. best = wdir / 'best.pt'
  42. results_file = save_dir / 'results.txt'
  43. # Save run settings
  44. with open(save_dir / 'hyp.yaml', 'w') as f:
  45. yaml.dump(hyp, f, sort_keys=False)
  46. with open(save_dir / 'opt.yaml', 'w') as f:
  47. yaml.dump(vars(opt), f, sort_keys=False)
  48. # Configure
  49. cuda = device.type != 'cpu'
  50. init_seeds(2 + rank)
  51. with open(opt.data) as f:
  52. data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict
  53. with torch_distributed_zero_first(rank):
  54. check_dataset(data_dict) # check
  55. train_path = data_dict['train']
  56. test_path = data_dict['val']
  57. nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names
  58. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  59. # Model
  60. pretrained = weights.endswith('.pt')
  61. if pretrained:
  62. with torch_distributed_zero_first(rank):
  63. attempt_download(weights) # download if not found locally
  64. ckpt = torch.load(weights, map_location=device) # load checkpoint
  65. if hyp.get('anchors'):
  66. ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
  67. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  68. exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
  69. state_dict = ckpt['model'].float().state_dict() # to FP32
  70. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  71. model.load_state_dict(state_dict, strict=False) # load
  72. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  73. else:
  74. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  75. # Freeze
  76. freeze = [] # parameter names to freeze (full or partial)
  77. for k, v in model.named_parameters():
  78. v.requires_grad = True # train all layers
  79. if any(x in k for x in freeze):
  80. print('freezing %s' % k)
  81. v.requires_grad = False
  82. # Optimizer
  83. nbs = 64 # nominal batch size
  84. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  85. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  86. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  87. for k, v in model.named_modules():
  88. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  89. pg2.append(v.bias) # biases
  90. if isinstance(v, nn.BatchNorm2d):
  91. pg0.append(v.weight) # no decay
  92. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  93. pg1.append(v.weight) # apply decay
  94. if opt.adam:
  95. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  96. else:
  97. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  98. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  99. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  100. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  101. del pg0, pg1, pg2
  102. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  103. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  104. lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
  105. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  106. # plot_lr_scheduler(optimizer, scheduler, epochs)
  107. # Logging
  108. if wandb and wandb.run is None:
  109. wandb_run = wandb.init(config=opt, resume="allow",
  110. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  111. name=save_dir.stem,
  112. id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
  113. # Resume
  114. start_epoch, best_fitness = 0, 0.0
  115. if pretrained:
  116. # Optimizer
  117. if ckpt['optimizer'] is not None:
  118. optimizer.load_state_dict(ckpt['optimizer'])
  119. best_fitness = ckpt['best_fitness']
  120. # Results
  121. if ckpt.get('training_results') is not None:
  122. with open(results_file, 'w') as file:
  123. file.write(ckpt['training_results']) # write results.txt
  124. # Epochs
  125. start_epoch = ckpt['epoch'] + 1
  126. if opt.resume:
  127. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  128. if epochs < start_epoch:
  129. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  130. (weights, ckpt['epoch'], epochs))
  131. epochs += ckpt['epoch'] # finetune additional epochs
  132. del ckpt, state_dict
  133. # Image sizes
  134. gs = int(max(model.stride)) # grid size (max stride)
  135. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  136. # DP mode
  137. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  138. model = torch.nn.DataParallel(model)
  139. # SyncBatchNorm
  140. if opt.sync_bn and cuda and rank != -1:
  141. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  142. logger.info('Using SyncBatchNorm()')
  143. # Exponential moving average
  144. ema = ModelEMA(model) if rank in [-1, 0] else None
  145. # DDP mode
  146. if cuda and rank != -1:
  147. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
  148. # Trainloader
  149. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  150. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect,
  151. rank=rank, world_size=opt.world_size, workers=opt.workers)
  152. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  153. nb = len(dataloader) # number of batches
  154. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  155. # Process 0
  156. if rank in [-1, 0]:
  157. ema.updates = start_epoch * nb // accumulate # set EMA updates
  158. testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
  159. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True,
  160. rank=-1, world_size=opt.world_size, workers=opt.workers)[0] # testloader
  161. if not opt.resume:
  162. labels = np.concatenate(dataset.labels, 0)
  163. c = torch.tensor(labels[:, 0]) # classes
  164. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  165. # model._initialize_biases(cf.to(device))
  166. plot_labels(labels, save_dir=save_dir)
  167. if tb_writer:
  168. # tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
  169. tb_writer.add_histogram('classes', c, 0)
  170. # Anchors
  171. if not opt.noautoanchor:
  172. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  173. # Model parameters
  174. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  175. model.nc = nc # attach number of classes to model
  176. model.hyp = hyp # attach hyperparameters to model
  177. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  178. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  179. model.names = names
  180. # Start training
  181. t0 = time.time()
  182. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  183. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  184. maps = np.zeros(nc) # mAP per class
  185. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  186. scheduler.last_epoch = start_epoch - 1 # do not move
  187. scaler = amp.GradScaler(enabled=cuda)
  188. logger.info('Image sizes %g train, %g test\n'
  189. 'Using %g dataloader workers\nLogging results to %s\n'
  190. 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs))
  191. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  192. model.train()
  193. # Update image weights (optional)
  194. if opt.image_weights:
  195. # Generate indices
  196. if rank in [-1, 0]:
  197. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  198. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  199. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  200. # Broadcast if DDP
  201. if rank != -1:
  202. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  203. dist.broadcast(indices, 0)
  204. if rank != 0:
  205. dataset.indices = indices.cpu().numpy()
  206. # Update mosaic border
  207. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  208. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  209. mloss = torch.zeros(4, device=device) # mean losses
  210. if rank != -1:
  211. dataloader.sampler.set_epoch(epoch)
  212. pbar = enumerate(dataloader)
  213. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
  214. if rank in [-1, 0]:
  215. pbar = tqdm(pbar, total=nb) # progress bar
  216. optimizer.zero_grad()
  217. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  218. ni = i + nb * epoch # number integrated batches (since train start)
  219. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  220. # Warmup
  221. if ni <= nw:
  222. xi = [0, nw] # x interp
  223. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  224. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  225. for j, x in enumerate(optimizer.param_groups):
  226. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  227. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  228. if 'momentum' in x:
  229. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  230. # Multi-scale
  231. if opt.multi_scale:
  232. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  233. sf = sz / max(imgs.shape[2:]) # scale factor
  234. if sf != 1:
  235. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  236. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  237. # Forward
  238. with amp.autocast(enabled=cuda):
  239. pred = model(imgs) # forward
  240. loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
  241. if rank != -1:
  242. loss *= opt.world_size # gradient averaged between devices in DDP mode
  243. # Backward
  244. scaler.scale(loss).backward()
  245. # Optimize
  246. if ni % accumulate == 0:
  247. scaler.step(optimizer) # optimizer.step
  248. scaler.update()
  249. optimizer.zero_grad()
  250. if ema:
  251. ema.update(model)
  252. # Print
  253. if rank in [-1, 0]:
  254. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  255. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  256. s = ('%10s' * 2 + '%10.4g' * 6) % (
  257. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  258. pbar.set_description(s)
  259. # Plot
  260. if ni < 3:
  261. f = str(save_dir / f'train_batch{ni}.jpg') # filename
  262. result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
  263. # if tb_writer and result is not None:
  264. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  265. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  266. # end batch ------------------------------------------------------------------------------------------------
  267. # Scheduler
  268. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  269. scheduler.step()
  270. # DDP process 0 or single-GPU
  271. if rank in [-1, 0]:
  272. # mAP
  273. if ema:
  274. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
  275. final_epoch = epoch + 1 == epochs
  276. if not opt.notest or final_epoch: # Calculate mAP
  277. results, maps, times = test.test(opt.data,
  278. batch_size=total_batch_size,
  279. imgsz=imgsz_test,
  280. model=ema.ema,
  281. single_cls=opt.single_cls,
  282. dataloader=testloader,
  283. save_dir=save_dir,
  284. plots=epoch == 0 or final_epoch, # plot first and last
  285. log_imgs=opt.log_imgs if wandb else 0)
  286. # Write
  287. with open(results_file, 'a') as f:
  288. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  289. if len(opt.name) and opt.bucket:
  290. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  291. # Log
  292. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  293. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  294. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  295. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  296. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  297. if tb_writer:
  298. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  299. if wandb:
  300. wandb.log({tag: x}) # W&B
  301. # Update best mAP
  302. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  303. if fi > best_fitness:
  304. best_fitness = fi
  305. # Save model
  306. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  307. if save:
  308. with open(results_file, 'r') as f: # create checkpoint
  309. ckpt = {'epoch': epoch,
  310. 'best_fitness': best_fitness,
  311. 'training_results': f.read(),
  312. 'model': ema.ema,
  313. 'optimizer': None if final_epoch else optimizer.state_dict(),
  314. 'wandb_id': wandb_run.id if wandb else None}
  315. # Save last, best and delete
  316. torch.save(ckpt, last)
  317. if best_fitness == fi:
  318. torch.save(ckpt, best)
  319. del ckpt
  320. # end epoch ----------------------------------------------------------------------------------------------------
  321. # end training
  322. if rank in [-1, 0]:
  323. # Strip optimizers
  324. n = opt.name if opt.name.isnumeric() else ''
  325. fresults, flast, fbest = save_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'
  326. for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]):
  327. if f1.exists():
  328. os.rename(f1, f2) # rename
  329. if str(f2).endswith('.pt'): # is *.pt
  330. strip_optimizer(f2) # strip optimizer
  331. os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None # upload
  332. # Finish
  333. if not opt.evolve:
  334. plot_results(save_dir=save_dir) # save as results.png
  335. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  336. dist.destroy_process_group() if rank not in [-1, 0] else None
  337. torch.cuda.empty_cache()
  338. return results
  339. if __name__ == '__main__':
  340. parser = argparse.ArgumentParser()
  341. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  342. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  343. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  344. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  345. parser.add_argument('--epochs', type=int, default=300)
  346. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  347. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  348. parser.add_argument('--rect', action='store_true', help='rectangular training')
  349. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  350. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  351. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  352. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  353. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  354. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  355. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  356. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  357. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  358. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  359. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  360. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  361. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  362. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  363. parser.add_argument('--log-imgs', type=int, default=10, help='number of images for W&B logging, max 100')
  364. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  365. parser.add_argument('--project', default='runs/train', help='save to project/name')
  366. parser.add_argument('--name', default='exp', help='save to project/name')
  367. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  368. opt = parser.parse_args()
  369. # Set DDP variables
  370. opt.total_batch_size = opt.batch_size
  371. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  372. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  373. set_logging(opt.global_rank)
  374. if opt.global_rank in [-1, 0]:
  375. check_git_status()
  376. # Resume
  377. if opt.resume: # resume an interrupted run
  378. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  379. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  380. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  381. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
  382. opt.cfg, opt.weights, opt.resume = '', ckpt, True
  383. logger.info('Resuming training from %s' % ckpt)
  384. else:
  385. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  386. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  387. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  388. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  389. opt.name = 'evolve' if opt.evolve else opt.name
  390. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
  391. # DDP mode
  392. device = select_device(opt.device, batch_size=opt.batch_size)
  393. if opt.local_rank != -1:
  394. assert torch.cuda.device_count() > opt.local_rank
  395. torch.cuda.set_device(opt.local_rank)
  396. device = torch.device('cuda', opt.local_rank)
  397. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  398. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  399. opt.batch_size = opt.total_batch_size // opt.world_size
  400. # Hyperparameters
  401. with open(opt.hyp) as f:
  402. hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
  403. if 'box' not in hyp:
  404. warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
  405. (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
  406. hyp['box'] = hyp.pop('giou')
  407. # Train
  408. logger.info(opt)
  409. if not opt.evolve:
  410. tb_writer, wandb = None, None # init loggers
  411. if opt.global_rank in [-1, 0]:
  412. # Tensorboard
  413. logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
  414. tb_writer = SummaryWriter(opt.save_dir) # runs/train/exp
  415. # W&B
  416. try:
  417. import wandb
  418. assert os.environ.get('WANDB_DISABLED') != 'true'
  419. except (ImportError, AssertionError):
  420. logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)")
  421. train(hyp, opt, device, tb_writer, wandb)
  422. # Evolve hyperparameters (optional)
  423. else:
  424. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  425. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  426. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  427. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  428. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  429. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  430. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  431. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  432. 'box': (1, 0.02, 0.2), # box loss gain
  433. 'cls': (1, 0.2, 4.0), # cls loss gain
  434. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  435. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  436. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  437. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  438. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  439. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  440. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  441. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  442. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  443. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  444. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  445. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  446. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  447. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  448. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  449. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  450. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  451. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  452. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  453. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  454. opt.notest, opt.nosave = True, True # only test/save final epoch
  455. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  456. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  457. if opt.bucket:
  458. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  459. for _ in range(300): # generations to evolve
  460. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  461. # Select parent(s)
  462. parent = 'single' # parent selection method: 'single' or 'weighted'
  463. x = np.loadtxt('evolve.txt', ndmin=2)
  464. n = min(5, len(x)) # number of previous results to consider
  465. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  466. w = fitness(x) - fitness(x).min() # weights
  467. if parent == 'single' or len(x) == 1:
  468. # x = x[random.randint(0, n - 1)] # random selection
  469. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  470. elif parent == 'weighted':
  471. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  472. # Mutate
  473. mp, s = 0.8, 0.2 # mutation probability, sigma
  474. npr = np.random
  475. npr.seed(int(time.time()))
  476. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  477. ng = len(meta)
  478. v = np.ones(ng)
  479. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  480. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  481. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  482. hyp[k] = float(x[i + 7] * v[i]) # mutate
  483. # Constrain to limits
  484. for k, v in meta.items():
  485. hyp[k] = max(hyp[k], v[1]) # lower limit
  486. hyp[k] = min(hyp[k], v[2]) # upper limit
  487. hyp[k] = round(hyp[k], 5) # significant digits
  488. # Train mutation
  489. results = train(hyp.copy(), opt, device)
  490. # Write mutation results
  491. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  492. # Plot results
  493. plot_evolution(yaml_file)
  494. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  495. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')