No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

564 líneas
29KB

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import shutil
  7. import time
  8. from pathlib import Path
  9. from warnings import warn
  10. import numpy as np
  11. import torch.distributed as dist
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.optim as optim
  15. import torch.optim.lr_scheduler as lr_scheduler
  16. import torch.utils.data
  17. import yaml
  18. from torch.cuda import amp
  19. from torch.nn.parallel import DistributedDataParallel as DDP
  20. from torch.utils.tensorboard import SummaryWriter
  21. from tqdm import tqdm
  22. import test # import test.py to get mAP after each epoch
  23. from models.yolo import Model
  24. from utils.datasets import create_dataloader
  25. from utils.general import (
  26. torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors, labels_to_image_weights,
  27. compute_loss, plot_images, fitness, strip_optimizer, plot_results, get_latest_run, check_dataset, check_file,
  28. check_git_status, check_img_size, increment_dir, print_mutation, plot_evolution, set_logging, init_seeds)
  29. from utils.google_utils import attempt_download
  30. from utils.torch_utils import ModelEMA, select_device, intersect_dicts
  31. logger = logging.getLogger(__name__)
  32. def train(hyp, opt, device, tb_writer=None, wandb=None):
  33. logger.info(f'Hyperparameters {hyp}')
  34. log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / 'evolve' # logging directory
  35. wdir = log_dir / 'weights' # weights directory
  36. wdir.mkdir(parents=True, exist_ok=True)
  37. last = wdir / 'last.pt'
  38. best = wdir / 'best.pt'
  39. results_file = log_dir / 'results.txt'
  40. epochs, batch_size, total_batch_size, weights, rank = \
  41. opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  42. # Save run settings
  43. with open(log_dir / 'hyp.yaml', 'w') as f:
  44. yaml.dump(hyp, f, sort_keys=False)
  45. with open(log_dir / 'opt.yaml', 'w') as f:
  46. yaml.dump(vars(opt), f, sort_keys=False)
  47. # Configure
  48. cuda = device.type != 'cpu'
  49. init_seeds(2 + rank)
  50. with open(opt.data) as f:
  51. data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict
  52. with torch_distributed_zero_first(rank):
  53. check_dataset(data_dict) # check
  54. train_path = data_dict['train']
  55. test_path = data_dict['val']
  56. nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names
  57. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  58. # Model
  59. pretrained = weights.endswith('.pt')
  60. if pretrained:
  61. with torch_distributed_zero_first(rank):
  62. attempt_download(weights) # download if not found locally
  63. ckpt = torch.load(weights, map_location=device) # load checkpoint
  64. if hyp.get('anchors'):
  65. ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
  66. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  67. exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
  68. state_dict = ckpt['model'].float().state_dict() # to FP32
  69. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  70. model.load_state_dict(state_dict, strict=False) # load
  71. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  72. else:
  73. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  74. # Freeze
  75. freeze = [] # parameter names to freeze (full or partial)
  76. for k, v in model.named_parameters():
  77. v.requires_grad = True # train all layers
  78. if any(x in k for x in freeze):
  79. print('freezing %s' % k)
  80. v.requires_grad = False
  81. # Optimizer
  82. nbs = 64 # nominal batch size
  83. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  84. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  85. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  86. for k, v in model.named_modules():
  87. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  88. pg2.append(v.bias) # biases
  89. if isinstance(v, nn.BatchNorm2d):
  90. pg0.append(v.weight) # no decay
  91. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  92. pg1.append(v.weight) # apply decay
  93. if opt.adam:
  94. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  95. else:
  96. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  97. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  98. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  99. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  100. del pg0, pg1, pg2
  101. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  102. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  103. lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
  104. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  105. # plot_lr_scheduler(optimizer, scheduler, epochs)
  106. # Logging
  107. if wandb and wandb.run is None:
  108. id = ckpt.get('wandb_id') if 'ckpt' in locals() else None
  109. wandb_run = wandb.init(config=opt, resume="allow", project="YOLOv5", name=log_dir.stem, id=id)
  110. # Resume
  111. start_epoch, best_fitness = 0, 0.0
  112. if pretrained:
  113. # Optimizer
  114. if ckpt['optimizer'] is not None:
  115. optimizer.load_state_dict(ckpt['optimizer'])
  116. best_fitness = ckpt['best_fitness']
  117. # Results
  118. if ckpt.get('training_results') is not None:
  119. with open(results_file, 'w') as file:
  120. file.write(ckpt['training_results']) # write results.txt
  121. # Epochs
  122. start_epoch = ckpt['epoch'] + 1
  123. if opt.resume:
  124. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  125. shutil.copytree(wdir, wdir.parent / f'weights_backup_epoch{start_epoch - 1}') # save previous weights
  126. if epochs < start_epoch:
  127. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  128. (weights, ckpt['epoch'], epochs))
  129. epochs += ckpt['epoch'] # finetune additional epochs
  130. del ckpt, state_dict
  131. # Image sizes
  132. gs = int(max(model.stride)) # grid size (max stride)
  133. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  134. # DP mode
  135. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  136. model = torch.nn.DataParallel(model)
  137. # SyncBatchNorm
  138. if opt.sync_bn and cuda and rank != -1:
  139. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  140. logger.info('Using SyncBatchNorm()')
  141. # Exponential moving average
  142. ema = ModelEMA(model) if rank in [-1, 0] else None
  143. # DDP mode
  144. if cuda and rank != -1:
  145. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
  146. # Trainloader
  147. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  148. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect,
  149. rank=rank, world_size=opt.world_size, workers=opt.workers)
  150. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  151. nb = len(dataloader) # number of batches
  152. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  153. # Process 0
  154. if rank in [-1, 0]:
  155. ema.updates = start_epoch * nb // accumulate # set EMA updates
  156. testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
  157. hyp=hyp, augment=False, cache=opt.cache_images and not opt.notest, rect=True,
  158. rank=-1, world_size=opt.world_size, workers=opt.workers)[0] # testloader
  159. if not opt.resume:
  160. labels = np.concatenate(dataset.labels, 0)
  161. c = torch.tensor(labels[:, 0]) # classes
  162. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  163. # model._initialize_biases(cf.to(device))
  164. plot_labels(labels, save_dir=log_dir)
  165. if tb_writer:
  166. # tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
  167. tb_writer.add_histogram('classes', c, 0)
  168. # Anchors
  169. if not opt.noautoanchor:
  170. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  171. # Model parameters
  172. hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
  173. model.nc = nc # attach number of classes to model
  174. model.hyp = hyp # attach hyperparameters to model
  175. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  176. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
  177. model.names = names
  178. # Start training
  179. t0 = time.time()
  180. nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
  181. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  182. maps = np.zeros(nc) # mAP per class
  183. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  184. scheduler.last_epoch = start_epoch - 1 # do not move
  185. scaler = amp.GradScaler(enabled=cuda)
  186. logger.info('Image sizes %g train, %g test\n'
  187. 'Using %g dataloader workers\nLogging results to %s\n'
  188. 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, log_dir, epochs))
  189. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  190. model.train()
  191. # Update image weights (optional)
  192. if opt.image_weights:
  193. # Generate indices
  194. if rank in [-1, 0]:
  195. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
  196. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  197. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  198. # Broadcast if DDP
  199. if rank != -1:
  200. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  201. dist.broadcast(indices, 0)
  202. if rank != 0:
  203. dataset.indices = indices.cpu().numpy()
  204. # Update mosaic border
  205. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  206. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  207. mloss = torch.zeros(4, device=device) # mean losses
  208. if rank != -1:
  209. dataloader.sampler.set_epoch(epoch)
  210. pbar = enumerate(dataloader)
  211. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
  212. if rank in [-1, 0]:
  213. pbar = tqdm(pbar, total=nb) # progress bar
  214. optimizer.zero_grad()
  215. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  216. ni = i + nb * epoch # number integrated batches (since train start)
  217. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  218. # Warmup
  219. if ni <= nw:
  220. xi = [0, nw] # x interp
  221. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  222. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  223. for j, x in enumerate(optimizer.param_groups):
  224. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  225. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  226. if 'momentum' in x:
  227. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  228. # Multi-scale
  229. if opt.multi_scale:
  230. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  231. sf = sz / max(imgs.shape[2:]) # scale factor
  232. if sf != 1:
  233. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  234. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  235. # Forward
  236. with amp.autocast(enabled=cuda):
  237. pred = model(imgs) # forward
  238. loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
  239. if rank != -1:
  240. loss *= opt.world_size # gradient averaged between devices in DDP mode
  241. # Backward
  242. scaler.scale(loss).backward()
  243. # Optimize
  244. if ni % accumulate == 0:
  245. scaler.step(optimizer) # optimizer.step
  246. scaler.update()
  247. optimizer.zero_grad()
  248. if ema:
  249. ema.update(model)
  250. # Print
  251. if rank in [-1, 0]:
  252. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  253. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  254. s = ('%10s' * 2 + '%10.4g' * 6) % (
  255. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  256. pbar.set_description(s)
  257. # Plot
  258. if ni < 3:
  259. f = str(log_dir / f'train_batch{ni}.jpg') # filename
  260. result = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
  261. # if tb_writer and result is not None:
  262. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  263. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  264. # end batch ------------------------------------------------------------------------------------------------
  265. # Scheduler
  266. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  267. scheduler.step()
  268. # DDP process 0 or single-GPU
  269. if rank in [-1, 0]:
  270. # mAP
  271. if ema:
  272. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
  273. final_epoch = epoch + 1 == epochs
  274. if not opt.notest or final_epoch: # Calculate mAP
  275. results, maps, times = test.test(opt.data,
  276. batch_size=total_batch_size,
  277. imgsz=imgsz_test,
  278. model=ema.ema,
  279. single_cls=opt.single_cls,
  280. dataloader=testloader,
  281. save_dir=log_dir,
  282. plots=epoch == 0 or final_epoch, # plot first and last
  283. log_imgs=opt.log_imgs)
  284. # Write
  285. with open(results_file, 'a') as f:
  286. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  287. if len(opt.name) and opt.bucket:
  288. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  289. # Log
  290. tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  291. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  292. 'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  293. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  294. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  295. if tb_writer:
  296. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  297. if wandb:
  298. wandb.log({tag: x}) # W&B
  299. # Update best mAP
  300. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  301. if fi > best_fitness:
  302. best_fitness = fi
  303. # Save model
  304. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  305. if save:
  306. with open(results_file, 'r') as f: # create checkpoint
  307. ckpt = {'epoch': epoch,
  308. 'best_fitness': best_fitness,
  309. 'training_results': f.read(),
  310. 'model': ema.ema,
  311. 'optimizer': None if final_epoch else optimizer.state_dict(),
  312. 'wandb_id': wandb_run.id if wandb else None}
  313. # Save last, best and delete
  314. torch.save(ckpt, last)
  315. if best_fitness == fi:
  316. torch.save(ckpt, best)
  317. del ckpt
  318. # end epoch ----------------------------------------------------------------------------------------------------
  319. # end training
  320. if rank in [-1, 0]:
  321. # Strip optimizers
  322. n = opt.name if opt.name.isnumeric() else ''
  323. fresults, flast, fbest = log_dir / f'results{n}.txt', wdir / f'last{n}.pt', wdir / f'best{n}.pt'
  324. for f1, f2 in zip([wdir / 'last.pt', wdir / 'best.pt', results_file], [flast, fbest, fresults]):
  325. if f1.exists():
  326. os.rename(f1, f2) # rename
  327. if str(f2).endswith('.pt'): # is *.pt
  328. strip_optimizer(f2) # strip optimizer
  329. os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket else None # upload
  330. # Finish
  331. if not opt.evolve:
  332. plot_results(save_dir=log_dir) # save as results.png
  333. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  334. dist.destroy_process_group() if rank not in [-1, 0] else None
  335. torch.cuda.empty_cache()
  336. return results
  337. if __name__ == '__main__':
  338. parser = argparse.ArgumentParser()
  339. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  340. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  341. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  342. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  343. parser.add_argument('--epochs', type=int, default=300)
  344. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  345. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  346. parser.add_argument('--rect', action='store_true', help='rectangular training')
  347. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  348. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  349. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  350. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  351. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  352. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  353. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  354. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  355. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  356. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  357. parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
  358. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  359. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  360. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  361. parser.add_argument('--logdir', type=str, default='runs/train', help='logging directory')
  362. parser.add_argument('--name', default='', help='name to append to --save-dir: i.e. runs/{N} -> runs/{N}_{name}')
  363. parser.add_argument('--log-imgs', type=int, default=10, help='number of images for W&B logging, max 100')
  364. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  365. opt = parser.parse_args()
  366. # Set DDP variables
  367. opt.total_batch_size = opt.batch_size
  368. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  369. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  370. set_logging(opt.global_rank)
  371. if opt.global_rank in [-1, 0]:
  372. check_git_status()
  373. # Resume
  374. if opt.resume: # resume an interrupted run
  375. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  376. log_dir = Path(ckpt).parent.parent # runs/train/exp0
  377. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  378. with open(log_dir / 'opt.yaml') as f:
  379. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
  380. opt.cfg, opt.weights, opt.resume = '', ckpt, True
  381. logger.info('Resuming training from %s' % ckpt)
  382. else:
  383. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  384. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  385. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  386. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  387. log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
  388. # DDP mode
  389. device = select_device(opt.device, batch_size=opt.batch_size)
  390. if opt.local_rank != -1:
  391. assert torch.cuda.device_count() > opt.local_rank
  392. torch.cuda.set_device(opt.local_rank)
  393. device = torch.device('cuda', opt.local_rank)
  394. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  395. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  396. opt.batch_size = opt.total_batch_size // opt.world_size
  397. # Hyperparameters
  398. with open(opt.hyp) as f:
  399. hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
  400. if 'box' not in hyp:
  401. warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
  402. (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
  403. hyp['box'] = hyp.pop('giou')
  404. # Train
  405. logger.info(opt)
  406. if not opt.evolve:
  407. tb_writer, wandb = None, None # init loggers
  408. if opt.global_rank in [-1, 0]:
  409. # Tensorboard
  410. logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.logdir}", view at http://localhost:6006/')
  411. tb_writer = SummaryWriter(log_dir=log_dir) # runs/train/exp0
  412. # W&B
  413. try:
  414. import wandb
  415. assert os.environ.get('WANDB_DISABLED') != 'true'
  416. except (ImportError, AssertionError):
  417. opt.log_imgs = 0
  418. logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)")
  419. train(hyp, opt, device, tb_writer, wandb)
  420. # Evolve hyperparameters (optional)
  421. else:
  422. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  423. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  424. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  425. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  426. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  427. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  428. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  429. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  430. 'box': (1, 0.02, 0.2), # box loss gain
  431. 'cls': (1, 0.2, 4.0), # cls loss gain
  432. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  433. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  434. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  435. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  436. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  437. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  438. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  439. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  440. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  441. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  442. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  443. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  444. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  445. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  446. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  447. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  448. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  449. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  450. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  451. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  452. opt.notest, opt.nosave = True, True # only test/save final epoch
  453. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  454. yaml_file = Path(opt.logdir) / 'evolve' / 'hyp_evolved.yaml' # save best result here
  455. if opt.bucket:
  456. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  457. for _ in range(300): # generations to evolve
  458. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  459. # Select parent(s)
  460. parent = 'single' # parent selection method: 'single' or 'weighted'
  461. x = np.loadtxt('evolve.txt', ndmin=2)
  462. n = min(5, len(x)) # number of previous results to consider
  463. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  464. w = fitness(x) - fitness(x).min() # weights
  465. if parent == 'single' or len(x) == 1:
  466. # x = x[random.randint(0, n - 1)] # random selection
  467. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  468. elif parent == 'weighted':
  469. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  470. # Mutate
  471. mp, s = 0.8, 0.2 # mutation probability, sigma
  472. npr = np.random
  473. npr.seed(int(time.time()))
  474. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  475. ng = len(meta)
  476. v = np.ones(ng)
  477. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  478. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  479. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  480. hyp[k] = float(x[i + 7] * v[i]) # mutate
  481. # Constrain to limits
  482. for k, v in meta.items():
  483. hyp[k] = max(hyp[k], v[1]) # lower limit
  484. hyp[k] = min(hyp[k], v[2]) # upper limit
  485. hyp[k] = round(hyp[k], 5) # significant digits
  486. # Train mutation
  487. results = train(hyp.copy(), opt, device)
  488. # Write mutation results
  489. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  490. # Plot results
  491. plot_evolution(yaml_file)
  492. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  493. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')