You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

609 lines
31KB

  1. import argparse
  2. import logging
  3. import math
  4. import os
  5. import random
  6. import time
  7. from pathlib import Path
  8. from threading import Thread
  9. import numpy as np
  10. import torch.distributed as dist
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. import torch.optim as optim
  14. import torch.optim.lr_scheduler as lr_scheduler
  15. import torch.utils.data
  16. import yaml
  17. from torch.cuda import amp
  18. from torch.nn.parallel import DistributedDataParallel as DDP
  19. from torch.utils.tensorboard import SummaryWriter
  20. from tqdm import tqdm
  21. import test # import test.py to get mAP after each epoch
  22. from models.experimental import attempt_load
  23. from models.yolo import Model
  24. from utils.autoanchor import check_anchors
  25. from utils.datasets import create_dataloader
  26. from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
  27. fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
  28. check_requirements, print_mutation, set_logging, one_cycle, colorstr
  29. from utils.google_utils import attempt_download
  30. from utils.loss import ComputeLoss
  31. from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
  32. from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
  33. logger = logging.getLogger(__name__)
  34. def train(hyp, opt, device, tb_writer=None, wandb=None):
  35. logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  36. save_dir, epochs, batch_size, total_batch_size, weights, rank = \
  37. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
  38. # Directories
  39. wdir = save_dir / 'weights'
  40. wdir.mkdir(parents=True, exist_ok=True) # make dir
  41. last = wdir / 'last.pt'
  42. best = wdir / 'best.pt'
  43. results_file = save_dir / 'results.txt'
  44. # Save run settings
  45. with open(save_dir / 'hyp.yaml', 'w') as f:
  46. yaml.dump(hyp, f, sort_keys=False)
  47. with open(save_dir / 'opt.yaml', 'w') as f:
  48. yaml.dump(vars(opt), f, sort_keys=False)
  49. # Configure
  50. plots = not opt.evolve # create plots
  51. cuda = device.type != 'cpu'
  52. init_seeds(2 + rank)
  53. with open(opt.data) as f:
  54. data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
  55. with torch_distributed_zero_first(rank):
  56. check_dataset(data_dict) # check
  57. train_path = data_dict['train']
  58. test_path = data_dict['val']
  59. nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
  60. names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  61. assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
  62. # Model
  63. pretrained = weights.endswith('.pt')
  64. if pretrained:
  65. with torch_distributed_zero_first(rank):
  66. attempt_download(weights) # download if not found locally
  67. ckpt = torch.load(weights, map_location=device) # load checkpoint
  68. if hyp.get('anchors'):
  69. ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
  70. model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
  71. exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
  72. state_dict = ckpt['model'].float().state_dict() # to FP32
  73. state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
  74. model.load_state_dict(state_dict, strict=False) # load
  75. logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
  76. else:
  77. model = Model(opt.cfg, ch=3, nc=nc).to(device) # create
  78. # Freeze
  79. freeze = [] # parameter names to freeze (full or partial)
  80. for k, v in model.named_parameters():
  81. v.requires_grad = True # train all layers
  82. if any(x in k for x in freeze):
  83. print('freezing %s' % k)
  84. v.requires_grad = False
  85. # Optimizer
  86. nbs = 64 # nominal batch size
  87. accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing
  88. hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay
  89. logger.info(f"Scaled weight_decay = {hyp['weight_decay']}")
  90. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  91. for k, v in model.named_modules():
  92. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
  93. pg2.append(v.bias) # biases
  94. if isinstance(v, nn.BatchNorm2d):
  95. pg0.append(v.weight) # no decay
  96. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
  97. pg1.append(v.weight) # apply decay
  98. if opt.adam:
  99. optimizer = optim.Adam(pg0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
  100. else:
  101. optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
  102. optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
  103. optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
  104. logger.info('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
  105. del pg0, pg1, pg2
  106. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  107. # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  108. if opt.linear_lr:
  109. lf = lambda x: (1 - x / (epochs - 1)) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  110. else:
  111. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  112. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  113. # plot_lr_scheduler(optimizer, scheduler, epochs)
  114. # Logging
  115. if rank in [-1, 0] and wandb and wandb.run is None:
  116. opt.hyp = hyp # add hyperparameters
  117. wandb_run = wandb.init(config=opt, resume="allow",
  118. project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
  119. name=save_dir.stem,
  120. id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
  121. loggers = {'wandb': wandb} # loggers dict
  122. # Resume
  123. start_epoch, best_fitness = 0, 0.0
  124. if pretrained:
  125. # Optimizer
  126. if ckpt['optimizer'] is not None:
  127. optimizer.load_state_dict(ckpt['optimizer'])
  128. best_fitness = ckpt['best_fitness']
  129. # Results
  130. if ckpt.get('training_results') is not None:
  131. with open(results_file, 'w') as file:
  132. file.write(ckpt['training_results']) # write results.txt
  133. # Epochs
  134. start_epoch = ckpt['epoch'] + 1
  135. if opt.resume:
  136. assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
  137. if epochs < start_epoch:
  138. logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
  139. (weights, ckpt['epoch'], epochs))
  140. epochs += ckpt['epoch'] # finetune additional epochs
  141. del ckpt, state_dict
  142. # Image sizes
  143. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  144. nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj'])
  145. imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
  146. # DP mode
  147. if cuda and rank == -1 and torch.cuda.device_count() > 1:
  148. model = torch.nn.DataParallel(model)
  149. # SyncBatchNorm
  150. if opt.sync_bn and cuda and rank != -1:
  151. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  152. logger.info('Using SyncBatchNorm()')
  153. # EMA
  154. ema = ModelEMA(model) if rank in [-1, 0] else None
  155. # DDP mode
  156. if cuda and rank != -1:
  157. model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
  158. # Trainloader
  159. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
  160. hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
  161. world_size=opt.world_size, workers=opt.workers,
  162. image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
  163. mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
  164. nb = len(dataloader) # number of batches
  165. assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
  166. # Process 0
  167. if rank in [-1, 0]:
  168. ema.updates = start_epoch * nb // accumulate # set EMA updates
  169. testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
  170. hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
  171. world_size=opt.world_size, workers=opt.workers,
  172. pad=0.5, prefix=colorstr('val: '))[0]
  173. if not opt.resume:
  174. labels = np.concatenate(dataset.labels, 0)
  175. c = torch.tensor(labels[:, 0]) # classes
  176. # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
  177. # model._initialize_biases(cf.to(device))
  178. if plots:
  179. plot_labels(labels, save_dir, loggers)
  180. if tb_writer:
  181. tb_writer.add_histogram('classes', c, 0)
  182. # Anchors
  183. if not opt.noautoanchor:
  184. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
  185. # Model parameters
  186. hyp['box'] *= 3. / nl # scale to layers
  187. hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers
  188. hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers
  189. model.nc = nc # attach number of classes to model
  190. model.hyp = hyp # attach hyperparameters to model
  191. model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
  192. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  193. model.names = names
  194. # Start training
  195. t0 = time.time()
  196. nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
  197. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  198. maps = np.zeros(nc) # mAP per class
  199. results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  200. scheduler.last_epoch = start_epoch - 1 # do not move
  201. scaler = amp.GradScaler(enabled=cuda)
  202. compute_loss = ComputeLoss(model) # init loss class
  203. logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
  204. f'Using {dataloader.num_workers} dataloader workers\n'
  205. f'Logging results to {save_dir}\n'
  206. f'Starting training for {epochs} epochs...')
  207. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  208. model.train()
  209. # Update image weights (optional)
  210. if opt.image_weights:
  211. # Generate indices
  212. if rank in [-1, 0]:
  213. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  214. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  215. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  216. # Broadcast if DDP
  217. if rank != -1:
  218. indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
  219. dist.broadcast(indices, 0)
  220. if rank != 0:
  221. dataset.indices = indices.cpu().numpy()
  222. # Update mosaic border
  223. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  224. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  225. mloss = torch.zeros(4, device=device) # mean losses
  226. if rank != -1:
  227. dataloader.sampler.set_epoch(epoch)
  228. pbar = enumerate(dataloader)
  229. logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'targets', 'img_size'))
  230. if rank in [-1, 0]:
  231. pbar = tqdm(pbar, total=nb) # progress bar
  232. optimizer.zero_grad()
  233. for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
  234. ni = i + nb * epoch # number integrated batches (since train start)
  235. imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
  236. # Warmup
  237. if ni <= nw:
  238. xi = [0, nw] # x interp
  239. # model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  240. accumulate = max(1, np.interp(ni, xi, [1, nbs / total_batch_size]).round())
  241. for j, x in enumerate(optimizer.param_groups):
  242. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  243. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
  244. if 'momentum' in x:
  245. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  246. # Multi-scale
  247. if opt.multi_scale:
  248. sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
  249. sf = sz / max(imgs.shape[2:]) # scale factor
  250. if sf != 1:
  251. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  252. imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  253. # Forward
  254. with amp.autocast(enabled=cuda):
  255. pred = model(imgs) # forward
  256. loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
  257. if rank != -1:
  258. loss *= opt.world_size # gradient averaged between devices in DDP mode
  259. if opt.quad:
  260. loss *= 4.
  261. # Backward
  262. scaler.scale(loss).backward()
  263. # Optimize
  264. if ni % accumulate == 0:
  265. scaler.step(optimizer) # optimizer.step
  266. scaler.update()
  267. optimizer.zero_grad()
  268. if ema:
  269. ema.update(model)
  270. # Print
  271. if rank in [-1, 0]:
  272. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  273. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  274. s = ('%10s' * 2 + '%10.4g' * 6) % (
  275. '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
  276. pbar.set_description(s)
  277. # Plot
  278. if plots and ni < 3:
  279. f = save_dir / f'train_batch{ni}.jpg' # filename
  280. Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
  281. # if tb_writer:
  282. # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
  283. # tb_writer.add_graph(model, imgs) # add model to tensorboard
  284. elif plots and ni == 10 and wandb:
  285. wandb.log({"Mosaics": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('train*.jpg')
  286. if x.exists()]}, commit=False)
  287. # end batch ------------------------------------------------------------------------------------------------
  288. # end epoch ----------------------------------------------------------------------------------------------------
  289. # Scheduler
  290. lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
  291. scheduler.step()
  292. # DDP process 0 or single-GPU
  293. if rank in [-1, 0]:
  294. # mAP
  295. if ema:
  296. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
  297. final_epoch = epoch + 1 == epochs
  298. if not opt.notest or final_epoch: # Calculate mAP
  299. results, maps, times = test.test(opt.data,
  300. batch_size=batch_size * 2,
  301. imgsz=imgsz_test,
  302. model=ema.ema,
  303. single_cls=opt.single_cls,
  304. dataloader=testloader,
  305. save_dir=save_dir,
  306. verbose=nc < 50 and final_epoch,
  307. plots=plots and final_epoch,
  308. log_imgs=opt.log_imgs if wandb else 0,
  309. compute_loss=compute_loss)
  310. # Write
  311. with open(results_file, 'a') as f:
  312. f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  313. if len(opt.name) and opt.bucket:
  314. os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
  315. # Log
  316. tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
  317. 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
  318. 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
  319. 'x/lr0', 'x/lr1', 'x/lr2'] # params
  320. for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
  321. if tb_writer:
  322. tb_writer.add_scalar(tag, x, epoch) # tensorboard
  323. if wandb:
  324. wandb.log({tag: x}, step=epoch, commit=tag == tags[-1]) # W&B
  325. # Update best mAP
  326. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  327. if fi > best_fitness:
  328. best_fitness = fi
  329. # Save model
  330. save = (not opt.nosave) or (final_epoch and not opt.evolve)
  331. if save:
  332. with open(results_file, 'r') as f: # create checkpoint
  333. ckpt = {'epoch': epoch,
  334. 'best_fitness': best_fitness,
  335. 'training_results': f.read(),
  336. 'model': ema.ema,
  337. 'optimizer': None if final_epoch else optimizer.state_dict(),
  338. 'wandb_id': wandb_run.id if wandb else None}
  339. # Save last, best and delete
  340. torch.save(ckpt, last)
  341. if best_fitness == fi:
  342. torch.save(ckpt, best)
  343. del ckpt
  344. # end epoch ----------------------------------------------------------------------------------------------------
  345. # end training
  346. if rank in [-1, 0]:
  347. # Strip optimizers
  348. final = best if best.exists() else last # final model
  349. for f in [last, best]:
  350. if f.exists():
  351. strip_optimizer(f) # strip optimizers
  352. if opt.bucket:
  353. os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload
  354. # Plots
  355. if plots:
  356. plot_results(save_dir=save_dir) # save as results.png
  357. if wandb:
  358. files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
  359. wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
  360. if (save_dir / f).exists()]})
  361. if opt.log_artifacts:
  362. wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem)
  363. # Test best.pt
  364. logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
  365. if opt.data.endswith('coco.yaml') and nc == 80: # if COCO
  366. for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests
  367. results, _, _ = test.test(opt.data,
  368. batch_size=batch_size * 2,
  369. imgsz=imgsz_test,
  370. conf_thres=conf,
  371. iou_thres=iou,
  372. model=attempt_load(final, device).half(),
  373. single_cls=opt.single_cls,
  374. dataloader=testloader,
  375. save_dir=save_dir,
  376. save_json=save_json,
  377. plots=False)
  378. else:
  379. dist.destroy_process_group()
  380. wandb.run.finish() if wandb and wandb.run else None
  381. torch.cuda.empty_cache()
  382. return results
  383. if __name__ == '__main__':
  384. parser = argparse.ArgumentParser()
  385. parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
  386. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  387. parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path')
  388. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.yaml', help='hyperparameters path')
  389. parser.add_argument('--epochs', type=int, default=300)
  390. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
  391. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes')
  392. parser.add_argument('--rect', action='store_true', help='rectangular training')
  393. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  394. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  395. parser.add_argument('--notest', action='store_true', help='only test final epoch')
  396. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
  397. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
  398. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  399. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
  400. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  401. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  402. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  403. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  404. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
  405. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  406. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
  407. parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100')
  408. parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model')
  409. parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
  410. parser.add_argument('--project', default='runs/train', help='save to project/name')
  411. parser.add_argument('--name', default='exp', help='save to project/name')
  412. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  413. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  414. parser.add_argument('--linear-lr', action='store_true', help='linear LR')
  415. opt = parser.parse_args()
  416. # Set DDP variables
  417. opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
  418. opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
  419. set_logging(opt.global_rank)
  420. if opt.global_rank in [-1, 0]:
  421. check_git_status()
  422. check_requirements()
  423. # Resume
  424. if opt.resume: # resume an interrupted run
  425. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  426. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  427. apriori = opt.global_rank, opt.local_rank
  428. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  429. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace
  430. opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
  431. logger.info('Resuming training from %s' % ckpt)
  432. else:
  433. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  434. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  435. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  436. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  437. opt.name = 'evolve' if opt.evolve else opt.name
  438. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run
  439. # DDP mode
  440. opt.total_batch_size = opt.batch_size
  441. device = select_device(opt.device, batch_size=opt.batch_size)
  442. if opt.local_rank != -1:
  443. assert torch.cuda.device_count() > opt.local_rank
  444. torch.cuda.set_device(opt.local_rank)
  445. device = torch.device('cuda', opt.local_rank)
  446. dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
  447. assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
  448. opt.batch_size = opt.total_batch_size // opt.world_size
  449. # Hyperparameters
  450. with open(opt.hyp) as f:
  451. hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
  452. # Train
  453. logger.info(opt)
  454. try:
  455. import wandb
  456. except ImportError:
  457. wandb = None
  458. prefix = colorstr('wandb: ')
  459. logger.info(f"{prefix}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
  460. if not opt.evolve:
  461. tb_writer = None # init loggers
  462. if opt.global_rank in [-1, 0]:
  463. logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
  464. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  465. train(hyp, opt, device, tb_writer, wandb)
  466. # Evolve hyperparameters (optional)
  467. else:
  468. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  469. meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  470. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  471. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  472. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  473. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  474. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  475. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  476. 'box': (1, 0.02, 0.2), # box loss gain
  477. 'cls': (1, 0.2, 4.0), # cls loss gain
  478. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  479. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  480. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  481. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  482. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  483. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  484. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  485. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  486. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  487. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  488. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  489. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  490. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  491. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  492. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  493. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  494. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  495. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  496. 'mixup': (1, 0.0, 1.0)} # image mixup (probability)
  497. assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
  498. opt.notest, opt.nosave = True, True # only test/save final epoch
  499. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  500. yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
  501. if opt.bucket:
  502. os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
  503. for _ in range(300): # generations to evolve
  504. if Path('evolve.txt').exists(): # if evolve.txt exists: select best hyps and mutate
  505. # Select parent(s)
  506. parent = 'single' # parent selection method: 'single' or 'weighted'
  507. x = np.loadtxt('evolve.txt', ndmin=2)
  508. n = min(5, len(x)) # number of previous results to consider
  509. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  510. w = fitness(x) - fitness(x).min() # weights
  511. if parent == 'single' or len(x) == 1:
  512. # x = x[random.randint(0, n - 1)] # random selection
  513. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  514. elif parent == 'weighted':
  515. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  516. # Mutate
  517. mp, s = 0.8, 0.2 # mutation probability, sigma
  518. npr = np.random
  519. npr.seed(int(time.time()))
  520. g = np.array([x[0] for x in meta.values()]) # gains 0-1
  521. ng = len(meta)
  522. v = np.ones(ng)
  523. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  524. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  525. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  526. hyp[k] = float(x[i + 7] * v[i]) # mutate
  527. # Constrain to limits
  528. for k, v in meta.items():
  529. hyp[k] = max(hyp[k], v[1]) # lower limit
  530. hyp[k] = min(hyp[k], v[2]) # upper limit
  531. hyp[k] = round(hyp[k], 5) # significant digits
  532. # Train mutation
  533. results = train(hyp.copy(), opt, device, wandb=wandb)
  534. # Write mutation results
  535. print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
  536. # Plot results
  537. plot_evolution(yaml_file)
  538. print(f'Hyperparameter evolution complete. Best results saved as: {yaml_file}\n'
  539. f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')